12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997 |
- # Source - LangChain
- # URL: https://github.com/langchain-ai/langchain/blob/6a5b084704afa22ca02f78d0464f35aed75d1ff2/libs/langchain/langchain/text_splitter.py#L851
- """**Text Splitters** are classes for splitting text.
- **Class hierarchy:**
- .. code-block::
- BaseDocumentTransformer --> TextSplitter --> <name>TextSplitter # Example: CharacterTextSplitter
- RecursiveCharacterTextSplitter --> <name>TextSplitter
- Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter.
- **Main helpers:**
- .. code-block::
- Document, Tokenizer, Language, LineType, HeaderType
- """ # noqa: E501
- from __future__ import annotations
- import copy
- import json
- import logging
- import pathlib
- import re
- from abc import ABC, abstractmethod
- from dataclasses import dataclass
- from enum import Enum
- from io import BytesIO, StringIO
- from typing import (
- AbstractSet,
- Any,
- Callable,
- Collection,
- Iterable,
- Literal,
- Optional,
- Sequence,
- Tuple,
- Type,
- TypedDict,
- TypeVar,
- cast,
- )
- import requests
- from pydantic import BaseModel, Field, PrivateAttr
- from typing_extensions import NotRequired
- logger = logging.getLogger()
- TS = TypeVar("TS", bound="TextSplitter")
- class BaseSerialized(TypedDict):
- """Base class for serialized objects."""
- lc: int
- id: list[str]
- name: NotRequired[str]
- graph: NotRequired[dict[str, Any]]
- class SerializedConstructor(BaseSerialized):
- """Serialized constructor."""
- type: Literal["constructor"]
- kwargs: dict[str, Any]
- class SerializedSecret(BaseSerialized):
- """Serialized secret."""
- type: Literal["secret"]
- class SerializedNotImplemented(BaseSerialized):
- """Serialized not implemented."""
- type: Literal["not_implemented"]
- repr: Optional[str]
- def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
- """Try to determine if a value is different from the default.
- Args:
- value: The value.
- key: The key.
- model: The model.
- Returns:
- Whether the value is different from the default.
- """
- try:
- return model.__fields__[key].get_default() != value
- except Exception:
- return True
- class Serializable(BaseModel, ABC):
- """Serializable base class."""
- @classmethod
- def is_lc_serializable(cls) -> bool:
- """Is this class serializable?"""
- return False
- @classmethod
- def get_lc_namespace(cls) -> list[str]:
- """Get the namespace of the langchain object.
- For example, if the class is `langchain.llms.openai.OpenAI`, then the
- namespace is ["langchain", "llms", "openai"]
- """
- return cls.__module__.split(".")
- @property
- def lc_secrets(self) -> dict[str, str]:
- """A map of constructor argument names to secret ids.
- For example,
- {"openai_api_key": "OPENAI_API_KEY"}
- """
- return {}
- @property
- def lc_attributes(self) -> dict:
- """List of attribute names that should be included in the serialized kwargs.
- These attributes must be accepted by the constructor.
- """
- return {}
- @classmethod
- def lc_id(cls) -> list[str]:
- """A unique identifier for this class for serialization purposes.
- The unique identifier is a list of strings that describes the path
- to the object.
- """
- return [*cls.get_lc_namespace(), cls.__name__]
- class Config:
- extra = "ignore"
- def __repr_args__(self) -> Any:
- return [
- (k, v)
- for k, v in super().__repr_args__()
- if (k not in self.__fields__ or try_neq_default(v, k, self))
- ]
- _lc_kwargs: dict[str, Any] = PrivateAttr(default_factory=dict)
- def __init__(self, **kwargs: Any) -> None:
- super().__init__(**kwargs)
- self._lc_kwargs = kwargs
- def to_json(
- self,
- ) -> SerializedConstructor | SerializedNotImplemented:
- if not self.is_lc_serializable():
- return self.to_json_not_implemented()
- secrets = dict()
- # Get latest values for kwargs if there is an attribute with same name
- lc_kwargs = {
- k: getattr(self, k, v)
- for k, v in self._lc_kwargs.items()
- if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
- }
- # Merge the lc_secrets and lc_attributes from every class in the MRO
- for cls in [None, *self.__class__.mro()]:
- # Once we get to Serializable, we're done
- if cls is Serializable:
- break
- if cls:
- deprecated_attributes = [
- "lc_namespace",
- "lc_serializable",
- ]
- for attr in deprecated_attributes:
- if hasattr(cls, attr):
- raise ValueError(
- f"Class {self.__class__} has a deprecated "
- f"attribute {attr}. Please use the corresponding "
- f"classmethod instead."
- )
- # Get a reference to self bound to each class in the MRO
- this = cast(
- Serializable, self if cls is None else super(cls, self)
- )
- secrets.update(this.lc_secrets)
- # Now also add the aliases for the secrets
- # This ensures known secret aliases are hidden.
- # Note: this does NOT hide any other extra kwargs
- # that are not present in the fields.
- for key in list(secrets):
- value = secrets[key]
- if key in this.__fields__:
- secrets[this.__fields__[key].alias] = value # type: ignore
- lc_kwargs.update(this.lc_attributes)
- # include all secrets, even if not specified in kwargs
- # as these secrets may be passed as an environment variable instead
- for key in secrets.keys():
- secret_value = getattr(self, key, None) or lc_kwargs.get(key)
- if secret_value is not None:
- lc_kwargs.update({key: secret_value})
- return {
- "lc": 1,
- "type": "constructor",
- "id": self.lc_id(),
- "kwargs": (
- lc_kwargs
- if not secrets
- else _replace_secrets(lc_kwargs, secrets)
- ),
- }
- def to_json_not_implemented(self) -> SerializedNotImplemented:
- return to_json_not_implemented(self)
- def _replace_secrets(
- root: dict[Any, Any], secrets_map: dict[str, str]
- ) -> dict[Any, Any]:
- result = root.copy()
- for path, secret_id in secrets_map.items():
- [*parts, last] = path.split(".")
- current = result
- for part in parts:
- if part not in current:
- break
- current[part] = current[part].copy()
- current = current[part]
- if last in current:
- current[last] = {
- "lc": 1,
- "type": "secret",
- "id": [secret_id],
- }
- return result
- def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
- """Serialize a "not implemented" object.
- Args:
- obj: object to serialize
- Returns:
- SerializedNotImplemented
- """
- _id: list[str] = []
- try:
- if hasattr(obj, "__name__"):
- _id = [*obj.__module__.split("."), obj.__name__]
- elif hasattr(obj, "__class__"):
- _id = [
- *obj.__class__.__module__.split("."),
- obj.__class__.__name__,
- ]
- except Exception:
- pass
- result: SerializedNotImplemented = {
- "lc": 1,
- "type": "not_implemented",
- "id": _id,
- "repr": None,
- }
- try:
- result["repr"] = repr(obj)
- except Exception:
- pass
- return result
- class SplitterDocument(Serializable):
- """Class for storing a piece of text and associated metadata."""
- page_content: str
- """String text."""
- metadata: dict = Field(default_factory=dict)
- """Arbitrary metadata about the page content (e.g., source, relationships to other
- documents, etc.).
- """
- type: Literal["Document"] = "Document"
- def __init__(self, page_content: str, **kwargs: Any) -> None:
- """Pass page_content in as positional or named arg."""
- super().__init__(page_content=page_content, **kwargs)
- @classmethod
- def is_lc_serializable(cls) -> bool:
- """Return whether this class is serializable."""
- return True
- @classmethod
- def get_lc_namespace(cls) -> list[str]:
- """Get the namespace of the langchain object."""
- return ["langchain", "schema", "document"]
- class BaseDocumentTransformer(ABC):
- """Abstract base class for document transformation systems.
- A document transformation system takes a sequence of Documents and returns a
- sequence of transformed Documents.
- Example:
- .. code-block:: python
- class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
- embeddings: Embeddings
- similarity_fn: Callable = cosine_similarity
- similarity_threshold: float = 0.95
- class Config:
- arbitrary_types_allowed = True
- def transform_documents(
- self, documents: Sequence[Document], **kwargs: Any
- ) -> Sequence[Document]:
- stateful_documents = get_stateful_documents(documents)
- embedded_documents = _get_embeddings_from_stateful_docs(
- self.embeddings, stateful_documents
- )
- included_idxs = _filter_similar_embeddings(
- embedded_documents, self.similarity_fn, self.similarity_threshold
- )
- return [stateful_documents[i] for i in sorted(included_idxs)]
- async def atransform_documents(
- self, documents: Sequence[Document], **kwargs: Any
- ) -> Sequence[Document]:
- raise NotImplementedError
- """ # noqa: E501
- @abstractmethod
- def transform_documents(
- self, documents: Sequence[SplitterDocument], **kwargs: Any
- ) -> Sequence[SplitterDocument]:
- """Transform a list of documents.
- Args:
- documents: A sequence of Documents to be transformed.
- Returns:
- A list of transformed Documents.
- """
- async def atransform_documents(
- self, documents: Sequence[SplitterDocument], **kwargs: Any
- ) -> Sequence[SplitterDocument]:
- """Asynchronously transform a list of documents.
- Args:
- documents: A sequence of Documents to be transformed.
- Returns:
- A list of transformed Documents.
- """
- raise NotImplementedError("This method is not implemented.")
- # return await langchain_core.runnables.config.run_in_executor(
- # None, self.transform_documents, documents, **kwargs
- # )
- def _make_spacy_pipe_for_splitting(
- pipe: str, *, max_length: int = 1_000_000
- ) -> Any: # avoid importing spacy
- try:
- import spacy
- except ImportError:
- raise ImportError(
- "Spacy is not installed, please install it with `pip install spacy`."
- )
- if pipe == "sentencizer":
- from spacy.lang.en import English
- sentencizer = English()
- sentencizer.add_pipe("sentencizer")
- else:
- sentencizer = spacy.load(pipe, exclude=["ner", "tagger"])
- sentencizer.max_length = max_length
- return sentencizer
- def _split_text_with_regex(
- text: str, separator: str, keep_separator: bool
- ) -> list[str]:
- # Now that we have the separator, split the text
- if separator:
- if keep_separator:
- # The parentheses in the pattern keep the delimiters in the result.
- _splits = re.split(f"({separator})", text)
- splits = [
- _splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)
- ]
- if len(_splits) % 2 == 0:
- splits += _splits[-1:]
- splits = [_splits[0]] + splits
- else:
- splits = re.split(separator, text)
- else:
- splits = list(text)
- return [s for s in splits if s != ""]
- class TextSplitter(BaseDocumentTransformer, ABC):
- """Interface for splitting text into chunks."""
- def __init__(
- self,
- chunk_size: int = 4000,
- chunk_overlap: int = 200,
- length_function: Callable[[str], int] = len,
- keep_separator: bool = False,
- add_start_index: bool = False,
- strip_whitespace: bool = True,
- ) -> None:
- """Create a new TextSplitter.
- Args:
- chunk_size: Maximum size of chunks to return
- chunk_overlap: Overlap in characters between chunks
- length_function: Function that measures the length of given chunks
- keep_separator: Whether to keep the separator in the chunks
- add_start_index: If `True`, includes chunk's start index in metadata
- strip_whitespace: If `True`, strips whitespace from the start and end of
- every document
- """
- if chunk_overlap > chunk_size:
- raise ValueError(
- f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
- f"({chunk_size}), should be smaller."
- )
- self._chunk_size = chunk_size
- self._chunk_overlap = chunk_overlap
- self._length_function = length_function
- self._keep_separator = keep_separator
- self._add_start_index = add_start_index
- self._strip_whitespace = strip_whitespace
- @abstractmethod
- def split_text(self, text: str) -> list[str]:
- """Split text into multiple components."""
- def create_documents(
- self, texts: list[str], metadatas: Optional[list[dict]] = None
- ) -> list[SplitterDocument]:
- """Create documents from a list of texts."""
- _metadatas = metadatas or [{}] * len(texts)
- documents = []
- for i, text in enumerate(texts):
- index = 0
- previous_chunk_len = 0
- for chunk in self.split_text(text):
- metadata = copy.deepcopy(_metadatas[i])
- if self._add_start_index:
- offset = index + previous_chunk_len - self._chunk_overlap
- index = text.find(chunk, max(0, offset))
- metadata["start_index"] = index
- previous_chunk_len = len(chunk)
- new_doc = SplitterDocument(
- page_content=chunk, metadata=metadata
- )
- documents.append(new_doc)
- return documents
- def split_documents(
- self, documents: Iterable[SplitterDocument]
- ) -> list[SplitterDocument]:
- """Split documents."""
- texts, metadatas = [], []
- for doc in documents:
- texts.append(doc.page_content)
- metadatas.append(doc.metadata)
- return self.create_documents(texts, metadatas=metadatas)
- def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
- text = separator.join(docs)
- if self._strip_whitespace:
- text = text.strip()
- if text == "":
- return None
- else:
- return text
- def _merge_splits(
- self, splits: Iterable[str], separator: str
- ) -> list[str]:
- # We now want to combine these smaller pieces into medium size
- # chunks to send to the LLM.
- separator_len = self._length_function(separator)
- docs = []
- current_doc: list[str] = []
- total = 0
- for d in splits:
- _len = self._length_function(d)
- if (
- total + _len + (separator_len if len(current_doc) > 0 else 0)
- > self._chunk_size
- ):
- if total > self._chunk_size:
- logger.warning(
- f"Created a chunk of size {total}, "
- f"which is longer than the specified {self._chunk_size}"
- )
- if len(current_doc) > 0:
- doc = self._join_docs(current_doc, separator)
- if doc is not None:
- docs.append(doc)
- # Keep on popping if:
- # - we have a larger chunk than in the chunk overlap
- # - or if we still have any chunks and the length is long
- while total > self._chunk_overlap or (
- total
- + _len
- + (separator_len if len(current_doc) > 0 else 0)
- > self._chunk_size
- and total > 0
- ):
- total -= self._length_function(current_doc[0]) + (
- separator_len if len(current_doc) > 1 else 0
- )
- current_doc = current_doc[1:]
- current_doc.append(d)
- total += _len + (separator_len if len(current_doc) > 1 else 0)
- doc = self._join_docs(current_doc, separator)
- if doc is not None:
- docs.append(doc)
- return docs
- @classmethod
- def from_huggingface_tokenizer(
- cls, tokenizer: Any, **kwargs: Any
- ) -> TextSplitter:
- """Text splitter that uses HuggingFace tokenizer to count length."""
- try:
- from transformers import PreTrainedTokenizerBase
- if not isinstance(tokenizer, PreTrainedTokenizerBase):
- raise ValueError(
- "Tokenizer received was not an instance of PreTrainedTokenizerBase"
- )
- def _huggingface_tokenizer_length(text: str) -> int:
- return len(tokenizer.encode(text))
- except ImportError:
- raise ValueError(
- "Could not import transformers python package. "
- "Please install it with `pip install transformers`."
- )
- return cls(length_function=_huggingface_tokenizer_length, **kwargs)
- @classmethod
- def from_tiktoken_encoder(
- cls: Type[TS],
- encoding_name: str = "gpt2",
- model: Optional[str] = None,
- allowed_special: Literal["all"] | AbstractSet[str] = set(),
- disallowed_special: Literal["all"] | Collection[str] = "all",
- **kwargs: Any,
- ) -> TS:
- """Text splitter that uses tiktoken encoder to count length."""
- try:
- import tiktoken
- except ImportError:
- raise ImportError(
- "Could not import tiktoken python package. "
- "This is needed in order to calculate max_tokens_for_prompt. "
- "Please install it with `pip install tiktoken`."
- )
- if model is not None:
- enc = tiktoken.encoding_for_model(model)
- else:
- enc = tiktoken.get_encoding(encoding_name)
- def _tiktoken_encoder(text: str) -> int:
- return len(
- enc.encode(
- text,
- allowed_special=allowed_special,
- disallowed_special=disallowed_special,
- )
- )
- if issubclass(cls, TokenTextSplitter):
- extra_kwargs = {
- "encoding_name": encoding_name,
- "model": model,
- "allowed_special": allowed_special,
- "disallowed_special": disallowed_special,
- }
- kwargs = {**kwargs, **extra_kwargs}
- return cls(length_function=_tiktoken_encoder, **kwargs)
- def transform_documents(
- self, documents: Sequence[SplitterDocument], **kwargs: Any
- ) -> Sequence[SplitterDocument]:
- """Transform sequence of documents by splitting them."""
- return self.split_documents(list(documents))
- class CharacterTextSplitter(TextSplitter):
- """Splitting text that looks at characters."""
- DEFAULT_SEPARATOR: str = "\n\n"
- def __init__(
- self,
- separator: str = DEFAULT_SEPARATOR,
- is_separator_regex: bool = False,
- **kwargs: Any,
- ) -> None:
- """Create a new TextSplitter."""
- super().__init__(**kwargs)
- self._separator = separator
- self._is_separator_regex = is_separator_regex
- def split_text(self, text: str) -> list[str]:
- """Split incoming text and return chunks."""
- # First we naively split the large input into a bunch of smaller ones.
- separator = (
- self._separator
- if self._is_separator_regex
- else re.escape(self._separator)
- )
- splits = _split_text_with_regex(text, separator, self._keep_separator)
- _separator = "" if self._keep_separator else self._separator
- return self._merge_splits(splits, _separator)
- class LineType(TypedDict):
- """Line type as typed dict."""
- metadata: dict[str, str]
- content: str
- class HeaderType(TypedDict):
- """Header type as typed dict."""
- level: int
- name: str
- data: str
- class MarkdownHeaderTextSplitter:
- """Splitting markdown files based on specified headers."""
- def __init__(
- self,
- headers_to_split_on: list[Tuple[str, str]],
- return_each_line: bool = False,
- strip_headers: bool = True,
- ):
- """Create a new MarkdownHeaderTextSplitter.
- Args:
- headers_to_split_on: Headers we want to track
- return_each_line: Return each line w/ associated headers
- strip_headers: Strip split headers from the content of the chunk
- """
- # Output line-by-line or aggregated into chunks w/ common headers
- self.return_each_line = return_each_line
- # Given the headers we want to split on,
- # (e.g., "#, ##, etc") order by length
- self.headers_to_split_on = sorted(
- headers_to_split_on, key=lambda split: len(split[0]), reverse=True
- )
- # Strip headers split headers from the content of the chunk
- self.strip_headers = strip_headers
- def aggregate_lines_to_chunks(
- self, lines: list[LineType]
- ) -> list[SplitterDocument]:
- """Combine lines with common metadata into chunks
- Args:
- lines: Line of text / associated header metadata
- """
- aggregated_chunks: list[LineType] = []
- for line in lines:
- if (
- aggregated_chunks
- and aggregated_chunks[-1]["metadata"] == line["metadata"]
- ):
- # If the last line in the aggregated list
- # has the same metadata as the current line,
- # append the current content to the last lines's content
- aggregated_chunks[-1]["content"] += " \n" + line["content"]
- elif (
- aggregated_chunks
- and aggregated_chunks[-1]["metadata"] != line["metadata"]
- # may be issues if other metadata is present
- and len(aggregated_chunks[-1]["metadata"])
- < len(line["metadata"])
- and aggregated_chunks[-1]["content"].split("\n")[-1][0] == "#"
- and not self.strip_headers
- ):
- # If the last line in the aggregated list
- # has different metadata as the current line,
- # and has shallower header level than the current line,
- # and the last line is a header,
- # and we are not stripping headers,
- # append the current content to the last line's content
- aggregated_chunks[-1]["content"] += " \n" + line["content"]
- # and update the last line's metadata
- aggregated_chunks[-1]["metadata"] = line["metadata"]
- else:
- # Otherwise, append the current line to the aggregated list
- aggregated_chunks.append(line)
- return [
- SplitterDocument(
- page_content=chunk["content"], metadata=chunk["metadata"]
- )
- for chunk in aggregated_chunks
- ]
- def split_text(self, text: str) -> list[SplitterDocument]:
- """Split markdown file
- Args:
- text: Markdown file"""
- # Split the input text by newline character ("\n").
- lines = text.split("\n")
- # Final output
- lines_with_metadata: list[LineType] = []
- # Content and metadata of the chunk currently being processed
- current_content: list[str] = []
- current_metadata: dict[str, str] = {}
- # Keep track of the nested header structure
- # header_stack: list[dict[str, int | str]] = []
- header_stack: list[HeaderType] = []
- initial_metadata: dict[str, str] = {}
- in_code_block = False
- opening_fence = ""
- for line in lines:
- stripped_line = line.strip()
- if not in_code_block:
- # Exclude inline code spans
- if (
- stripped_line.startswith("```")
- and stripped_line.count("```") == 1
- ):
- in_code_block = True
- opening_fence = "```"
- elif stripped_line.startswith("~~~"):
- in_code_block = True
- opening_fence = "~~~"
- else:
- if stripped_line.startswith(opening_fence):
- in_code_block = False
- opening_fence = ""
- if in_code_block:
- current_content.append(stripped_line)
- continue
- # Check each line against each of the header types (e.g., #, ##)
- for sep, name in self.headers_to_split_on:
- # Check if line starts with a header that we intend to split on
- if stripped_line.startswith(sep) and (
- # Header with no text OR header is followed by space
- # Both are valid conditions that sep is being used a header
- len(stripped_line) == len(sep)
- or stripped_line[len(sep)] == " "
- ):
- # Ensure we are tracking the header as metadata
- if name is not None:
- # Get the current header level
- current_header_level = sep.count("#")
- # Pop out headers of lower or same level from the stack
- while (
- header_stack
- and header_stack[-1]["level"]
- >= current_header_level
- ):
- # We have encountered a new header
- # at the same or higher level
- popped_header = header_stack.pop()
- # Clear the metadata for the
- # popped header in initial_metadata
- if popped_header["name"] in initial_metadata:
- initial_metadata.pop(popped_header["name"])
- # Push the current header to the stack
- header: HeaderType = {
- "level": current_header_level,
- "name": name,
- "data": stripped_line[len(sep) :].strip(),
- }
- header_stack.append(header)
- # Update initial_metadata with the current header
- initial_metadata[name] = header["data"]
- # Add the previous line to the lines_with_metadata
- # only if current_content is not empty
- if current_content:
- lines_with_metadata.append(
- {
- "content": "\n".join(current_content),
- "metadata": current_metadata.copy(),
- }
- )
- current_content.clear()
- if not self.strip_headers:
- current_content.append(stripped_line)
- break
- else:
- if stripped_line:
- current_content.append(stripped_line)
- elif current_content:
- lines_with_metadata.append(
- {
- "content": "\n".join(current_content),
- "metadata": current_metadata.copy(),
- }
- )
- current_content.clear()
- current_metadata = initial_metadata.copy()
- if current_content:
- lines_with_metadata.append(
- {
- "content": "\n".join(current_content),
- "metadata": current_metadata,
- }
- )
- # lines_with_metadata has each line with associated header metadata
- # aggregate these into chunks based on common metadata
- if not self.return_each_line:
- return self.aggregate_lines_to_chunks(lines_with_metadata)
- else:
- return [
- SplitterDocument(
- page_content=chunk["content"], metadata=chunk["metadata"]
- )
- for chunk in lines_with_metadata
- ]
- class ElementType(TypedDict):
- """Element type as typed dict."""
- url: str
- xpath: str
- content: str
- metadata: dict[str, str]
- class HTMLHeaderTextSplitter:
- """
- Splitting HTML files based on specified headers.
- Requires lxml package.
- """
- def __init__(
- self,
- headers_to_split_on: list[Tuple[str, str]],
- return_each_element: bool = False,
- ):
- """Create a new HTMLHeaderTextSplitter.
- Args:
- headers_to_split_on: list of tuples of headers we want to track mapped to
- (arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4,
- h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2)].
- return_each_element: Return each element w/ associated headers.
- """
- # Output element-by-element or aggregated into chunks w/ common headers
- self.return_each_element = return_each_element
- self.headers_to_split_on = sorted(headers_to_split_on)
- def aggregate_elements_to_chunks(
- self, elements: list[ElementType]
- ) -> list[SplitterDocument]:
- """Combine elements with common metadata into chunks
- Args:
- elements: HTML element content with associated identifying info and metadata
- """
- aggregated_chunks: list[ElementType] = []
- for element in elements:
- if (
- aggregated_chunks
- and aggregated_chunks[-1]["metadata"] == element["metadata"]
- ):
- # If the last element in the aggregated list
- # has the same metadata as the current element,
- # append the current content to the last element's content
- aggregated_chunks[-1]["content"] += " \n" + element["content"]
- else:
- # Otherwise, append the current element to the aggregated list
- aggregated_chunks.append(element)
- return [
- SplitterDocument(
- page_content=chunk["content"], metadata=chunk["metadata"]
- )
- for chunk in aggregated_chunks
- ]
- def split_text_from_url(self, url: str) -> list[SplitterDocument]:
- """Split HTML from web URL
- Args:
- url: web URL
- """
- r = requests.get(url)
- return self.split_text_from_file(BytesIO(r.content))
- def split_text(self, text: str) -> list[SplitterDocument]:
- """Split HTML text string
- Args:
- text: HTML text
- """
- return self.split_text_from_file(StringIO(text))
- def split_text_from_file(self, file: Any) -> list[SplitterDocument]:
- """Split HTML file
- Args:
- file: HTML file
- """
- try:
- from lxml import etree
- except ImportError as e:
- raise ImportError(
- "Unable to import lxml, please install with `pip install lxml`."
- ) from e
- # use lxml library to parse html document and return xml ElementTree
- # Explicitly encoding in utf-8 allows non-English
- # html files to be processed without garbled characters
- parser = etree.HTMLParser(encoding="utf-8")
- tree = etree.parse(file, parser)
- # document transformation for "structure-aware" chunking is handled with xsl.
- # see comments in html_chunks_with_headers.xslt for more detailed information.
- xslt_path = (
- pathlib.Path(__file__).parent
- / "document_transformers/xsl/html_chunks_with_headers.xslt"
- )
- xslt_tree = etree.parse(xslt_path)
- transform = etree.XSLT(xslt_tree)
- result = transform(tree)
- result_dom = etree.fromstring(str(result))
- # create filter and mapping for header metadata
- header_filter = [header[0] for header in self.headers_to_split_on]
- header_mapping = dict(self.headers_to_split_on)
- # map xhtml namespace prefix
- ns_map = {"h": "http://www.w3.org/1999/xhtml"}
- # build list of elements from DOM
- elements = []
- for element in result_dom.findall("*//*", ns_map):
- if element.findall("*[@class='headers']") or element.findall(
- "*[@class='chunk']"
- ):
- elements.append(
- ElementType(
- url=file,
- xpath="".join(
- [
- node.text
- for node in element.findall(
- "*[@class='xpath']", ns_map
- )
- ]
- ),
- content="".join(
- [
- node.text
- for node in element.findall(
- "*[@class='chunk']", ns_map
- )
- ]
- ),
- metadata={
- # Add text of specified headers to metadata using header
- # mapping.
- header_mapping[node.tag]: node.text
- for node in filter(
- lambda x: x.tag in header_filter,
- element.findall(
- "*[@class='headers']/*", ns_map
- ),
- )
- },
- )
- )
- if not self.return_each_element:
- return self.aggregate_elements_to_chunks(elements)
- else:
- return [
- SplitterDocument(
- page_content=chunk["content"], metadata=chunk["metadata"]
- )
- for chunk in elements
- ]
- # should be in newer Python versions (3.11+)
- # @dataclass(frozen=True, kw_only=True, slots=True)
- @dataclass(frozen=True)
- class Tokenizer:
- """Tokenizer data class."""
- chunk_overlap: int
- """Overlap in tokens between chunks"""
- tokens_per_chunk: int
- """Maximum number of tokens per chunk"""
- decode: Callable[[list[int]], str]
- """ Function to decode a list of token ids to a string"""
- encode: Callable[[str], list[int]]
- """ Function to encode a string to a list of token ids"""
- def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
- """Split incoming text and return chunks using tokenizer."""
- splits: list[str] = []
- input_ids = tokenizer.encode(text)
- start_idx = 0
- cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
- chunk_ids = input_ids[start_idx:cur_idx]
- while start_idx < len(input_ids):
- splits.append(tokenizer.decode(chunk_ids))
- if cur_idx == len(input_ids):
- break
- start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
- cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
- chunk_ids = input_ids[start_idx:cur_idx]
- return splits
- class TokenTextSplitter(TextSplitter):
- """Splitting text to tokens using model tokenizer."""
- def __init__(
- self,
- encoding_name: str = "gpt2",
- model: Optional[str] = None,
- allowed_special: Literal["all"] | AbstractSet[str] = set(),
- disallowed_special: Literal["all"] | Collection[str] = "all",
- **kwargs: Any,
- ) -> None:
- """Create a new TextSplitter."""
- super().__init__(**kwargs)
- try:
- import tiktoken
- except ImportError:
- raise ImportError(
- "Could not import tiktoken python package. "
- "This is needed in order to for TokenTextSplitter. "
- "Please install it with `pip install tiktoken`."
- )
- if model is not None:
- enc = tiktoken.encoding_for_model(model)
- else:
- enc = tiktoken.get_encoding(encoding_name)
- self._tokenizer = enc
- self._allowed_special = allowed_special
- self._disallowed_special = disallowed_special
- def split_text(self, text: str) -> list[str]:
- def _encode(_text: str) -> list[int]:
- return self._tokenizer.encode(
- _text,
- allowed_special=self._allowed_special,
- disallowed_special=self._disallowed_special,
- )
- tokenizer = Tokenizer(
- chunk_overlap=self._chunk_overlap,
- tokens_per_chunk=self._chunk_size,
- decode=self._tokenizer.decode,
- encode=_encode,
- )
- return split_text_on_tokens(text=text, tokenizer=tokenizer)
- class SentenceTransformersTokenTextSplitter(TextSplitter):
- """Splitting text to tokens using sentence model tokenizer."""
- def __init__(
- self,
- chunk_overlap: int = 50,
- model: str = "sentence-transformers/all-mpnet-base-v2",
- tokens_per_chunk: Optional[int] = None,
- **kwargs: Any,
- ) -> None:
- """Create a new TextSplitter."""
- super().__init__(**kwargs, chunk_overlap=chunk_overlap)
- try:
- from sentence_transformers import SentenceTransformer
- except ImportError:
- raise ImportError(
- "Could not import sentence_transformer python package. "
- "This is needed in order to for SentenceTransformersTokenTextSplitter. "
- "Please install it with `pip install sentence-transformers`."
- )
- self.model = model
- self._model = SentenceTransformer(self.model, trust_remote_code=True)
- self.tokenizer = self._model.tokenizer
- self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
- def _initialize_chunk_configuration(
- self, *, tokens_per_chunk: Optional[int]
- ) -> None:
- self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
- if tokens_per_chunk is None:
- self.tokens_per_chunk = self.maximum_tokens_per_chunk
- else:
- self.tokens_per_chunk = tokens_per_chunk
- if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
- raise ValueError(
- f"The token limit of the models '{self.model}'"
- f" is: {self.maximum_tokens_per_chunk}."
- f" Argument tokens_per_chunk={self.tokens_per_chunk}"
- f" > maximum token limit."
- )
- def split_text(self, text: str) -> list[str]:
- def encode_strip_start_and_stop_token_ids(text: str) -> list[int]:
- return self._encode(text)[1:-1]
- tokenizer = Tokenizer(
- chunk_overlap=self._chunk_overlap,
- tokens_per_chunk=self.tokens_per_chunk,
- decode=self.tokenizer.decode,
- encode=encode_strip_start_and_stop_token_ids,
- )
- return split_text_on_tokens(text=text, tokenizer=tokenizer)
- def count_tokens(self, *, text: str) -> int:
- return len(self._encode(text))
- _max_length_equal_32_bit_integer: int = 2**32
- def _encode(self, text: str) -> list[int]:
- token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
- text,
- max_length=self._max_length_equal_32_bit_integer,
- truncation="do_not_truncate",
- )
- return token_ids_with_start_and_end_token_ids
- class Language(str, Enum):
- """Enum of the programming languages."""
- CPP = "cpp"
- GO = "go"
- JAVA = "java"
- KOTLIN = "kotlin"
- JS = "js"
- TS = "ts"
- PHP = "php"
- PROTO = "proto"
- PYTHON = "python"
- RST = "rst"
- RUBY = "ruby"
- RUST = "rust"
- SCALA = "scala"
- SWIFT = "swift"
- MARKDOWN = "markdown"
- LATEX = "latex"
- HTML = "html"
- SOL = "sol"
- CSHARP = "csharp"
- COBOL = "cobol"
- C = "c"
- LUA = "lua"
- PERL = "perl"
- class RecursiveCharacterTextSplitter(TextSplitter):
- """Splitting text by recursively look at characters.
- Recursively tries to split by different characters to find one
- that works.
- """
- def __init__(
- self,
- separators: Optional[list[str]] = None,
- keep_separator: bool = True,
- is_separator_regex: bool = False,
- chunk_size: int = 4000,
- chunk_overlap: int = 200,
- **kwargs: Any,
- ) -> None:
- """Create a new TextSplitter."""
- super().__init__(
- chunk_size=chunk_size,
- chunk_overlap=chunk_overlap,
- keep_separator=keep_separator,
- **kwargs,
- )
- self._separators = separators or ["\n\n", "\n", " ", ""]
- self._is_separator_regex = is_separator_regex
- self.chunk_size = chunk_size
- self.chunk_overlap = chunk_overlap
- def _split_text(self, text: str, separators: list[str]) -> list[str]:
- """Split incoming text and return chunks."""
- final_chunks = []
- # Get appropriate separator to use
- separator = separators[-1]
- new_separators = []
- for i, _s in enumerate(separators):
- _separator = _s if self._is_separator_regex else re.escape(_s)
- if _s == "":
- separator = _s
- break
- if re.search(_separator, text):
- separator = _s
- new_separators = separators[i + 1 :]
- break
- _separator = (
- separator if self._is_separator_regex else re.escape(separator)
- )
- splits = _split_text_with_regex(text, _separator, self._keep_separator)
- # Now go merging things, recursively splitting longer texts.
- _good_splits = []
- _separator = "" if self._keep_separator else separator
- for s in splits:
- if self._length_function(s) < self._chunk_size:
- _good_splits.append(s)
- else:
- if _good_splits:
- merged_text = self._merge_splits(_good_splits, _separator)
- final_chunks.extend(merged_text)
- _good_splits = []
- if not new_separators:
- final_chunks.append(s)
- else:
- other_info = self._split_text(s, new_separators)
- final_chunks.extend(other_info)
- if _good_splits:
- merged_text = self._merge_splits(_good_splits, _separator)
- final_chunks.extend(merged_text)
- return final_chunks
- def split_text(self, text: str) -> list[str]:
- return self._split_text(text, self._separators)
- @classmethod
- def from_language(
- cls, language: Language, **kwargs: Any
- ) -> RecursiveCharacterTextSplitter:
- separators = cls.get_separators_for_language(language)
- return cls(separators=separators, is_separator_regex=True, **kwargs)
- @staticmethod
- def get_separators_for_language(language: Language) -> list[str]:
- if language == Language.CPP:
- return [
- # Split along class definitions
- "\nclass ",
- # Split along function definitions
- "\nvoid ",
- "\nint ",
- "\nfloat ",
- "\ndouble ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nwhile ",
- "\nswitch ",
- "\ncase ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.GO:
- return [
- # Split along function definitions
- "\nfunc ",
- "\nvar ",
- "\nconst ",
- "\ntype ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nswitch ",
- "\ncase ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.JAVA:
- return [
- # Split along class definitions
- "\nclass ",
- # Split along method definitions
- "\npublic ",
- "\nprotected ",
- "\nprivate ",
- "\nstatic ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nwhile ",
- "\nswitch ",
- "\ncase ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.KOTLIN:
- return [
- # Split along class definitions
- "\nclass ",
- # Split along method definitions
- "\npublic ",
- "\nprotected ",
- "\nprivate ",
- "\ninternal ",
- "\ncompanion ",
- "\nfun ",
- "\nval ",
- "\nvar ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nwhile ",
- "\nwhen ",
- "\ncase ",
- "\nelse ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.JS:
- return [
- # Split along function definitions
- "\nfunction ",
- "\nconst ",
- "\nlet ",
- "\nvar ",
- "\nclass ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nwhile ",
- "\nswitch ",
- "\ncase ",
- "\ndefault ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.TS:
- return [
- "\nenum ",
- "\ninterface ",
- "\nnamespace ",
- "\ntype ",
- # Split along class definitions
- "\nclass ",
- # Split along function definitions
- "\nfunction ",
- "\nconst ",
- "\nlet ",
- "\nvar ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nwhile ",
- "\nswitch ",
- "\ncase ",
- "\ndefault ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.PHP:
- return [
- # Split along function definitions
- "\nfunction ",
- # Split along class definitions
- "\nclass ",
- # Split along control flow statements
- "\nif ",
- "\nforeach ",
- "\nwhile ",
- "\ndo ",
- "\nswitch ",
- "\ncase ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.PROTO:
- return [
- # Split along message definitions
- "\nmessage ",
- # Split along service definitions
- "\nservice ",
- # Split along enum definitions
- "\nenum ",
- # Split along option definitions
- "\noption ",
- # Split along import statements
- "\nimport ",
- # Split along syntax declarations
- "\nsyntax ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.PYTHON:
- return [
- # First, try to split along class definitions
- "\nclass ",
- "\ndef ",
- "\n\tdef ",
- # Now split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.RST:
- return [
- # Split along section titles
- "\n=+\n",
- "\n-+\n",
- "\n\\*+\n",
- # Split along directive markers
- "\n\n.. *\n\n",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.RUBY:
- return [
- # Split along method definitions
- "\ndef ",
- "\nclass ",
- # Split along control flow statements
- "\nif ",
- "\nunless ",
- "\nwhile ",
- "\nfor ",
- "\ndo ",
- "\nbegin ",
- "\nrescue ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.RUST:
- return [
- # Split along function definitions
- "\nfn ",
- "\nconst ",
- "\nlet ",
- # Split along control flow statements
- "\nif ",
- "\nwhile ",
- "\nfor ",
- "\nloop ",
- "\nmatch ",
- "\nconst ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.SCALA:
- return [
- # Split along class definitions
- "\nclass ",
- "\nobject ",
- # Split along method definitions
- "\ndef ",
- "\nval ",
- "\nvar ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nwhile ",
- "\nmatch ",
- "\ncase ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.SWIFT:
- return [
- # Split along function definitions
- "\nfunc ",
- # Split along class definitions
- "\nclass ",
- "\nstruct ",
- "\nenum ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nwhile ",
- "\ndo ",
- "\nswitch ",
- "\ncase ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.MARKDOWN:
- return [
- # First, try to split along Markdown headings (starting with level 2)
- "\n#{1,6} ",
- # Note the alternative syntax for headings (below) is not handled here
- # Heading level 2
- # ---------------
- # End of code block
- "```\n",
- # Horizontal lines
- "\n\\*\\*\\*+\n",
- "\n---+\n",
- "\n___+\n",
- # Note that this splitter doesn't handle horizontal lines defined
- # by *three or more* of ***, ---, or ___, but this is not handled
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.LATEX:
- return [
- # First, try to split along Latex sections
- "\n\\\\chapter{",
- "\n\\\\section{",
- "\n\\\\subsection{",
- "\n\\\\subsubsection{",
- # Now split by environments
- "\n\\\\begin{enumerate}",
- "\n\\\\begin{itemize}",
- "\n\\\\begin{description}",
- "\n\\\\begin{list}",
- "\n\\\\begin{quote}",
- "\n\\\\begin{quotation}",
- "\n\\\\begin{verse}",
- "\n\\\\begin{verbatim}",
- # Now split by math environments
- "\n\\\begin{align}",
- "$$",
- "$",
- # Now split by the normal type of lines
- " ",
- "",
- ]
- elif language == Language.HTML:
- return [
- # First, try to split along HTML tags
- "<body",
- "<div",
- "<p",
- "<br",
- "<li",
- "<h1",
- "<h2",
- "<h3",
- "<h4",
- "<h5",
- "<h6",
- "<span",
- "<table",
- "<tr",
- "<td",
- "<th",
- "<ul",
- "<ol",
- "<header",
- "<footer",
- "<nav",
- # Head
- "<head",
- "<style",
- "<script",
- "<meta",
- "<title",
- "",
- ]
- elif language == Language.CSHARP:
- return [
- "\ninterface ",
- "\nenum ",
- "\nimplements ",
- "\ndelegate ",
- "\nevent ",
- # Split along class definitions
- "\nclass ",
- "\nabstract ",
- # Split along method definitions
- "\npublic ",
- "\nprotected ",
- "\nprivate ",
- "\nstatic ",
- "\nreturn ",
- # Split along control flow statements
- "\nif ",
- "\ncontinue ",
- "\nfor ",
- "\nforeach ",
- "\nwhile ",
- "\nswitch ",
- "\nbreak ",
- "\ncase ",
- "\nelse ",
- # Split by exceptions
- "\ntry ",
- "\nthrow ",
- "\nfinally ",
- "\ncatch ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.SOL:
- return [
- # Split along compiler information definitions
- "\npragma ",
- "\nusing ",
- # Split along contract definitions
- "\ncontract ",
- "\ninterface ",
- "\nlibrary ",
- # Split along method definitions
- "\nconstructor ",
- "\ntype ",
- "\nfunction ",
- "\nevent ",
- "\nmodifier ",
- "\nerror ",
- "\nstruct ",
- "\nenum ",
- # Split along control flow statements
- "\nif ",
- "\nfor ",
- "\nwhile ",
- "\ndo while ",
- "\nassembly ",
- # Split by the normal type of lines
- "\n\n",
- "\n",
- " ",
- "",
- ]
- elif language == Language.COBOL:
- return [
- # Split along divisions
- "\nIDENTIFICATION DIVISION.",
- "\nENVIRONMENT DIVISION.",
- "\nDATA DIVISION.",
- "\nPROCEDURE DIVISION.",
- # Split along sections within DATA DIVISION
- "\nWORKING-STORAGE SECTION.",
- "\nLINKAGE SECTION.",
- "\nFILE SECTION.",
- # Split along sections within PROCEDURE DIVISION
- "\nINPUT-OUTPUT SECTION.",
- # Split along paragraphs and common statements
- "\nOPEN ",
- "\nCLOSE ",
- "\nREAD ",
- "\nWRITE ",
- "\nIF ",
- "\nELSE ",
- "\nMOVE ",
- "\nPERFORM ",
- "\nUNTIL ",
- "\nVARYING ",
- "\nACCEPT ",
- "\nDISPLAY ",
- "\nSTOP RUN.",
- # Split by the normal type of lines
- "\n",
- " ",
- "",
- ]
- else:
- raise ValueError(
- f"Language {language} is not supported! "
- f"Please choose from {list(Language)}"
- )
- class NLTKTextSplitter(TextSplitter):
- """Splitting text using NLTK package."""
- def __init__(
- self, separator: str = "\n\n", language: str = "english", **kwargs: Any
- ) -> None:
- """Initialize the NLTK splitter."""
- super().__init__(**kwargs)
- try:
- from nltk.tokenize import sent_tokenize
- self._tokenizer = sent_tokenize
- except ImportError:
- raise ImportError(
- "NLTK is not installed, please install it with `pip install nltk`."
- )
- self._separator = separator
- self._language = language
- def split_text(self, text: str) -> list[str]:
- """Split incoming text and return chunks."""
- # First we naively split the large input into a bunch of smaller ones.
- splits = self._tokenizer(text, language=self._language)
- return self._merge_splits(splits, self._separator)
- class SpacyTextSplitter(TextSplitter):
- """Splitting text using Spacy package.
- Per default, Spacy's `en_core_web_sm` model is used and
- its default max_length is 1000000 (it is the length of maximum character
- this model takes which can be increased for large files). For a faster, but
- potentially less accurate splitting, you can use `pipe='sentencizer'`.
- """
- def __init__(
- self,
- separator: str = "\n\n",
- pipe: str = "en_core_web_sm",
- max_length: int = 1_000_000,
- **kwargs: Any,
- ) -> None:
- """Initialize the spacy text splitter."""
- super().__init__(**kwargs)
- self._tokenizer = _make_spacy_pipe_for_splitting(
- pipe, max_length=max_length
- )
- self._separator = separator
- def split_text(self, text: str) -> list[str]:
- """Split incoming text and return chunks."""
- splits = (s.text for s in self._tokenizer(text).sents)
- return self._merge_splits(splits, self._separator)
- class KonlpyTextSplitter(TextSplitter):
- """Splitting text using Konlpy package.
- It is good for splitting Korean text.
- """
- def __init__(
- self,
- separator: str = "\n\n",
- **kwargs: Any,
- ) -> None:
- """Initialize the Konlpy text splitter."""
- super().__init__(**kwargs)
- self._separator = separator
- try:
- from konlpy.tag import Kkma
- except ImportError:
- raise ImportError(
- """
- Konlpy is not installed, please install it with
- `pip install konlpy`
- """
- )
- self.kkma = Kkma()
- def split_text(self, text: str) -> list[str]:
- """Split incoming text and return chunks."""
- splits = self.kkma.sentences(text)
- return self._merge_splits(splits, self._separator)
- # For backwards compatibility
- class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
- """Attempts to split the text along Python syntax."""
- def __init__(self, **kwargs: Any) -> None:
- """Initialize a PythonCodeTextSplitter."""
- separators = self.get_separators_for_language(Language.PYTHON)
- super().__init__(separators=separators, **kwargs)
- class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
- """Attempts to split the text along Markdown-formatted headings."""
- def __init__(self, **kwargs: Any) -> None:
- """Initialize a MarkdownTextSplitter."""
- separators = self.get_separators_for_language(Language.MARKDOWN)
- super().__init__(separators=separators, **kwargs)
- class LatexTextSplitter(RecursiveCharacterTextSplitter):
- """Attempts to split the text along Latex-formatted layout elements."""
- def __init__(self, **kwargs: Any) -> None:
- """Initialize a LatexTextSplitter."""
- separators = self.get_separators_for_language(Language.LATEX)
- super().__init__(separators=separators, **kwargs)
- class RecursiveJsonSplitter:
- def __init__(
- self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None
- ):
- super().__init__()
- self.max_chunk_size = max_chunk_size
- self.min_chunk_size = (
- min_chunk_size
- if min_chunk_size is not None
- else max(max_chunk_size - 200, 50)
- )
- @staticmethod
- def _json_size(data: dict) -> int:
- """Calculate the size of the serialized JSON object."""
- return len(json.dumps(data))
- @staticmethod
- def _set_nested_dict(d: dict, path: list[str], value: Any) -> None:
- """Set a value in a nested dictionary based on the given path."""
- for key in path[:-1]:
- d = d.setdefault(key, {})
- d[path[-1]] = value
- def _list_to_dict_preprocessing(self, data: Any) -> Any:
- if isinstance(data, dict):
- # Process each key-value pair in the dictionary
- return {
- k: self._list_to_dict_preprocessing(v) for k, v in data.items()
- }
- elif isinstance(data, list):
- # Convert the list to a dictionary with index-based keys
- return {
- str(i): self._list_to_dict_preprocessing(item)
- for i, item in enumerate(data)
- }
- else:
- # Base case: the item is neither a dict nor a list, so return it unchanged
- return data
- def _json_split(
- self,
- data: dict[str, Any],
- current_path: list[str] = [],
- chunks: list[dict] = [{}],
- ) -> list[dict]:
- """
- Split json into maximum size dictionaries while preserving structure.
- """
- if isinstance(data, dict):
- for key, value in data.items():
- new_path = current_path + [key]
- chunk_size = self._json_size(chunks[-1])
- size = self._json_size({key: value})
- remaining = self.max_chunk_size - chunk_size
- if size < remaining:
- # Add item to current chunk
- self._set_nested_dict(chunks[-1], new_path, value)
- else:
- if chunk_size >= self.min_chunk_size:
- # Chunk is big enough, start a new chunk
- chunks.append({})
- # Iterate
- self._json_split(value, new_path, chunks)
- else:
- # handle single item
- self._set_nested_dict(chunks[-1], current_path, data)
- return chunks
- def split_json(
- self,
- json_data: dict[str, Any],
- convert_lists: bool = False,
- ) -> list[dict]:
- """Splits JSON into a list of JSON chunks"""
- if convert_lists:
- chunks = self._json_split(
- self._list_to_dict_preprocessing(json_data)
- )
- else:
- chunks = self._json_split(json_data)
- # Remove the last chunk if it's empty
- if not chunks[-1]:
- chunks.pop()
- return chunks
- def split_text(
- self, json_data: dict[str, Any], convert_lists: bool = False
- ) -> list[str]:
- """Splits JSON into a list of JSON formatted strings"""
- chunks = self.split_json(
- json_data=json_data, convert_lists=convert_lists
- )
- # Convert to string
- return [json.dumps(chunk) for chunk in chunks]
- def create_documents(
- self,
- texts: list[dict],
- convert_lists: bool = False,
- metadatas: Optional[list[dict]] = None,
- ) -> list[SplitterDocument]:
- """Create documents from a list of json objects (dict)."""
- _metadatas = metadatas or [{}] * len(texts)
- documents = []
- for i, text in enumerate(texts):
- for chunk in self.split_text(
- json_data=text, convert_lists=convert_lists
- ):
- metadata = copy.deepcopy(_metadatas[i])
- new_doc = SplitterDocument(
- page_content=chunk, metadata=metadata
- )
- documents.append(new_doc)
- return documents
|