# Source - LangChain # URL: https://github.com/langchain-ai/langchain/blob/6a5b084704afa22ca02f78d0464f35aed75d1ff2/libs/langchain/langchain/text_splitter.py#L851 """**Text Splitters** are classes for splitting text. **Class hierarchy:** .. code-block:: BaseDocumentTransformer --> TextSplitter --> TextSplitter # Example: CharacterTextSplitter RecursiveCharacterTextSplitter --> TextSplitter Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter. **Main helpers:** .. code-block:: Document, Tokenizer, Language, LineType, HeaderType """ # noqa: E501 from __future__ import annotations import copy import json import logging import pathlib import re from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from io import BytesIO, StringIO from typing import ( AbstractSet, Any, Callable, Collection, Iterable, Literal, Optional, Sequence, Tuple, Type, TypedDict, TypeVar, cast, ) import requests from pydantic import BaseModel, Field, PrivateAttr from typing_extensions import NotRequired logger = logging.getLogger() TS = TypeVar("TS", bound="TextSplitter") class BaseSerialized(TypedDict): """Base class for serialized objects.""" lc: int id: list[str] name: NotRequired[str] graph: NotRequired[dict[str, Any]] class SerializedConstructor(BaseSerialized): """Serialized constructor.""" type: Literal["constructor"] kwargs: dict[str, Any] class SerializedSecret(BaseSerialized): """Serialized secret.""" type: Literal["secret"] class SerializedNotImplemented(BaseSerialized): """Serialized not implemented.""" type: Literal["not_implemented"] repr: Optional[str] def try_neq_default(value: Any, key: str, model: BaseModel) -> bool: """Try to determine if a value is different from the default. Args: value: The value. key: The key. model: The model. Returns: Whether the value is different from the default. """ try: return model.__fields__[key].get_default() != value except Exception: return True class Serializable(BaseModel, ABC): """Serializable base class.""" @classmethod def is_lc_serializable(cls) -> bool: """Is this class serializable?""" return False @classmethod def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. For example, if the class is `langchain.llms.openai.OpenAI`, then the namespace is ["langchain", "llms", "openai"] """ return cls.__module__.split(".") @property def lc_secrets(self) -> dict[str, str]: """A map of constructor argument names to secret ids. For example, {"openai_api_key": "OPENAI_API_KEY"} """ return {} @property def lc_attributes(self) -> dict: """List of attribute names that should be included in the serialized kwargs. These attributes must be accepted by the constructor. """ return {} @classmethod def lc_id(cls) -> list[str]: """A unique identifier for this class for serialization purposes. The unique identifier is a list of strings that describes the path to the object. """ return [*cls.get_lc_namespace(), cls.__name__] class Config: extra = "ignore" def __repr_args__(self) -> Any: return [ (k, v) for k, v in super().__repr_args__() if (k not in self.__fields__ or try_neq_default(v, k, self)) ] _lc_kwargs: dict[str, Any] = PrivateAttr(default_factory=dict) def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self._lc_kwargs = kwargs def to_json( self, ) -> SerializedConstructor | SerializedNotImplemented: if not self.is_lc_serializable(): return self.to_json_not_implemented() secrets = dict() # Get latest values for kwargs if there is an attribute with same name lc_kwargs = { k: getattr(self, k, v) for k, v in self._lc_kwargs.items() if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore } # Merge the lc_secrets and lc_attributes from every class in the MRO for cls in [None, *self.__class__.mro()]: # Once we get to Serializable, we're done if cls is Serializable: break if cls: deprecated_attributes = [ "lc_namespace", "lc_serializable", ] for attr in deprecated_attributes: if hasattr(cls, attr): raise ValueError( f"Class {self.__class__} has a deprecated " f"attribute {attr}. Please use the corresponding " f"classmethod instead." ) # Get a reference to self bound to each class in the MRO this = cast( Serializable, self if cls is None else super(cls, self) ) secrets.update(this.lc_secrets) # Now also add the aliases for the secrets # This ensures known secret aliases are hidden. # Note: this does NOT hide any other extra kwargs # that are not present in the fields. for key in list(secrets): value = secrets[key] if key in this.__fields__: secrets[this.__fields__[key].alias] = value # type: ignore lc_kwargs.update(this.lc_attributes) # include all secrets, even if not specified in kwargs # as these secrets may be passed as an environment variable instead for key in secrets.keys(): secret_value = getattr(self, key, None) or lc_kwargs.get(key) if secret_value is not None: lc_kwargs.update({key: secret_value}) return { "lc": 1, "type": "constructor", "id": self.lc_id(), "kwargs": ( lc_kwargs if not secrets else _replace_secrets(lc_kwargs, secrets) ), } def to_json_not_implemented(self) -> SerializedNotImplemented: return to_json_not_implemented(self) def _replace_secrets( root: dict[Any, Any], secrets_map: dict[str, str] ) -> dict[Any, Any]: result = root.copy() for path, secret_id in secrets_map.items(): [*parts, last] = path.split(".") current = result for part in parts: if part not in current: break current[part] = current[part].copy() current = current[part] if last in current: current[last] = { "lc": 1, "type": "secret", "id": [secret_id], } return result def to_json_not_implemented(obj: object) -> SerializedNotImplemented: """Serialize a "not implemented" object. Args: obj: object to serialize Returns: SerializedNotImplemented """ _id: list[str] = [] try: if hasattr(obj, "__name__"): _id = [*obj.__module__.split("."), obj.__name__] elif hasattr(obj, "__class__"): _id = [ *obj.__class__.__module__.split("."), obj.__class__.__name__, ] except Exception: pass result: SerializedNotImplemented = { "lc": 1, "type": "not_implemented", "id": _id, "repr": None, } try: result["repr"] = repr(obj) except Exception: pass return result class SplitterDocument(Serializable): """Class for storing a piece of text and associated metadata.""" page_content: str """String text.""" metadata: dict = Field(default_factory=dict) """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ type: Literal["Document"] = "Document" def __init__(self, page_content: str, **kwargs: Any) -> None: """Pass page_content in as positional or named arg.""" super().__init__(page_content=page_content, **kwargs) @classmethod def is_lc_serializable(cls) -> bool: """Return whether this class is serializable.""" return True @classmethod def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "document"] class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. A document transformation system takes a sequence of Documents and returns a sequence of transformed Documents. Example: .. code-block:: python class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): embeddings: Embeddings similarity_fn: Callable = cosine_similarity similarity_threshold: float = 0.95 class Config: arbitrary_types_allowed = True def transform_documents( self, documents: Sequence[Document], **kwargs: Any ) -> Sequence[Document]: stateful_documents = get_stateful_documents(documents) embedded_documents = _get_embeddings_from_stateful_docs( self.embeddings, stateful_documents ) included_idxs = _filter_similar_embeddings( embedded_documents, self.similarity_fn, self.similarity_threshold ) return [stateful_documents[i] for i in sorted(included_idxs)] async def atransform_documents( self, documents: Sequence[Document], **kwargs: Any ) -> Sequence[Document]: raise NotImplementedError """ # noqa: E501 @abstractmethod def transform_documents( self, documents: Sequence[SplitterDocument], **kwargs: Any ) -> Sequence[SplitterDocument]: """Transform a list of documents. Args: documents: A sequence of Documents to be transformed. Returns: A list of transformed Documents. """ async def atransform_documents( self, documents: Sequence[SplitterDocument], **kwargs: Any ) -> Sequence[SplitterDocument]: """Asynchronously transform a list of documents. Args: documents: A sequence of Documents to be transformed. Returns: A list of transformed Documents. """ raise NotImplementedError("This method is not implemented.") # return await langchain_core.runnables.config.run_in_executor( # None, self.transform_documents, documents, **kwargs # ) def _make_spacy_pipe_for_splitting( pipe: str, *, max_length: int = 1_000_000 ) -> Any: # avoid importing spacy try: import spacy except ImportError: raise ImportError( "Spacy is not installed, please install it with `pip install spacy`." ) if pipe == "sentencizer": from spacy.lang.en import English sentencizer = English() sentencizer.add_pipe("sentencizer") else: sentencizer = spacy.load(pipe, exclude=["ner", "tagger"]) sentencizer.max_length = max_length return sentencizer def _split_text_with_regex( text: str, separator: str, keep_separator: bool ) -> list[str]: # Now that we have the separator, split the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. _splits = re.split(f"({separator})", text) splits = [ _splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2) ] if len(_splits) % 2 == 0: splits += _splits[-1:] splits = [_splits[0]] + splits else: splits = re.split(separator, text) else: splits = list(text) return [s for s in splits if s != ""] class TextSplitter(BaseDocumentTransformer, ABC): """Interface for splitting text into chunks.""" def __init__( self, chunk_size: int = 4000, chunk_overlap: int = 200, length_function: Callable[[str], int] = len, keep_separator: bool = False, add_start_index: bool = False, strip_whitespace: bool = True, ) -> None: """Create a new TextSplitter. Args: chunk_size: Maximum size of chunks to return chunk_overlap: Overlap in characters between chunks length_function: Function that measures the length of given chunks keep_separator: Whether to keep the separator in the chunks add_start_index: If `True`, includes chunk's start index in metadata strip_whitespace: If `True`, strips whitespace from the start and end of every document """ if chunk_overlap > chunk_size: raise ValueError( f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller." ) self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._length_function = length_function self._keep_separator = keep_separator self._add_start_index = add_start_index self._strip_whitespace = strip_whitespace @abstractmethod def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" def create_documents( self, texts: list[str], metadatas: Optional[list[dict]] = None ) -> list[SplitterDocument]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] for i, text in enumerate(texts): index = 0 previous_chunk_len = 0 for chunk in self.split_text(text): metadata = copy.deepcopy(_metadatas[i]) if self._add_start_index: offset = index + previous_chunk_len - self._chunk_overlap index = text.find(chunk, max(0, offset)) metadata["start_index"] = index previous_chunk_len = len(chunk) new_doc = SplitterDocument( page_content=chunk, metadata=metadata ) documents.append(new_doc) return documents def split_documents( self, documents: Iterable[SplitterDocument] ) -> list[SplitterDocument]: """Split documents.""" texts, metadatas = [], [] for doc in documents: texts.append(doc.page_content) metadatas.append(doc.metadata) return self.create_documents(texts, metadatas=metadatas) def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: text = separator.join(docs) if self._strip_whitespace: text = text.strip() if text == "": return None else: return text def _merge_splits( self, splits: Iterable[str], separator: str ) -> list[str]: # We now want to combine these smaller pieces into medium size # chunks to send to the LLM. separator_len = self._length_function(separator) docs = [] current_doc: list[str] = [] total = 0 for d in splits: _len = self._length_function(d) if ( total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size ): if total > self._chunk_size: logger.warning( f"Created a chunk of size {total}, " f"which is longer than the specified {self._chunk_size}" ) if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) # Keep on popping if: # - we have a larger chunk than in the chunk overlap # - or if we still have any chunks and the length is long while total > self._chunk_overlap or ( total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 ): total -= self._length_function(current_doc[0]) + ( separator_len if len(current_doc) > 1 else 0 ) current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) return docs @classmethod def from_huggingface_tokenizer( cls, tokenizer: Any, **kwargs: Any ) -> TextSplitter: """Text splitter that uses HuggingFace tokenizer to count length.""" try: from transformers import PreTrainedTokenizerBase if not isinstance(tokenizer, PreTrainedTokenizerBase): raise ValueError( "Tokenizer received was not an instance of PreTrainedTokenizerBase" ) def _huggingface_tokenizer_length(text: str) -> int: return len(tokenizer.encode(text)) except ImportError: raise ValueError( "Could not import transformers python package. " "Please install it with `pip install transformers`." ) return cls(length_function=_huggingface_tokenizer_length, **kwargs) @classmethod def from_tiktoken_encoder( cls: Type[TS], encoding_name: str = "gpt2", model: Optional[str] = None, allowed_special: Literal["all"] | AbstractSet[str] = set(), disallowed_special: Literal["all"] | Collection[str] = "all", **kwargs: Any, ) -> TS: """Text splitter that uses tiktoken encoder to count length.""" try: import tiktoken except ImportError: raise ImportError( "Could not import tiktoken python package. " "This is needed in order to calculate max_tokens_for_prompt. " "Please install it with `pip install tiktoken`." ) if model is not None: enc = tiktoken.encoding_for_model(model) else: enc = tiktoken.get_encoding(encoding_name) def _tiktoken_encoder(text: str) -> int: return len( enc.encode( text, allowed_special=allowed_special, disallowed_special=disallowed_special, ) ) if issubclass(cls, TokenTextSplitter): extra_kwargs = { "encoding_name": encoding_name, "model": model, "allowed_special": allowed_special, "disallowed_special": disallowed_special, } kwargs = {**kwargs, **extra_kwargs} return cls(length_function=_tiktoken_encoder, **kwargs) def transform_documents( self, documents: Sequence[SplitterDocument], **kwargs: Any ) -> Sequence[SplitterDocument]: """Transform sequence of documents by splitting them.""" return self.split_documents(list(documents)) class CharacterTextSplitter(TextSplitter): """Splitting text that looks at characters.""" DEFAULT_SEPARATOR: str = "\n\n" def __init__( self, separator: str = DEFAULT_SEPARATOR, is_separator_regex: bool = False, **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) self._separator = separator self._is_separator_regex = is_separator_regex def split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" # First we naively split the large input into a bunch of smaller ones. separator = ( self._separator if self._is_separator_regex else re.escape(self._separator) ) splits = _split_text_with_regex(text, separator, self._keep_separator) _separator = "" if self._keep_separator else self._separator return self._merge_splits(splits, _separator) class LineType(TypedDict): """Line type as typed dict.""" metadata: dict[str, str] content: str class HeaderType(TypedDict): """Header type as typed dict.""" level: int name: str data: str class MarkdownHeaderTextSplitter: """Splitting markdown files based on specified headers.""" def __init__( self, headers_to_split_on: list[Tuple[str, str]], return_each_line: bool = False, strip_headers: bool = True, ): """Create a new MarkdownHeaderTextSplitter. Args: headers_to_split_on: Headers we want to track return_each_line: Return each line w/ associated headers strip_headers: Strip split headers from the content of the chunk """ # Output line-by-line or aggregated into chunks w/ common headers self.return_each_line = return_each_line # Given the headers we want to split on, # (e.g., "#, ##, etc") order by length self.headers_to_split_on = sorted( headers_to_split_on, key=lambda split: len(split[0]), reverse=True ) # Strip headers split headers from the content of the chunk self.strip_headers = strip_headers def aggregate_lines_to_chunks( self, lines: list[LineType] ) -> list[SplitterDocument]: """Combine lines with common metadata into chunks Args: lines: Line of text / associated header metadata """ aggregated_chunks: list[LineType] = [] for line in lines: if ( aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"] ): # If the last line in the aggregated list # has the same metadata as the current line, # append the current content to the last lines's content aggregated_chunks[-1]["content"] += " \n" + line["content"] elif ( aggregated_chunks and aggregated_chunks[-1]["metadata"] != line["metadata"] # may be issues if other metadata is present and len(aggregated_chunks[-1]["metadata"]) < len(line["metadata"]) and aggregated_chunks[-1]["content"].split("\n")[-1][0] == "#" and not self.strip_headers ): # If the last line in the aggregated list # has different metadata as the current line, # and has shallower header level than the current line, # and the last line is a header, # and we are not stripping headers, # append the current content to the last line's content aggregated_chunks[-1]["content"] += " \n" + line["content"] # and update the last line's metadata aggregated_chunks[-1]["metadata"] = line["metadata"] else: # Otherwise, append the current line to the aggregated list aggregated_chunks.append(line) return [ SplitterDocument( page_content=chunk["content"], metadata=chunk["metadata"] ) for chunk in aggregated_chunks ] def split_text(self, text: str) -> list[SplitterDocument]: """Split markdown file Args: text: Markdown file""" # Split the input text by newline character ("\n"). lines = text.split("\n") # Final output lines_with_metadata: list[LineType] = [] # Content and metadata of the chunk currently being processed current_content: list[str] = [] current_metadata: dict[str, str] = {} # Keep track of the nested header structure # header_stack: list[dict[str, int | str]] = [] header_stack: list[HeaderType] = [] initial_metadata: dict[str, str] = {} in_code_block = False opening_fence = "" for line in lines: stripped_line = line.strip() if not in_code_block: # Exclude inline code spans if ( stripped_line.startswith("```") and stripped_line.count("```") == 1 ): in_code_block = True opening_fence = "```" elif stripped_line.startswith("~~~"): in_code_block = True opening_fence = "~~~" else: if stripped_line.startswith(opening_fence): in_code_block = False opening_fence = "" if in_code_block: current_content.append(stripped_line) continue # Check each line against each of the header types (e.g., #, ##) for sep, name in self.headers_to_split_on: # Check if line starts with a header that we intend to split on if stripped_line.startswith(sep) and ( # Header with no text OR header is followed by space # Both are valid conditions that sep is being used a header len(stripped_line) == len(sep) or stripped_line[len(sep)] == " " ): # Ensure we are tracking the header as metadata if name is not None: # Get the current header level current_header_level = sep.count("#") # Pop out headers of lower or same level from the stack while ( header_stack and header_stack[-1]["level"] >= current_header_level ): # We have encountered a new header # at the same or higher level popped_header = header_stack.pop() # Clear the metadata for the # popped header in initial_metadata if popped_header["name"] in initial_metadata: initial_metadata.pop(popped_header["name"]) # Push the current header to the stack header: HeaderType = { "level": current_header_level, "name": name, "data": stripped_line[len(sep) :].strip(), } header_stack.append(header) # Update initial_metadata with the current header initial_metadata[name] = header["data"] # Add the previous line to the lines_with_metadata # only if current_content is not empty if current_content: lines_with_metadata.append( { "content": "\n".join(current_content), "metadata": current_metadata.copy(), } ) current_content.clear() if not self.strip_headers: current_content.append(stripped_line) break else: if stripped_line: current_content.append(stripped_line) elif current_content: lines_with_metadata.append( { "content": "\n".join(current_content), "metadata": current_metadata.copy(), } ) current_content.clear() current_metadata = initial_metadata.copy() if current_content: lines_with_metadata.append( { "content": "\n".join(current_content), "metadata": current_metadata, } ) # lines_with_metadata has each line with associated header metadata # aggregate these into chunks based on common metadata if not self.return_each_line: return self.aggregate_lines_to_chunks(lines_with_metadata) else: return [ SplitterDocument( page_content=chunk["content"], metadata=chunk["metadata"] ) for chunk in lines_with_metadata ] class ElementType(TypedDict): """Element type as typed dict.""" url: str xpath: str content: str metadata: dict[str, str] class HTMLHeaderTextSplitter: """ Splitting HTML files based on specified headers. Requires lxml package. """ def __init__( self, headers_to_split_on: list[Tuple[str, str]], return_each_element: bool = False, ): """Create a new HTMLHeaderTextSplitter. Args: headers_to_split_on: list of tuples of headers we want to track mapped to (arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4, h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2)]. return_each_element: Return each element w/ associated headers. """ # Output element-by-element or aggregated into chunks w/ common headers self.return_each_element = return_each_element self.headers_to_split_on = sorted(headers_to_split_on) def aggregate_elements_to_chunks( self, elements: list[ElementType] ) -> list[SplitterDocument]: """Combine elements with common metadata into chunks Args: elements: HTML element content with associated identifying info and metadata """ aggregated_chunks: list[ElementType] = [] for element in elements: if ( aggregated_chunks and aggregated_chunks[-1]["metadata"] == element["metadata"] ): # If the last element in the aggregated list # has the same metadata as the current element, # append the current content to the last element's content aggregated_chunks[-1]["content"] += " \n" + element["content"] else: # Otherwise, append the current element to the aggregated list aggregated_chunks.append(element) return [ SplitterDocument( page_content=chunk["content"], metadata=chunk["metadata"] ) for chunk in aggregated_chunks ] def split_text_from_url(self, url: str) -> list[SplitterDocument]: """Split HTML from web URL Args: url: web URL """ r = requests.get(url) return self.split_text_from_file(BytesIO(r.content)) def split_text(self, text: str) -> list[SplitterDocument]: """Split HTML text string Args: text: HTML text """ return self.split_text_from_file(StringIO(text)) def split_text_from_file(self, file: Any) -> list[SplitterDocument]: """Split HTML file Args: file: HTML file """ try: from lxml import etree except ImportError as e: raise ImportError( "Unable to import lxml, please install with `pip install lxml`." ) from e # use lxml library to parse html document and return xml ElementTree # Explicitly encoding in utf-8 allows non-English # html files to be processed without garbled characters parser = etree.HTMLParser(encoding="utf-8") tree = etree.parse(file, parser) # document transformation for "structure-aware" chunking is handled with xsl. # see comments in html_chunks_with_headers.xslt for more detailed information. xslt_path = ( pathlib.Path(__file__).parent / "document_transformers/xsl/html_chunks_with_headers.xslt" ) xslt_tree = etree.parse(xslt_path) transform = etree.XSLT(xslt_tree) result = transform(tree) result_dom = etree.fromstring(str(result)) # create filter and mapping for header metadata header_filter = [header[0] for header in self.headers_to_split_on] header_mapping = dict(self.headers_to_split_on) # map xhtml namespace prefix ns_map = {"h": "http://www.w3.org/1999/xhtml"} # build list of elements from DOM elements = [] for element in result_dom.findall("*//*", ns_map): if element.findall("*[@class='headers']") or element.findall( "*[@class='chunk']" ): elements.append( ElementType( url=file, xpath="".join( [ node.text for node in element.findall( "*[@class='xpath']", ns_map ) ] ), content="".join( [ node.text for node in element.findall( "*[@class='chunk']", ns_map ) ] ), metadata={ # Add text of specified headers to metadata using header # mapping. header_mapping[node.tag]: node.text for node in filter( lambda x: x.tag in header_filter, element.findall( "*[@class='headers']/*", ns_map ), ) }, ) ) if not self.return_each_element: return self.aggregate_elements_to_chunks(elements) else: return [ SplitterDocument( page_content=chunk["content"], metadata=chunk["metadata"] ) for chunk in elements ] # should be in newer Python versions (3.11+) # @dataclass(frozen=True, kw_only=True, slots=True) @dataclass(frozen=True) class Tokenizer: """Tokenizer data class.""" chunk_overlap: int """Overlap in tokens between chunks""" tokens_per_chunk: int """Maximum number of tokens per chunk""" decode: Callable[[list[int]], str] """ Function to decode a list of token ids to a string""" encode: Callable[[str], list[int]] """ Function to encode a string to a list of token ids""" def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]: """Split incoming text and return chunks using tokenizer.""" splits: list[str] = [] input_ids = tokenizer.encode(text) start_idx = 0 cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] while start_idx < len(input_ids): splits.append(tokenizer.decode(chunk_ids)) if cur_idx == len(input_ids): break start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] return splits class TokenTextSplitter(TextSplitter): """Splitting text to tokens using model tokenizer.""" def __init__( self, encoding_name: str = "gpt2", model: Optional[str] = None, allowed_special: Literal["all"] | AbstractSet[str] = set(), disallowed_special: Literal["all"] | Collection[str] = "all", **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) try: import tiktoken except ImportError: raise ImportError( "Could not import tiktoken python package. " "This is needed in order to for TokenTextSplitter. " "Please install it with `pip install tiktoken`." ) if model is not None: enc = tiktoken.encoding_for_model(model) else: enc = tiktoken.get_encoding(encoding_name) self._tokenizer = enc self._allowed_special = allowed_special self._disallowed_special = disallowed_special def split_text(self, text: str) -> list[str]: def _encode(_text: str) -> list[int]: return self._tokenizer.encode( _text, allowed_special=self._allowed_special, disallowed_special=self._disallowed_special, ) tokenizer = Tokenizer( chunk_overlap=self._chunk_overlap, tokens_per_chunk=self._chunk_size, decode=self._tokenizer.decode, encode=_encode, ) return split_text_on_tokens(text=text, tokenizer=tokenizer) class SentenceTransformersTokenTextSplitter(TextSplitter): """Splitting text to tokens using sentence model tokenizer.""" def __init__( self, chunk_overlap: int = 50, model: str = "sentence-transformers/all-mpnet-base-v2", tokens_per_chunk: Optional[int] = None, **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs, chunk_overlap=chunk_overlap) try: from sentence_transformers import SentenceTransformer except ImportError: raise ImportError( "Could not import sentence_transformer python package. " "This is needed in order to for SentenceTransformersTokenTextSplitter. " "Please install it with `pip install sentence-transformers`." ) self.model = model self._model = SentenceTransformer(self.model, trust_remote_code=True) self.tokenizer = self._model.tokenizer self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk) def _initialize_chunk_configuration( self, *, tokens_per_chunk: Optional[int] ) -> None: self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length) if tokens_per_chunk is None: self.tokens_per_chunk = self.maximum_tokens_per_chunk else: self.tokens_per_chunk = tokens_per_chunk if self.tokens_per_chunk > self.maximum_tokens_per_chunk: raise ValueError( f"The token limit of the models '{self.model}'" f" is: {self.maximum_tokens_per_chunk}." f" Argument tokens_per_chunk={self.tokens_per_chunk}" f" > maximum token limit." ) def split_text(self, text: str) -> list[str]: def encode_strip_start_and_stop_token_ids(text: str) -> list[int]: return self._encode(text)[1:-1] tokenizer = Tokenizer( chunk_overlap=self._chunk_overlap, tokens_per_chunk=self.tokens_per_chunk, decode=self.tokenizer.decode, encode=encode_strip_start_and_stop_token_ids, ) return split_text_on_tokens(text=text, tokenizer=tokenizer) def count_tokens(self, *, text: str) -> int: return len(self._encode(text)) _max_length_equal_32_bit_integer: int = 2**32 def _encode(self, text: str) -> list[int]: token_ids_with_start_and_end_token_ids = self.tokenizer.encode( text, max_length=self._max_length_equal_32_bit_integer, truncation="do_not_truncate", ) return token_ids_with_start_and_end_token_ids class Language(str, Enum): """Enum of the programming languages.""" CPP = "cpp" GO = "go" JAVA = "java" KOTLIN = "kotlin" JS = "js" TS = "ts" PHP = "php" PROTO = "proto" PYTHON = "python" RST = "rst" RUBY = "ruby" RUST = "rust" SCALA = "scala" SWIFT = "swift" MARKDOWN = "markdown" LATEX = "latex" HTML = "html" SOL = "sol" CSHARP = "csharp" COBOL = "cobol" C = "c" LUA = "lua" PERL = "perl" class RecursiveCharacterTextSplitter(TextSplitter): """Splitting text by recursively look at characters. Recursively tries to split by different characters to find one that works. """ def __init__( self, separators: Optional[list[str]] = None, keep_separator: bool = True, is_separator_regex: bool = False, chunk_size: int = 4000, chunk_overlap: int = 200, **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__( chunk_size=chunk_size, chunk_overlap=chunk_overlap, keep_separator=keep_separator, **kwargs, ) self._separators = separators or ["\n\n", "\n", " ", ""] self._is_separator_regex = is_separator_regex self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap def _split_text(self, text: str, separators: list[str]) -> list[str]: """Split incoming text and return chunks.""" final_chunks = [] # Get appropriate separator to use separator = separators[-1] new_separators = [] for i, _s in enumerate(separators): _separator = _s if self._is_separator_regex else re.escape(_s) if _s == "": separator = _s break if re.search(_separator, text): separator = _s new_separators = separators[i + 1 :] break _separator = ( separator if self._is_separator_regex else re.escape(separator) ) splits = _split_text_with_regex(text, _separator, self._keep_separator) # Now go merging things, recursively splitting longer texts. _good_splits = [] _separator = "" if self._keep_separator else separator for s in splits: if self._length_function(s) < self._chunk_size: _good_splits.append(s) else: if _good_splits: merged_text = self._merge_splits(_good_splits, _separator) final_chunks.extend(merged_text) _good_splits = [] if not new_separators: final_chunks.append(s) else: other_info = self._split_text(s, new_separators) final_chunks.extend(other_info) if _good_splits: merged_text = self._merge_splits(_good_splits, _separator) final_chunks.extend(merged_text) return final_chunks def split_text(self, text: str) -> list[str]: return self._split_text(text, self._separators) @classmethod def from_language( cls, language: Language, **kwargs: Any ) -> RecursiveCharacterTextSplitter: separators = cls.get_separators_for_language(language) return cls(separators=separators, is_separator_regex=True, **kwargs) @staticmethod def get_separators_for_language(language: Language) -> list[str]: if language == Language.CPP: return [ # Split along class definitions "\nclass ", # Split along function definitions "\nvoid ", "\nint ", "\nfloat ", "\ndouble ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.GO: return [ # Split along function definitions "\nfunc ", "\nvar ", "\nconst ", "\ntype ", # Split along control flow statements "\nif ", "\nfor ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.JAVA: return [ # Split along class definitions "\nclass ", # Split along method definitions "\npublic ", "\nprotected ", "\nprivate ", "\nstatic ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.KOTLIN: return [ # Split along class definitions "\nclass ", # Split along method definitions "\npublic ", "\nprotected ", "\nprivate ", "\ninternal ", "\ncompanion ", "\nfun ", "\nval ", "\nvar ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nwhen ", "\ncase ", "\nelse ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.JS: return [ # Split along function definitions "\nfunction ", "\nconst ", "\nlet ", "\nvar ", "\nclass ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", "\ndefault ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.TS: return [ "\nenum ", "\ninterface ", "\nnamespace ", "\ntype ", # Split along class definitions "\nclass ", # Split along function definitions "\nfunction ", "\nconst ", "\nlet ", "\nvar ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", "\ndefault ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.PHP: return [ # Split along function definitions "\nfunction ", # Split along class definitions "\nclass ", # Split along control flow statements "\nif ", "\nforeach ", "\nwhile ", "\ndo ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.PROTO: return [ # Split along message definitions "\nmessage ", # Split along service definitions "\nservice ", # Split along enum definitions "\nenum ", # Split along option definitions "\noption ", # Split along import statements "\nimport ", # Split along syntax declarations "\nsyntax ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.PYTHON: return [ # First, try to split along class definitions "\nclass ", "\ndef ", "\n\tdef ", # Now split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.RST: return [ # Split along section titles "\n=+\n", "\n-+\n", "\n\\*+\n", # Split along directive markers "\n\n.. *\n\n", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.RUBY: return [ # Split along method definitions "\ndef ", "\nclass ", # Split along control flow statements "\nif ", "\nunless ", "\nwhile ", "\nfor ", "\ndo ", "\nbegin ", "\nrescue ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.RUST: return [ # Split along function definitions "\nfn ", "\nconst ", "\nlet ", # Split along control flow statements "\nif ", "\nwhile ", "\nfor ", "\nloop ", "\nmatch ", "\nconst ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.SCALA: return [ # Split along class definitions "\nclass ", "\nobject ", # Split along method definitions "\ndef ", "\nval ", "\nvar ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nmatch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.SWIFT: return [ # Split along function definitions "\nfunc ", # Split along class definitions "\nclass ", "\nstruct ", "\nenum ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\ndo ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.MARKDOWN: return [ # First, try to split along Markdown headings (starting with level 2) "\n#{1,6} ", # Note the alternative syntax for headings (below) is not handled here # Heading level 2 # --------------- # End of code block "```\n", # Horizontal lines "\n\\*\\*\\*+\n", "\n---+\n", "\n___+\n", # Note that this splitter doesn't handle horizontal lines defined # by *three or more* of ***, ---, or ___, but this is not handled "\n\n", "\n", " ", "", ] elif language == Language.LATEX: return [ # First, try to split along Latex sections "\n\\\\chapter{", "\n\\\\section{", "\n\\\\subsection{", "\n\\\\subsubsection{", # Now split by environments "\n\\\\begin{enumerate}", "\n\\\\begin{itemize}", "\n\\\\begin{description}", "\n\\\\begin{list}", "\n\\\\begin{quote}", "\n\\\\begin{quotation}", "\n\\\\begin{verse}", "\n\\\\begin{verbatim}", # Now split by math environments "\n\\\begin{align}", "$$", "$", # Now split by the normal type of lines " ", "", ] elif language == Language.HTML: return [ # First, try to split along HTML tags " None: """Initialize the NLTK splitter.""" super().__init__(**kwargs) try: from nltk.tokenize import sent_tokenize self._tokenizer = sent_tokenize except ImportError: raise ImportError( "NLTK is not installed, please install it with `pip install nltk`." ) self._separator = separator self._language = language def split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" # First we naively split the large input into a bunch of smaller ones. splits = self._tokenizer(text, language=self._language) return self._merge_splits(splits, self._separator) class SpacyTextSplitter(TextSplitter): """Splitting text using Spacy package. Per default, Spacy's `en_core_web_sm` model is used and its default max_length is 1000000 (it is the length of maximum character this model takes which can be increased for large files). For a faster, but potentially less accurate splitting, you can use `pipe='sentencizer'`. """ def __init__( self, separator: str = "\n\n", pipe: str = "en_core_web_sm", max_length: int = 1_000_000, **kwargs: Any, ) -> None: """Initialize the spacy text splitter.""" super().__init__(**kwargs) self._tokenizer = _make_spacy_pipe_for_splitting( pipe, max_length=max_length ) self._separator = separator def split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" splits = (s.text for s in self._tokenizer(text).sents) return self._merge_splits(splits, self._separator) class KonlpyTextSplitter(TextSplitter): """Splitting text using Konlpy package. It is good for splitting Korean text. """ def __init__( self, separator: str = "\n\n", **kwargs: Any, ) -> None: """Initialize the Konlpy text splitter.""" super().__init__(**kwargs) self._separator = separator try: from konlpy.tag import Kkma except ImportError: raise ImportError( """ Konlpy is not installed, please install it with `pip install konlpy` """ ) self.kkma = Kkma() def split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" splits = self.kkma.sentences(text) return self._merge_splits(splits, self._separator) # For backwards compatibility class PythonCodeTextSplitter(RecursiveCharacterTextSplitter): """Attempts to split the text along Python syntax.""" def __init__(self, **kwargs: Any) -> None: """Initialize a PythonCodeTextSplitter.""" separators = self.get_separators_for_language(Language.PYTHON) super().__init__(separators=separators, **kwargs) class MarkdownTextSplitter(RecursiveCharacterTextSplitter): """Attempts to split the text along Markdown-formatted headings.""" def __init__(self, **kwargs: Any) -> None: """Initialize a MarkdownTextSplitter.""" separators = self.get_separators_for_language(Language.MARKDOWN) super().__init__(separators=separators, **kwargs) class LatexTextSplitter(RecursiveCharacterTextSplitter): """Attempts to split the text along Latex-formatted layout elements.""" def __init__(self, **kwargs: Any) -> None: """Initialize a LatexTextSplitter.""" separators = self.get_separators_for_language(Language.LATEX) super().__init__(separators=separators, **kwargs) class RecursiveJsonSplitter: def __init__( self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None ): super().__init__() self.max_chunk_size = max_chunk_size self.min_chunk_size = ( min_chunk_size if min_chunk_size is not None else max(max_chunk_size - 200, 50) ) @staticmethod def _json_size(data: dict) -> int: """Calculate the size of the serialized JSON object.""" return len(json.dumps(data)) @staticmethod def _set_nested_dict(d: dict, path: list[str], value: Any) -> None: """Set a value in a nested dictionary based on the given path.""" for key in path[:-1]: d = d.setdefault(key, {}) d[path[-1]] = value def _list_to_dict_preprocessing(self, data: Any) -> Any: if isinstance(data, dict): # Process each key-value pair in the dictionary return { k: self._list_to_dict_preprocessing(v) for k, v in data.items() } elif isinstance(data, list): # Convert the list to a dictionary with index-based keys return { str(i): self._list_to_dict_preprocessing(item) for i, item in enumerate(data) } else: # Base case: the item is neither a dict nor a list, so return it unchanged return data def _json_split( self, data: dict[str, Any], current_path: list[str] = [], chunks: list[dict] = [{}], ) -> list[dict]: """ Split json into maximum size dictionaries while preserving structure. """ if isinstance(data, dict): for key, value in data.items(): new_path = current_path + [key] chunk_size = self._json_size(chunks[-1]) size = self._json_size({key: value}) remaining = self.max_chunk_size - chunk_size if size < remaining: # Add item to current chunk self._set_nested_dict(chunks[-1], new_path, value) else: if chunk_size >= self.min_chunk_size: # Chunk is big enough, start a new chunk chunks.append({}) # Iterate self._json_split(value, new_path, chunks) else: # handle single item self._set_nested_dict(chunks[-1], current_path, data) return chunks def split_json( self, json_data: dict[str, Any], convert_lists: bool = False, ) -> list[dict]: """Splits JSON into a list of JSON chunks""" if convert_lists: chunks = self._json_split( self._list_to_dict_preprocessing(json_data) ) else: chunks = self._json_split(json_data) # Remove the last chunk if it's empty if not chunks[-1]: chunks.pop() return chunks def split_text( self, json_data: dict[str, Any], convert_lists: bool = False ) -> list[str]: """Splits JSON into a list of JSON formatted strings""" chunks = self.split_json( json_data=json_data, convert_lists=convert_lists ) # Convert to string return [json.dumps(chunk) for chunk in chunks] def create_documents( self, texts: list[dict], convert_lists: bool = False, metadatas: Optional[list[dict]] = None, ) -> list[SplitterDocument]: """Create documents from a list of json objects (dict).""" _metadatas = metadatas or [{}] * len(texts) documents = [] for i, text in enumerate(texts): for chunk in self.split_text( json_data=text, convert_lists=convert_lists ): metadata = copy.deepcopy(_metadatas[i]) new_doc = SplitterDocument( page_content=chunk, metadata=metadata ) documents.append(new_doc) return documents