text.py 66 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997
  1. # Source - LangChain
  2. # URL: https://github.com/langchain-ai/langchain/blob/6a5b084704afa22ca02f78d0464f35aed75d1ff2/libs/langchain/langchain/text_splitter.py#L851
  3. """**Text Splitters** are classes for splitting text.
  4. **Class hierarchy:**
  5. .. code-block::
  6. BaseDocumentTransformer --> TextSplitter --> <name>TextSplitter # Example: CharacterTextSplitter
  7. RecursiveCharacterTextSplitter --> <name>TextSplitter
  8. Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter.
  9. **Main helpers:**
  10. .. code-block::
  11. Document, Tokenizer, Language, LineType, HeaderType
  12. """ # noqa: E501
  13. from __future__ import annotations
  14. import copy
  15. import json
  16. import logging
  17. import pathlib
  18. import re
  19. from abc import ABC, abstractmethod
  20. from dataclasses import dataclass
  21. from enum import Enum
  22. from io import BytesIO, StringIO
  23. from typing import (
  24. AbstractSet,
  25. Any,
  26. Callable,
  27. Collection,
  28. Iterable,
  29. Literal,
  30. Optional,
  31. Sequence,
  32. Tuple,
  33. Type,
  34. TypedDict,
  35. TypeVar,
  36. cast,
  37. )
  38. import requests
  39. from pydantic import BaseModel, Field, PrivateAttr
  40. from typing_extensions import NotRequired
  41. logger = logging.getLogger()
  42. TS = TypeVar("TS", bound="TextSplitter")
  43. class BaseSerialized(TypedDict):
  44. """Base class for serialized objects."""
  45. lc: int
  46. id: list[str]
  47. name: NotRequired[str]
  48. graph: NotRequired[dict[str, Any]]
  49. class SerializedConstructor(BaseSerialized):
  50. """Serialized constructor."""
  51. type: Literal["constructor"]
  52. kwargs: dict[str, Any]
  53. class SerializedSecret(BaseSerialized):
  54. """Serialized secret."""
  55. type: Literal["secret"]
  56. class SerializedNotImplemented(BaseSerialized):
  57. """Serialized not implemented."""
  58. type: Literal["not_implemented"]
  59. repr: Optional[str]
  60. def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
  61. """Try to determine if a value is different from the default.
  62. Args:
  63. value: The value.
  64. key: The key.
  65. model: The model.
  66. Returns:
  67. Whether the value is different from the default.
  68. """
  69. try:
  70. return model.__fields__[key].get_default() != value
  71. except Exception:
  72. return True
  73. class Serializable(BaseModel, ABC):
  74. """Serializable base class."""
  75. @classmethod
  76. def is_lc_serializable(cls) -> bool:
  77. """Is this class serializable?"""
  78. return False
  79. @classmethod
  80. def get_lc_namespace(cls) -> list[str]:
  81. """Get the namespace of the langchain object.
  82. For example, if the class is `langchain.llms.openai.OpenAI`, then the
  83. namespace is ["langchain", "llms", "openai"]
  84. """
  85. return cls.__module__.split(".")
  86. @property
  87. def lc_secrets(self) -> dict[str, str]:
  88. """A map of constructor argument names to secret ids.
  89. For example,
  90. {"openai_api_key": "OPENAI_API_KEY"}
  91. """
  92. return {}
  93. @property
  94. def lc_attributes(self) -> dict:
  95. """List of attribute names that should be included in the serialized kwargs.
  96. These attributes must be accepted by the constructor.
  97. """
  98. return {}
  99. @classmethod
  100. def lc_id(cls) -> list[str]:
  101. """A unique identifier for this class for serialization purposes.
  102. The unique identifier is a list of strings that describes the path
  103. to the object.
  104. """
  105. return [*cls.get_lc_namespace(), cls.__name__]
  106. class Config:
  107. extra = "ignore"
  108. def __repr_args__(self) -> Any:
  109. return [
  110. (k, v)
  111. for k, v in super().__repr_args__()
  112. if (k not in self.__fields__ or try_neq_default(v, k, self))
  113. ]
  114. _lc_kwargs: dict[str, Any] = PrivateAttr(default_factory=dict)
  115. def __init__(self, **kwargs: Any) -> None:
  116. super().__init__(**kwargs)
  117. self._lc_kwargs = kwargs
  118. def to_json(
  119. self,
  120. ) -> SerializedConstructor | SerializedNotImplemented:
  121. if not self.is_lc_serializable():
  122. return self.to_json_not_implemented()
  123. secrets = dict()
  124. # Get latest values for kwargs if there is an attribute with same name
  125. lc_kwargs = {
  126. k: getattr(self, k, v)
  127. for k, v in self._lc_kwargs.items()
  128. if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
  129. }
  130. # Merge the lc_secrets and lc_attributes from every class in the MRO
  131. for cls in [None, *self.__class__.mro()]:
  132. # Once we get to Serializable, we're done
  133. if cls is Serializable:
  134. break
  135. if cls:
  136. deprecated_attributes = [
  137. "lc_namespace",
  138. "lc_serializable",
  139. ]
  140. for attr in deprecated_attributes:
  141. if hasattr(cls, attr):
  142. raise ValueError(
  143. f"Class {self.__class__} has a deprecated "
  144. f"attribute {attr}. Please use the corresponding "
  145. f"classmethod instead."
  146. )
  147. # Get a reference to self bound to each class in the MRO
  148. this = cast(
  149. Serializable, self if cls is None else super(cls, self)
  150. )
  151. secrets.update(this.lc_secrets)
  152. # Now also add the aliases for the secrets
  153. # This ensures known secret aliases are hidden.
  154. # Note: this does NOT hide any other extra kwargs
  155. # that are not present in the fields.
  156. for key in list(secrets):
  157. value = secrets[key]
  158. if key in this.__fields__:
  159. secrets[this.__fields__[key].alias] = value # type: ignore
  160. lc_kwargs.update(this.lc_attributes)
  161. # include all secrets, even if not specified in kwargs
  162. # as these secrets may be passed as an environment variable instead
  163. for key in secrets.keys():
  164. secret_value = getattr(self, key, None) or lc_kwargs.get(key)
  165. if secret_value is not None:
  166. lc_kwargs.update({key: secret_value})
  167. return {
  168. "lc": 1,
  169. "type": "constructor",
  170. "id": self.lc_id(),
  171. "kwargs": (
  172. lc_kwargs
  173. if not secrets
  174. else _replace_secrets(lc_kwargs, secrets)
  175. ),
  176. }
  177. def to_json_not_implemented(self) -> SerializedNotImplemented:
  178. return to_json_not_implemented(self)
  179. def _replace_secrets(
  180. root: dict[Any, Any], secrets_map: dict[str, str]
  181. ) -> dict[Any, Any]:
  182. result = root.copy()
  183. for path, secret_id in secrets_map.items():
  184. [*parts, last] = path.split(".")
  185. current = result
  186. for part in parts:
  187. if part not in current:
  188. break
  189. current[part] = current[part].copy()
  190. current = current[part]
  191. if last in current:
  192. current[last] = {
  193. "lc": 1,
  194. "type": "secret",
  195. "id": [secret_id],
  196. }
  197. return result
  198. def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
  199. """Serialize a "not implemented" object.
  200. Args:
  201. obj: object to serialize
  202. Returns:
  203. SerializedNotImplemented
  204. """
  205. _id: list[str] = []
  206. try:
  207. if hasattr(obj, "__name__"):
  208. _id = [*obj.__module__.split("."), obj.__name__]
  209. elif hasattr(obj, "__class__"):
  210. _id = [
  211. *obj.__class__.__module__.split("."),
  212. obj.__class__.__name__,
  213. ]
  214. except Exception:
  215. pass
  216. result: SerializedNotImplemented = {
  217. "lc": 1,
  218. "type": "not_implemented",
  219. "id": _id,
  220. "repr": None,
  221. }
  222. try:
  223. result["repr"] = repr(obj)
  224. except Exception:
  225. pass
  226. return result
  227. class SplitterDocument(Serializable):
  228. """Class for storing a piece of text and associated metadata."""
  229. page_content: str
  230. """String text."""
  231. metadata: dict = Field(default_factory=dict)
  232. """Arbitrary metadata about the page content (e.g., source, relationships to other
  233. documents, etc.).
  234. """
  235. type: Literal["Document"] = "Document"
  236. def __init__(self, page_content: str, **kwargs: Any) -> None:
  237. """Pass page_content in as positional or named arg."""
  238. super().__init__(page_content=page_content, **kwargs)
  239. @classmethod
  240. def is_lc_serializable(cls) -> bool:
  241. """Return whether this class is serializable."""
  242. return True
  243. @classmethod
  244. def get_lc_namespace(cls) -> list[str]:
  245. """Get the namespace of the langchain object."""
  246. return ["langchain", "schema", "document"]
  247. class BaseDocumentTransformer(ABC):
  248. """Abstract base class for document transformation systems.
  249. A document transformation system takes a sequence of Documents and returns a
  250. sequence of transformed Documents.
  251. Example:
  252. .. code-block:: python
  253. class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
  254. embeddings: Embeddings
  255. similarity_fn: Callable = cosine_similarity
  256. similarity_threshold: float = 0.95
  257. class Config:
  258. arbitrary_types_allowed = True
  259. def transform_documents(
  260. self, documents: Sequence[Document], **kwargs: Any
  261. ) -> Sequence[Document]:
  262. stateful_documents = get_stateful_documents(documents)
  263. embedded_documents = _get_embeddings_from_stateful_docs(
  264. self.embeddings, stateful_documents
  265. )
  266. included_idxs = _filter_similar_embeddings(
  267. embedded_documents, self.similarity_fn, self.similarity_threshold
  268. )
  269. return [stateful_documents[i] for i in sorted(included_idxs)]
  270. async def atransform_documents(
  271. self, documents: Sequence[Document], **kwargs: Any
  272. ) -> Sequence[Document]:
  273. raise NotImplementedError
  274. """ # noqa: E501
  275. @abstractmethod
  276. def transform_documents(
  277. self, documents: Sequence[SplitterDocument], **kwargs: Any
  278. ) -> Sequence[SplitterDocument]:
  279. """Transform a list of documents.
  280. Args:
  281. documents: A sequence of Documents to be transformed.
  282. Returns:
  283. A list of transformed Documents.
  284. """
  285. async def atransform_documents(
  286. self, documents: Sequence[SplitterDocument], **kwargs: Any
  287. ) -> Sequence[SplitterDocument]:
  288. """Asynchronously transform a list of documents.
  289. Args:
  290. documents: A sequence of Documents to be transformed.
  291. Returns:
  292. A list of transformed Documents.
  293. """
  294. raise NotImplementedError("This method is not implemented.")
  295. # return await langchain_core.runnables.config.run_in_executor(
  296. # None, self.transform_documents, documents, **kwargs
  297. # )
  298. def _make_spacy_pipe_for_splitting(
  299. pipe: str, *, max_length: int = 1_000_000
  300. ) -> Any: # avoid importing spacy
  301. try:
  302. import spacy
  303. except ImportError:
  304. raise ImportError(
  305. "Spacy is not installed, please install it with `pip install spacy`."
  306. )
  307. if pipe == "sentencizer":
  308. from spacy.lang.en import English
  309. sentencizer = English()
  310. sentencizer.add_pipe("sentencizer")
  311. else:
  312. sentencizer = spacy.load(pipe, exclude=["ner", "tagger"])
  313. sentencizer.max_length = max_length
  314. return sentencizer
  315. def _split_text_with_regex(
  316. text: str, separator: str, keep_separator: bool
  317. ) -> list[str]:
  318. # Now that we have the separator, split the text
  319. if separator:
  320. if keep_separator:
  321. # The parentheses in the pattern keep the delimiters in the result.
  322. _splits = re.split(f"({separator})", text)
  323. splits = [
  324. _splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)
  325. ]
  326. if len(_splits) % 2 == 0:
  327. splits += _splits[-1:]
  328. splits = [_splits[0]] + splits
  329. else:
  330. splits = re.split(separator, text)
  331. else:
  332. splits = list(text)
  333. return [s for s in splits if s != ""]
  334. class TextSplitter(BaseDocumentTransformer, ABC):
  335. """Interface for splitting text into chunks."""
  336. def __init__(
  337. self,
  338. chunk_size: int = 4000,
  339. chunk_overlap: int = 200,
  340. length_function: Callable[[str], int] = len,
  341. keep_separator: bool = False,
  342. add_start_index: bool = False,
  343. strip_whitespace: bool = True,
  344. ) -> None:
  345. """Create a new TextSplitter.
  346. Args:
  347. chunk_size: Maximum size of chunks to return
  348. chunk_overlap: Overlap in characters between chunks
  349. length_function: Function that measures the length of given chunks
  350. keep_separator: Whether to keep the separator in the chunks
  351. add_start_index: If `True`, includes chunk's start index in metadata
  352. strip_whitespace: If `True`, strips whitespace from the start and end of
  353. every document
  354. """
  355. if chunk_overlap > chunk_size:
  356. raise ValueError(
  357. f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
  358. f"({chunk_size}), should be smaller."
  359. )
  360. self._chunk_size = chunk_size
  361. self._chunk_overlap = chunk_overlap
  362. self._length_function = length_function
  363. self._keep_separator = keep_separator
  364. self._add_start_index = add_start_index
  365. self._strip_whitespace = strip_whitespace
  366. @abstractmethod
  367. def split_text(self, text: str) -> list[str]:
  368. """Split text into multiple components."""
  369. def create_documents(
  370. self, texts: list[str], metadatas: Optional[list[dict]] = None
  371. ) -> list[SplitterDocument]:
  372. """Create documents from a list of texts."""
  373. _metadatas = metadatas or [{}] * len(texts)
  374. documents = []
  375. for i, text in enumerate(texts):
  376. index = 0
  377. previous_chunk_len = 0
  378. for chunk in self.split_text(text):
  379. metadata = copy.deepcopy(_metadatas[i])
  380. if self._add_start_index:
  381. offset = index + previous_chunk_len - self._chunk_overlap
  382. index = text.find(chunk, max(0, offset))
  383. metadata["start_index"] = index
  384. previous_chunk_len = len(chunk)
  385. new_doc = SplitterDocument(
  386. page_content=chunk, metadata=metadata
  387. )
  388. documents.append(new_doc)
  389. return documents
  390. def split_documents(
  391. self, documents: Iterable[SplitterDocument]
  392. ) -> list[SplitterDocument]:
  393. """Split documents."""
  394. texts, metadatas = [], []
  395. for doc in documents:
  396. texts.append(doc.page_content)
  397. metadatas.append(doc.metadata)
  398. return self.create_documents(texts, metadatas=metadatas)
  399. def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
  400. text = separator.join(docs)
  401. if self._strip_whitespace:
  402. text = text.strip()
  403. if text == "":
  404. return None
  405. else:
  406. return text
  407. def _merge_splits(
  408. self, splits: Iterable[str], separator: str
  409. ) -> list[str]:
  410. # We now want to combine these smaller pieces into medium size
  411. # chunks to send to the LLM.
  412. separator_len = self._length_function(separator)
  413. docs = []
  414. current_doc: list[str] = []
  415. total = 0
  416. for d in splits:
  417. _len = self._length_function(d)
  418. if (
  419. total + _len + (separator_len if len(current_doc) > 0 else 0)
  420. > self._chunk_size
  421. ):
  422. if total > self._chunk_size:
  423. logger.warning(
  424. f"Created a chunk of size {total}, "
  425. f"which is longer than the specified {self._chunk_size}"
  426. )
  427. if len(current_doc) > 0:
  428. doc = self._join_docs(current_doc, separator)
  429. if doc is not None:
  430. docs.append(doc)
  431. # Keep on popping if:
  432. # - we have a larger chunk than in the chunk overlap
  433. # - or if we still have any chunks and the length is long
  434. while total > self._chunk_overlap or (
  435. total
  436. + _len
  437. + (separator_len if len(current_doc) > 0 else 0)
  438. > self._chunk_size
  439. and total > 0
  440. ):
  441. total -= self._length_function(current_doc[0]) + (
  442. separator_len if len(current_doc) > 1 else 0
  443. )
  444. current_doc = current_doc[1:]
  445. current_doc.append(d)
  446. total += _len + (separator_len if len(current_doc) > 1 else 0)
  447. doc = self._join_docs(current_doc, separator)
  448. if doc is not None:
  449. docs.append(doc)
  450. return docs
  451. @classmethod
  452. def from_huggingface_tokenizer(
  453. cls, tokenizer: Any, **kwargs: Any
  454. ) -> TextSplitter:
  455. """Text splitter that uses HuggingFace tokenizer to count length."""
  456. try:
  457. from transformers import PreTrainedTokenizerBase
  458. if not isinstance(tokenizer, PreTrainedTokenizerBase):
  459. raise ValueError(
  460. "Tokenizer received was not an instance of PreTrainedTokenizerBase"
  461. )
  462. def _huggingface_tokenizer_length(text: str) -> int:
  463. return len(tokenizer.encode(text))
  464. except ImportError:
  465. raise ValueError(
  466. "Could not import transformers python package. "
  467. "Please install it with `pip install transformers`."
  468. )
  469. return cls(length_function=_huggingface_tokenizer_length, **kwargs)
  470. @classmethod
  471. def from_tiktoken_encoder(
  472. cls: Type[TS],
  473. encoding_name: str = "gpt2",
  474. model: Optional[str] = None,
  475. allowed_special: Literal["all"] | AbstractSet[str] = set(),
  476. disallowed_special: Literal["all"] | Collection[str] = "all",
  477. **kwargs: Any,
  478. ) -> TS:
  479. """Text splitter that uses tiktoken encoder to count length."""
  480. try:
  481. import tiktoken
  482. except ImportError:
  483. raise ImportError(
  484. "Could not import tiktoken python package. "
  485. "This is needed in order to calculate max_tokens_for_prompt. "
  486. "Please install it with `pip install tiktoken`."
  487. )
  488. if model is not None:
  489. enc = tiktoken.encoding_for_model(model)
  490. else:
  491. enc = tiktoken.get_encoding(encoding_name)
  492. def _tiktoken_encoder(text: str) -> int:
  493. return len(
  494. enc.encode(
  495. text,
  496. allowed_special=allowed_special,
  497. disallowed_special=disallowed_special,
  498. )
  499. )
  500. if issubclass(cls, TokenTextSplitter):
  501. extra_kwargs = {
  502. "encoding_name": encoding_name,
  503. "model": model,
  504. "allowed_special": allowed_special,
  505. "disallowed_special": disallowed_special,
  506. }
  507. kwargs = {**kwargs, **extra_kwargs}
  508. return cls(length_function=_tiktoken_encoder, **kwargs)
  509. def transform_documents(
  510. self, documents: Sequence[SplitterDocument], **kwargs: Any
  511. ) -> Sequence[SplitterDocument]:
  512. """Transform sequence of documents by splitting them."""
  513. return self.split_documents(list(documents))
  514. class CharacterTextSplitter(TextSplitter):
  515. """Splitting text that looks at characters."""
  516. DEFAULT_SEPARATOR: str = "\n\n"
  517. def __init__(
  518. self,
  519. separator: str = DEFAULT_SEPARATOR,
  520. is_separator_regex: bool = False,
  521. **kwargs: Any,
  522. ) -> None:
  523. """Create a new TextSplitter."""
  524. super().__init__(**kwargs)
  525. self._separator = separator
  526. self._is_separator_regex = is_separator_regex
  527. def split_text(self, text: str) -> list[str]:
  528. """Split incoming text and return chunks."""
  529. # First we naively split the large input into a bunch of smaller ones.
  530. separator = (
  531. self._separator
  532. if self._is_separator_regex
  533. else re.escape(self._separator)
  534. )
  535. splits = _split_text_with_regex(text, separator, self._keep_separator)
  536. _separator = "" if self._keep_separator else self._separator
  537. return self._merge_splits(splits, _separator)
  538. class LineType(TypedDict):
  539. """Line type as typed dict."""
  540. metadata: dict[str, str]
  541. content: str
  542. class HeaderType(TypedDict):
  543. """Header type as typed dict."""
  544. level: int
  545. name: str
  546. data: str
  547. class MarkdownHeaderTextSplitter:
  548. """Splitting markdown files based on specified headers."""
  549. def __init__(
  550. self,
  551. headers_to_split_on: list[Tuple[str, str]],
  552. return_each_line: bool = False,
  553. strip_headers: bool = True,
  554. ):
  555. """Create a new MarkdownHeaderTextSplitter.
  556. Args:
  557. headers_to_split_on: Headers we want to track
  558. return_each_line: Return each line w/ associated headers
  559. strip_headers: Strip split headers from the content of the chunk
  560. """
  561. # Output line-by-line or aggregated into chunks w/ common headers
  562. self.return_each_line = return_each_line
  563. # Given the headers we want to split on,
  564. # (e.g., "#, ##, etc") order by length
  565. self.headers_to_split_on = sorted(
  566. headers_to_split_on, key=lambda split: len(split[0]), reverse=True
  567. )
  568. # Strip headers split headers from the content of the chunk
  569. self.strip_headers = strip_headers
  570. def aggregate_lines_to_chunks(
  571. self, lines: list[LineType]
  572. ) -> list[SplitterDocument]:
  573. """Combine lines with common metadata into chunks
  574. Args:
  575. lines: Line of text / associated header metadata
  576. """
  577. aggregated_chunks: list[LineType] = []
  578. for line in lines:
  579. if (
  580. aggregated_chunks
  581. and aggregated_chunks[-1]["metadata"] == line["metadata"]
  582. ):
  583. # If the last line in the aggregated list
  584. # has the same metadata as the current line,
  585. # append the current content to the last lines's content
  586. aggregated_chunks[-1]["content"] += " \n" + line["content"]
  587. elif (
  588. aggregated_chunks
  589. and aggregated_chunks[-1]["metadata"] != line["metadata"]
  590. # may be issues if other metadata is present
  591. and len(aggregated_chunks[-1]["metadata"])
  592. < len(line["metadata"])
  593. and aggregated_chunks[-1]["content"].split("\n")[-1][0] == "#"
  594. and not self.strip_headers
  595. ):
  596. # If the last line in the aggregated list
  597. # has different metadata as the current line,
  598. # and has shallower header level than the current line,
  599. # and the last line is a header,
  600. # and we are not stripping headers,
  601. # append the current content to the last line's content
  602. aggregated_chunks[-1]["content"] += " \n" + line["content"]
  603. # and update the last line's metadata
  604. aggregated_chunks[-1]["metadata"] = line["metadata"]
  605. else:
  606. # Otherwise, append the current line to the aggregated list
  607. aggregated_chunks.append(line)
  608. return [
  609. SplitterDocument(
  610. page_content=chunk["content"], metadata=chunk["metadata"]
  611. )
  612. for chunk in aggregated_chunks
  613. ]
  614. def split_text(self, text: str) -> list[SplitterDocument]:
  615. """Split markdown file
  616. Args:
  617. text: Markdown file"""
  618. # Split the input text by newline character ("\n").
  619. lines = text.split("\n")
  620. # Final output
  621. lines_with_metadata: list[LineType] = []
  622. # Content and metadata of the chunk currently being processed
  623. current_content: list[str] = []
  624. current_metadata: dict[str, str] = {}
  625. # Keep track of the nested header structure
  626. # header_stack: list[dict[str, int | str]] = []
  627. header_stack: list[HeaderType] = []
  628. initial_metadata: dict[str, str] = {}
  629. in_code_block = False
  630. opening_fence = ""
  631. for line in lines:
  632. stripped_line = line.strip()
  633. if not in_code_block:
  634. # Exclude inline code spans
  635. if (
  636. stripped_line.startswith("```")
  637. and stripped_line.count("```") == 1
  638. ):
  639. in_code_block = True
  640. opening_fence = "```"
  641. elif stripped_line.startswith("~~~"):
  642. in_code_block = True
  643. opening_fence = "~~~"
  644. else:
  645. if stripped_line.startswith(opening_fence):
  646. in_code_block = False
  647. opening_fence = ""
  648. if in_code_block:
  649. current_content.append(stripped_line)
  650. continue
  651. # Check each line against each of the header types (e.g., #, ##)
  652. for sep, name in self.headers_to_split_on:
  653. # Check if line starts with a header that we intend to split on
  654. if stripped_line.startswith(sep) and (
  655. # Header with no text OR header is followed by space
  656. # Both are valid conditions that sep is being used a header
  657. len(stripped_line) == len(sep)
  658. or stripped_line[len(sep)] == " "
  659. ):
  660. # Ensure we are tracking the header as metadata
  661. if name is not None:
  662. # Get the current header level
  663. current_header_level = sep.count("#")
  664. # Pop out headers of lower or same level from the stack
  665. while (
  666. header_stack
  667. and header_stack[-1]["level"]
  668. >= current_header_level
  669. ):
  670. # We have encountered a new header
  671. # at the same or higher level
  672. popped_header = header_stack.pop()
  673. # Clear the metadata for the
  674. # popped header in initial_metadata
  675. if popped_header["name"] in initial_metadata:
  676. initial_metadata.pop(popped_header["name"])
  677. # Push the current header to the stack
  678. header: HeaderType = {
  679. "level": current_header_level,
  680. "name": name,
  681. "data": stripped_line[len(sep) :].strip(),
  682. }
  683. header_stack.append(header)
  684. # Update initial_metadata with the current header
  685. initial_metadata[name] = header["data"]
  686. # Add the previous line to the lines_with_metadata
  687. # only if current_content is not empty
  688. if current_content:
  689. lines_with_metadata.append(
  690. {
  691. "content": "\n".join(current_content),
  692. "metadata": current_metadata.copy(),
  693. }
  694. )
  695. current_content.clear()
  696. if not self.strip_headers:
  697. current_content.append(stripped_line)
  698. break
  699. else:
  700. if stripped_line:
  701. current_content.append(stripped_line)
  702. elif current_content:
  703. lines_with_metadata.append(
  704. {
  705. "content": "\n".join(current_content),
  706. "metadata": current_metadata.copy(),
  707. }
  708. )
  709. current_content.clear()
  710. current_metadata = initial_metadata.copy()
  711. if current_content:
  712. lines_with_metadata.append(
  713. {
  714. "content": "\n".join(current_content),
  715. "metadata": current_metadata,
  716. }
  717. )
  718. # lines_with_metadata has each line with associated header metadata
  719. # aggregate these into chunks based on common metadata
  720. if not self.return_each_line:
  721. return self.aggregate_lines_to_chunks(lines_with_metadata)
  722. else:
  723. return [
  724. SplitterDocument(
  725. page_content=chunk["content"], metadata=chunk["metadata"]
  726. )
  727. for chunk in lines_with_metadata
  728. ]
  729. class ElementType(TypedDict):
  730. """Element type as typed dict."""
  731. url: str
  732. xpath: str
  733. content: str
  734. metadata: dict[str, str]
  735. class HTMLHeaderTextSplitter:
  736. """
  737. Splitting HTML files based on specified headers.
  738. Requires lxml package.
  739. """
  740. def __init__(
  741. self,
  742. headers_to_split_on: list[Tuple[str, str]],
  743. return_each_element: bool = False,
  744. ):
  745. """Create a new HTMLHeaderTextSplitter.
  746. Args:
  747. headers_to_split_on: list of tuples of headers we want to track mapped to
  748. (arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4,
  749. h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2)].
  750. return_each_element: Return each element w/ associated headers.
  751. """
  752. # Output element-by-element or aggregated into chunks w/ common headers
  753. self.return_each_element = return_each_element
  754. self.headers_to_split_on = sorted(headers_to_split_on)
  755. def aggregate_elements_to_chunks(
  756. self, elements: list[ElementType]
  757. ) -> list[SplitterDocument]:
  758. """Combine elements with common metadata into chunks
  759. Args:
  760. elements: HTML element content with associated identifying info and metadata
  761. """
  762. aggregated_chunks: list[ElementType] = []
  763. for element in elements:
  764. if (
  765. aggregated_chunks
  766. and aggregated_chunks[-1]["metadata"] == element["metadata"]
  767. ):
  768. # If the last element in the aggregated list
  769. # has the same metadata as the current element,
  770. # append the current content to the last element's content
  771. aggregated_chunks[-1]["content"] += " \n" + element["content"]
  772. else:
  773. # Otherwise, append the current element to the aggregated list
  774. aggregated_chunks.append(element)
  775. return [
  776. SplitterDocument(
  777. page_content=chunk["content"], metadata=chunk["metadata"]
  778. )
  779. for chunk in aggregated_chunks
  780. ]
  781. def split_text_from_url(self, url: str) -> list[SplitterDocument]:
  782. """Split HTML from web URL
  783. Args:
  784. url: web URL
  785. """
  786. r = requests.get(url)
  787. return self.split_text_from_file(BytesIO(r.content))
  788. def split_text(self, text: str) -> list[SplitterDocument]:
  789. """Split HTML text string
  790. Args:
  791. text: HTML text
  792. """
  793. return self.split_text_from_file(StringIO(text))
  794. def split_text_from_file(self, file: Any) -> list[SplitterDocument]:
  795. """Split HTML file
  796. Args:
  797. file: HTML file
  798. """
  799. try:
  800. from lxml import etree
  801. except ImportError as e:
  802. raise ImportError(
  803. "Unable to import lxml, please install with `pip install lxml`."
  804. ) from e
  805. # use lxml library to parse html document and return xml ElementTree
  806. # Explicitly encoding in utf-8 allows non-English
  807. # html files to be processed without garbled characters
  808. parser = etree.HTMLParser(encoding="utf-8")
  809. tree = etree.parse(file, parser)
  810. # document transformation for "structure-aware" chunking is handled with xsl.
  811. # see comments in html_chunks_with_headers.xslt for more detailed information.
  812. xslt_path = (
  813. pathlib.Path(__file__).parent
  814. / "document_transformers/xsl/html_chunks_with_headers.xslt"
  815. )
  816. xslt_tree = etree.parse(xslt_path)
  817. transform = etree.XSLT(xslt_tree)
  818. result = transform(tree)
  819. result_dom = etree.fromstring(str(result))
  820. # create filter and mapping for header metadata
  821. header_filter = [header[0] for header in self.headers_to_split_on]
  822. header_mapping = dict(self.headers_to_split_on)
  823. # map xhtml namespace prefix
  824. ns_map = {"h": "http://www.w3.org/1999/xhtml"}
  825. # build list of elements from DOM
  826. elements = []
  827. for element in result_dom.findall("*//*", ns_map):
  828. if element.findall("*[@class='headers']") or element.findall(
  829. "*[@class='chunk']"
  830. ):
  831. elements.append(
  832. ElementType(
  833. url=file,
  834. xpath="".join(
  835. [
  836. node.text
  837. for node in element.findall(
  838. "*[@class='xpath']", ns_map
  839. )
  840. ]
  841. ),
  842. content="".join(
  843. [
  844. node.text
  845. for node in element.findall(
  846. "*[@class='chunk']", ns_map
  847. )
  848. ]
  849. ),
  850. metadata={
  851. # Add text of specified headers to metadata using header
  852. # mapping.
  853. header_mapping[node.tag]: node.text
  854. for node in filter(
  855. lambda x: x.tag in header_filter,
  856. element.findall(
  857. "*[@class='headers']/*", ns_map
  858. ),
  859. )
  860. },
  861. )
  862. )
  863. if not self.return_each_element:
  864. return self.aggregate_elements_to_chunks(elements)
  865. else:
  866. return [
  867. SplitterDocument(
  868. page_content=chunk["content"], metadata=chunk["metadata"]
  869. )
  870. for chunk in elements
  871. ]
  872. # should be in newer Python versions (3.11+)
  873. # @dataclass(frozen=True, kw_only=True, slots=True)
  874. @dataclass(frozen=True)
  875. class Tokenizer:
  876. """Tokenizer data class."""
  877. chunk_overlap: int
  878. """Overlap in tokens between chunks"""
  879. tokens_per_chunk: int
  880. """Maximum number of tokens per chunk"""
  881. decode: Callable[[list[int]], str]
  882. """ Function to decode a list of token ids to a string"""
  883. encode: Callable[[str], list[int]]
  884. """ Function to encode a string to a list of token ids"""
  885. def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
  886. """Split incoming text and return chunks using tokenizer."""
  887. splits: list[str] = []
  888. input_ids = tokenizer.encode(text)
  889. start_idx = 0
  890. cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
  891. chunk_ids = input_ids[start_idx:cur_idx]
  892. while start_idx < len(input_ids):
  893. splits.append(tokenizer.decode(chunk_ids))
  894. if cur_idx == len(input_ids):
  895. break
  896. start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
  897. cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
  898. chunk_ids = input_ids[start_idx:cur_idx]
  899. return splits
  900. class TokenTextSplitter(TextSplitter):
  901. """Splitting text to tokens using model tokenizer."""
  902. def __init__(
  903. self,
  904. encoding_name: str = "gpt2",
  905. model: Optional[str] = None,
  906. allowed_special: Literal["all"] | AbstractSet[str] = set(),
  907. disallowed_special: Literal["all"] | Collection[str] = "all",
  908. **kwargs: Any,
  909. ) -> None:
  910. """Create a new TextSplitter."""
  911. super().__init__(**kwargs)
  912. try:
  913. import tiktoken
  914. except ImportError:
  915. raise ImportError(
  916. "Could not import tiktoken python package. "
  917. "This is needed in order to for TokenTextSplitter. "
  918. "Please install it with `pip install tiktoken`."
  919. )
  920. if model is not None:
  921. enc = tiktoken.encoding_for_model(model)
  922. else:
  923. enc = tiktoken.get_encoding(encoding_name)
  924. self._tokenizer = enc
  925. self._allowed_special = allowed_special
  926. self._disallowed_special = disallowed_special
  927. def split_text(self, text: str) -> list[str]:
  928. def _encode(_text: str) -> list[int]:
  929. return self._tokenizer.encode(
  930. _text,
  931. allowed_special=self._allowed_special,
  932. disallowed_special=self._disallowed_special,
  933. )
  934. tokenizer = Tokenizer(
  935. chunk_overlap=self._chunk_overlap,
  936. tokens_per_chunk=self._chunk_size,
  937. decode=self._tokenizer.decode,
  938. encode=_encode,
  939. )
  940. return split_text_on_tokens(text=text, tokenizer=tokenizer)
  941. class SentenceTransformersTokenTextSplitter(TextSplitter):
  942. """Splitting text to tokens using sentence model tokenizer."""
  943. def __init__(
  944. self,
  945. chunk_overlap: int = 50,
  946. model: str = "sentence-transformers/all-mpnet-base-v2",
  947. tokens_per_chunk: Optional[int] = None,
  948. **kwargs: Any,
  949. ) -> None:
  950. """Create a new TextSplitter."""
  951. super().__init__(**kwargs, chunk_overlap=chunk_overlap)
  952. try:
  953. from sentence_transformers import SentenceTransformer
  954. except ImportError:
  955. raise ImportError(
  956. "Could not import sentence_transformer python package. "
  957. "This is needed in order to for SentenceTransformersTokenTextSplitter. "
  958. "Please install it with `pip install sentence-transformers`."
  959. )
  960. self.model = model
  961. self._model = SentenceTransformer(self.model, trust_remote_code=True)
  962. self.tokenizer = self._model.tokenizer
  963. self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
  964. def _initialize_chunk_configuration(
  965. self, *, tokens_per_chunk: Optional[int]
  966. ) -> None:
  967. self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
  968. if tokens_per_chunk is None:
  969. self.tokens_per_chunk = self.maximum_tokens_per_chunk
  970. else:
  971. self.tokens_per_chunk = tokens_per_chunk
  972. if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
  973. raise ValueError(
  974. f"The token limit of the models '{self.model}'"
  975. f" is: {self.maximum_tokens_per_chunk}."
  976. f" Argument tokens_per_chunk={self.tokens_per_chunk}"
  977. f" > maximum token limit."
  978. )
  979. def split_text(self, text: str) -> list[str]:
  980. def encode_strip_start_and_stop_token_ids(text: str) -> list[int]:
  981. return self._encode(text)[1:-1]
  982. tokenizer = Tokenizer(
  983. chunk_overlap=self._chunk_overlap,
  984. tokens_per_chunk=self.tokens_per_chunk,
  985. decode=self.tokenizer.decode,
  986. encode=encode_strip_start_and_stop_token_ids,
  987. )
  988. return split_text_on_tokens(text=text, tokenizer=tokenizer)
  989. def count_tokens(self, *, text: str) -> int:
  990. return len(self._encode(text))
  991. _max_length_equal_32_bit_integer: int = 2**32
  992. def _encode(self, text: str) -> list[int]:
  993. token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
  994. text,
  995. max_length=self._max_length_equal_32_bit_integer,
  996. truncation="do_not_truncate",
  997. )
  998. return token_ids_with_start_and_end_token_ids
  999. class Language(str, Enum):
  1000. """Enum of the programming languages."""
  1001. CPP = "cpp"
  1002. GO = "go"
  1003. JAVA = "java"
  1004. KOTLIN = "kotlin"
  1005. JS = "js"
  1006. TS = "ts"
  1007. PHP = "php"
  1008. PROTO = "proto"
  1009. PYTHON = "python"
  1010. RST = "rst"
  1011. RUBY = "ruby"
  1012. RUST = "rust"
  1013. SCALA = "scala"
  1014. SWIFT = "swift"
  1015. MARKDOWN = "markdown"
  1016. LATEX = "latex"
  1017. HTML = "html"
  1018. SOL = "sol"
  1019. CSHARP = "csharp"
  1020. COBOL = "cobol"
  1021. C = "c"
  1022. LUA = "lua"
  1023. PERL = "perl"
  1024. class RecursiveCharacterTextSplitter(TextSplitter):
  1025. """Splitting text by recursively look at characters.
  1026. Recursively tries to split by different characters to find one
  1027. that works.
  1028. """
  1029. def __init__(
  1030. self,
  1031. separators: Optional[list[str]] = None,
  1032. keep_separator: bool = True,
  1033. is_separator_regex: bool = False,
  1034. chunk_size: int = 4000,
  1035. chunk_overlap: int = 200,
  1036. **kwargs: Any,
  1037. ) -> None:
  1038. """Create a new TextSplitter."""
  1039. super().__init__(
  1040. chunk_size=chunk_size,
  1041. chunk_overlap=chunk_overlap,
  1042. keep_separator=keep_separator,
  1043. **kwargs,
  1044. )
  1045. self._separators = separators or ["\n\n", "\n", " ", ""]
  1046. self._is_separator_regex = is_separator_regex
  1047. self.chunk_size = chunk_size
  1048. self.chunk_overlap = chunk_overlap
  1049. def _split_text(self, text: str, separators: list[str]) -> list[str]:
  1050. """Split incoming text and return chunks."""
  1051. final_chunks = []
  1052. # Get appropriate separator to use
  1053. separator = separators[-1]
  1054. new_separators = []
  1055. for i, _s in enumerate(separators):
  1056. _separator = _s if self._is_separator_regex else re.escape(_s)
  1057. if _s == "":
  1058. separator = _s
  1059. break
  1060. if re.search(_separator, text):
  1061. separator = _s
  1062. new_separators = separators[i + 1 :]
  1063. break
  1064. _separator = (
  1065. separator if self._is_separator_regex else re.escape(separator)
  1066. )
  1067. splits = _split_text_with_regex(text, _separator, self._keep_separator)
  1068. # Now go merging things, recursively splitting longer texts.
  1069. _good_splits = []
  1070. _separator = "" if self._keep_separator else separator
  1071. for s in splits:
  1072. if self._length_function(s) < self._chunk_size:
  1073. _good_splits.append(s)
  1074. else:
  1075. if _good_splits:
  1076. merged_text = self._merge_splits(_good_splits, _separator)
  1077. final_chunks.extend(merged_text)
  1078. _good_splits = []
  1079. if not new_separators:
  1080. final_chunks.append(s)
  1081. else:
  1082. other_info = self._split_text(s, new_separators)
  1083. final_chunks.extend(other_info)
  1084. if _good_splits:
  1085. merged_text = self._merge_splits(_good_splits, _separator)
  1086. final_chunks.extend(merged_text)
  1087. return final_chunks
  1088. def split_text(self, text: str) -> list[str]:
  1089. return self._split_text(text, self._separators)
  1090. @classmethod
  1091. def from_language(
  1092. cls, language: Language, **kwargs: Any
  1093. ) -> RecursiveCharacterTextSplitter:
  1094. separators = cls.get_separators_for_language(language)
  1095. return cls(separators=separators, is_separator_regex=True, **kwargs)
  1096. @staticmethod
  1097. def get_separators_for_language(language: Language) -> list[str]:
  1098. if language == Language.CPP:
  1099. return [
  1100. # Split along class definitions
  1101. "\nclass ",
  1102. # Split along function definitions
  1103. "\nvoid ",
  1104. "\nint ",
  1105. "\nfloat ",
  1106. "\ndouble ",
  1107. # Split along control flow statements
  1108. "\nif ",
  1109. "\nfor ",
  1110. "\nwhile ",
  1111. "\nswitch ",
  1112. "\ncase ",
  1113. # Split by the normal type of lines
  1114. "\n\n",
  1115. "\n",
  1116. " ",
  1117. "",
  1118. ]
  1119. elif language == Language.GO:
  1120. return [
  1121. # Split along function definitions
  1122. "\nfunc ",
  1123. "\nvar ",
  1124. "\nconst ",
  1125. "\ntype ",
  1126. # Split along control flow statements
  1127. "\nif ",
  1128. "\nfor ",
  1129. "\nswitch ",
  1130. "\ncase ",
  1131. # Split by the normal type of lines
  1132. "\n\n",
  1133. "\n",
  1134. " ",
  1135. "",
  1136. ]
  1137. elif language == Language.JAVA:
  1138. return [
  1139. # Split along class definitions
  1140. "\nclass ",
  1141. # Split along method definitions
  1142. "\npublic ",
  1143. "\nprotected ",
  1144. "\nprivate ",
  1145. "\nstatic ",
  1146. # Split along control flow statements
  1147. "\nif ",
  1148. "\nfor ",
  1149. "\nwhile ",
  1150. "\nswitch ",
  1151. "\ncase ",
  1152. # Split by the normal type of lines
  1153. "\n\n",
  1154. "\n",
  1155. " ",
  1156. "",
  1157. ]
  1158. elif language == Language.KOTLIN:
  1159. return [
  1160. # Split along class definitions
  1161. "\nclass ",
  1162. # Split along method definitions
  1163. "\npublic ",
  1164. "\nprotected ",
  1165. "\nprivate ",
  1166. "\ninternal ",
  1167. "\ncompanion ",
  1168. "\nfun ",
  1169. "\nval ",
  1170. "\nvar ",
  1171. # Split along control flow statements
  1172. "\nif ",
  1173. "\nfor ",
  1174. "\nwhile ",
  1175. "\nwhen ",
  1176. "\ncase ",
  1177. "\nelse ",
  1178. # Split by the normal type of lines
  1179. "\n\n",
  1180. "\n",
  1181. " ",
  1182. "",
  1183. ]
  1184. elif language == Language.JS:
  1185. return [
  1186. # Split along function definitions
  1187. "\nfunction ",
  1188. "\nconst ",
  1189. "\nlet ",
  1190. "\nvar ",
  1191. "\nclass ",
  1192. # Split along control flow statements
  1193. "\nif ",
  1194. "\nfor ",
  1195. "\nwhile ",
  1196. "\nswitch ",
  1197. "\ncase ",
  1198. "\ndefault ",
  1199. # Split by the normal type of lines
  1200. "\n\n",
  1201. "\n",
  1202. " ",
  1203. "",
  1204. ]
  1205. elif language == Language.TS:
  1206. return [
  1207. "\nenum ",
  1208. "\ninterface ",
  1209. "\nnamespace ",
  1210. "\ntype ",
  1211. # Split along class definitions
  1212. "\nclass ",
  1213. # Split along function definitions
  1214. "\nfunction ",
  1215. "\nconst ",
  1216. "\nlet ",
  1217. "\nvar ",
  1218. # Split along control flow statements
  1219. "\nif ",
  1220. "\nfor ",
  1221. "\nwhile ",
  1222. "\nswitch ",
  1223. "\ncase ",
  1224. "\ndefault ",
  1225. # Split by the normal type of lines
  1226. "\n\n",
  1227. "\n",
  1228. " ",
  1229. "",
  1230. ]
  1231. elif language == Language.PHP:
  1232. return [
  1233. # Split along function definitions
  1234. "\nfunction ",
  1235. # Split along class definitions
  1236. "\nclass ",
  1237. # Split along control flow statements
  1238. "\nif ",
  1239. "\nforeach ",
  1240. "\nwhile ",
  1241. "\ndo ",
  1242. "\nswitch ",
  1243. "\ncase ",
  1244. # Split by the normal type of lines
  1245. "\n\n",
  1246. "\n",
  1247. " ",
  1248. "",
  1249. ]
  1250. elif language == Language.PROTO:
  1251. return [
  1252. # Split along message definitions
  1253. "\nmessage ",
  1254. # Split along service definitions
  1255. "\nservice ",
  1256. # Split along enum definitions
  1257. "\nenum ",
  1258. # Split along option definitions
  1259. "\noption ",
  1260. # Split along import statements
  1261. "\nimport ",
  1262. # Split along syntax declarations
  1263. "\nsyntax ",
  1264. # Split by the normal type of lines
  1265. "\n\n",
  1266. "\n",
  1267. " ",
  1268. "",
  1269. ]
  1270. elif language == Language.PYTHON:
  1271. return [
  1272. # First, try to split along class definitions
  1273. "\nclass ",
  1274. "\ndef ",
  1275. "\n\tdef ",
  1276. # Now split by the normal type of lines
  1277. "\n\n",
  1278. "\n",
  1279. " ",
  1280. "",
  1281. ]
  1282. elif language == Language.RST:
  1283. return [
  1284. # Split along section titles
  1285. "\n=+\n",
  1286. "\n-+\n",
  1287. "\n\\*+\n",
  1288. # Split along directive markers
  1289. "\n\n.. *\n\n",
  1290. # Split by the normal type of lines
  1291. "\n\n",
  1292. "\n",
  1293. " ",
  1294. "",
  1295. ]
  1296. elif language == Language.RUBY:
  1297. return [
  1298. # Split along method definitions
  1299. "\ndef ",
  1300. "\nclass ",
  1301. # Split along control flow statements
  1302. "\nif ",
  1303. "\nunless ",
  1304. "\nwhile ",
  1305. "\nfor ",
  1306. "\ndo ",
  1307. "\nbegin ",
  1308. "\nrescue ",
  1309. # Split by the normal type of lines
  1310. "\n\n",
  1311. "\n",
  1312. " ",
  1313. "",
  1314. ]
  1315. elif language == Language.RUST:
  1316. return [
  1317. # Split along function definitions
  1318. "\nfn ",
  1319. "\nconst ",
  1320. "\nlet ",
  1321. # Split along control flow statements
  1322. "\nif ",
  1323. "\nwhile ",
  1324. "\nfor ",
  1325. "\nloop ",
  1326. "\nmatch ",
  1327. "\nconst ",
  1328. # Split by the normal type of lines
  1329. "\n\n",
  1330. "\n",
  1331. " ",
  1332. "",
  1333. ]
  1334. elif language == Language.SCALA:
  1335. return [
  1336. # Split along class definitions
  1337. "\nclass ",
  1338. "\nobject ",
  1339. # Split along method definitions
  1340. "\ndef ",
  1341. "\nval ",
  1342. "\nvar ",
  1343. # Split along control flow statements
  1344. "\nif ",
  1345. "\nfor ",
  1346. "\nwhile ",
  1347. "\nmatch ",
  1348. "\ncase ",
  1349. # Split by the normal type of lines
  1350. "\n\n",
  1351. "\n",
  1352. " ",
  1353. "",
  1354. ]
  1355. elif language == Language.SWIFT:
  1356. return [
  1357. # Split along function definitions
  1358. "\nfunc ",
  1359. # Split along class definitions
  1360. "\nclass ",
  1361. "\nstruct ",
  1362. "\nenum ",
  1363. # Split along control flow statements
  1364. "\nif ",
  1365. "\nfor ",
  1366. "\nwhile ",
  1367. "\ndo ",
  1368. "\nswitch ",
  1369. "\ncase ",
  1370. # Split by the normal type of lines
  1371. "\n\n",
  1372. "\n",
  1373. " ",
  1374. "",
  1375. ]
  1376. elif language == Language.MARKDOWN:
  1377. return [
  1378. # First, try to split along Markdown headings (starting with level 2)
  1379. "\n#{1,6} ",
  1380. # Note the alternative syntax for headings (below) is not handled here
  1381. # Heading level 2
  1382. # ---------------
  1383. # End of code block
  1384. "```\n",
  1385. # Horizontal lines
  1386. "\n\\*\\*\\*+\n",
  1387. "\n---+\n",
  1388. "\n___+\n",
  1389. # Note that this splitter doesn't handle horizontal lines defined
  1390. # by *three or more* of ***, ---, or ___, but this is not handled
  1391. "\n\n",
  1392. "\n",
  1393. " ",
  1394. "",
  1395. ]
  1396. elif language == Language.LATEX:
  1397. return [
  1398. # First, try to split along Latex sections
  1399. "\n\\\\chapter{",
  1400. "\n\\\\section{",
  1401. "\n\\\\subsection{",
  1402. "\n\\\\subsubsection{",
  1403. # Now split by environments
  1404. "\n\\\\begin{enumerate}",
  1405. "\n\\\\begin{itemize}",
  1406. "\n\\\\begin{description}",
  1407. "\n\\\\begin{list}",
  1408. "\n\\\\begin{quote}",
  1409. "\n\\\\begin{quotation}",
  1410. "\n\\\\begin{verse}",
  1411. "\n\\\\begin{verbatim}",
  1412. # Now split by math environments
  1413. "\n\\\begin{align}",
  1414. "$$",
  1415. "$",
  1416. # Now split by the normal type of lines
  1417. " ",
  1418. "",
  1419. ]
  1420. elif language == Language.HTML:
  1421. return [
  1422. # First, try to split along HTML tags
  1423. "<body",
  1424. "<div",
  1425. "<p",
  1426. "<br",
  1427. "<li",
  1428. "<h1",
  1429. "<h2",
  1430. "<h3",
  1431. "<h4",
  1432. "<h5",
  1433. "<h6",
  1434. "<span",
  1435. "<table",
  1436. "<tr",
  1437. "<td",
  1438. "<th",
  1439. "<ul",
  1440. "<ol",
  1441. "<header",
  1442. "<footer",
  1443. "<nav",
  1444. # Head
  1445. "<head",
  1446. "<style",
  1447. "<script",
  1448. "<meta",
  1449. "<title",
  1450. "",
  1451. ]
  1452. elif language == Language.CSHARP:
  1453. return [
  1454. "\ninterface ",
  1455. "\nenum ",
  1456. "\nimplements ",
  1457. "\ndelegate ",
  1458. "\nevent ",
  1459. # Split along class definitions
  1460. "\nclass ",
  1461. "\nabstract ",
  1462. # Split along method definitions
  1463. "\npublic ",
  1464. "\nprotected ",
  1465. "\nprivate ",
  1466. "\nstatic ",
  1467. "\nreturn ",
  1468. # Split along control flow statements
  1469. "\nif ",
  1470. "\ncontinue ",
  1471. "\nfor ",
  1472. "\nforeach ",
  1473. "\nwhile ",
  1474. "\nswitch ",
  1475. "\nbreak ",
  1476. "\ncase ",
  1477. "\nelse ",
  1478. # Split by exceptions
  1479. "\ntry ",
  1480. "\nthrow ",
  1481. "\nfinally ",
  1482. "\ncatch ",
  1483. # Split by the normal type of lines
  1484. "\n\n",
  1485. "\n",
  1486. " ",
  1487. "",
  1488. ]
  1489. elif language == Language.SOL:
  1490. return [
  1491. # Split along compiler information definitions
  1492. "\npragma ",
  1493. "\nusing ",
  1494. # Split along contract definitions
  1495. "\ncontract ",
  1496. "\ninterface ",
  1497. "\nlibrary ",
  1498. # Split along method definitions
  1499. "\nconstructor ",
  1500. "\ntype ",
  1501. "\nfunction ",
  1502. "\nevent ",
  1503. "\nmodifier ",
  1504. "\nerror ",
  1505. "\nstruct ",
  1506. "\nenum ",
  1507. # Split along control flow statements
  1508. "\nif ",
  1509. "\nfor ",
  1510. "\nwhile ",
  1511. "\ndo while ",
  1512. "\nassembly ",
  1513. # Split by the normal type of lines
  1514. "\n\n",
  1515. "\n",
  1516. " ",
  1517. "",
  1518. ]
  1519. elif language == Language.COBOL:
  1520. return [
  1521. # Split along divisions
  1522. "\nIDENTIFICATION DIVISION.",
  1523. "\nENVIRONMENT DIVISION.",
  1524. "\nDATA DIVISION.",
  1525. "\nPROCEDURE DIVISION.",
  1526. # Split along sections within DATA DIVISION
  1527. "\nWORKING-STORAGE SECTION.",
  1528. "\nLINKAGE SECTION.",
  1529. "\nFILE SECTION.",
  1530. # Split along sections within PROCEDURE DIVISION
  1531. "\nINPUT-OUTPUT SECTION.",
  1532. # Split along paragraphs and common statements
  1533. "\nOPEN ",
  1534. "\nCLOSE ",
  1535. "\nREAD ",
  1536. "\nWRITE ",
  1537. "\nIF ",
  1538. "\nELSE ",
  1539. "\nMOVE ",
  1540. "\nPERFORM ",
  1541. "\nUNTIL ",
  1542. "\nVARYING ",
  1543. "\nACCEPT ",
  1544. "\nDISPLAY ",
  1545. "\nSTOP RUN.",
  1546. # Split by the normal type of lines
  1547. "\n",
  1548. " ",
  1549. "",
  1550. ]
  1551. else:
  1552. raise ValueError(
  1553. f"Language {language} is not supported! "
  1554. f"Please choose from {list(Language)}"
  1555. )
  1556. class NLTKTextSplitter(TextSplitter):
  1557. """Splitting text using NLTK package."""
  1558. def __init__(
  1559. self, separator: str = "\n\n", language: str = "english", **kwargs: Any
  1560. ) -> None:
  1561. """Initialize the NLTK splitter."""
  1562. super().__init__(**kwargs)
  1563. try:
  1564. from nltk.tokenize import sent_tokenize
  1565. self._tokenizer = sent_tokenize
  1566. except ImportError:
  1567. raise ImportError(
  1568. "NLTK is not installed, please install it with `pip install nltk`."
  1569. )
  1570. self._separator = separator
  1571. self._language = language
  1572. def split_text(self, text: str) -> list[str]:
  1573. """Split incoming text and return chunks."""
  1574. # First we naively split the large input into a bunch of smaller ones.
  1575. splits = self._tokenizer(text, language=self._language)
  1576. return self._merge_splits(splits, self._separator)
  1577. class SpacyTextSplitter(TextSplitter):
  1578. """Splitting text using Spacy package.
  1579. Per default, Spacy's `en_core_web_sm` model is used and
  1580. its default max_length is 1000000 (it is the length of maximum character
  1581. this model takes which can be increased for large files). For a faster, but
  1582. potentially less accurate splitting, you can use `pipe='sentencizer'`.
  1583. """
  1584. def __init__(
  1585. self,
  1586. separator: str = "\n\n",
  1587. pipe: str = "en_core_web_sm",
  1588. max_length: int = 1_000_000,
  1589. **kwargs: Any,
  1590. ) -> None:
  1591. """Initialize the spacy text splitter."""
  1592. super().__init__(**kwargs)
  1593. self._tokenizer = _make_spacy_pipe_for_splitting(
  1594. pipe, max_length=max_length
  1595. )
  1596. self._separator = separator
  1597. def split_text(self, text: str) -> list[str]:
  1598. """Split incoming text and return chunks."""
  1599. splits = (s.text for s in self._tokenizer(text).sents)
  1600. return self._merge_splits(splits, self._separator)
  1601. class KonlpyTextSplitter(TextSplitter):
  1602. """Splitting text using Konlpy package.
  1603. It is good for splitting Korean text.
  1604. """
  1605. def __init__(
  1606. self,
  1607. separator: str = "\n\n",
  1608. **kwargs: Any,
  1609. ) -> None:
  1610. """Initialize the Konlpy text splitter."""
  1611. super().__init__(**kwargs)
  1612. self._separator = separator
  1613. try:
  1614. from konlpy.tag import Kkma
  1615. except ImportError:
  1616. raise ImportError(
  1617. """
  1618. Konlpy is not installed, please install it with
  1619. `pip install konlpy`
  1620. """
  1621. )
  1622. self.kkma = Kkma()
  1623. def split_text(self, text: str) -> list[str]:
  1624. """Split incoming text and return chunks."""
  1625. splits = self.kkma.sentences(text)
  1626. return self._merge_splits(splits, self._separator)
  1627. # For backwards compatibility
  1628. class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
  1629. """Attempts to split the text along Python syntax."""
  1630. def __init__(self, **kwargs: Any) -> None:
  1631. """Initialize a PythonCodeTextSplitter."""
  1632. separators = self.get_separators_for_language(Language.PYTHON)
  1633. super().__init__(separators=separators, **kwargs)
  1634. class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
  1635. """Attempts to split the text along Markdown-formatted headings."""
  1636. def __init__(self, **kwargs: Any) -> None:
  1637. """Initialize a MarkdownTextSplitter."""
  1638. separators = self.get_separators_for_language(Language.MARKDOWN)
  1639. super().__init__(separators=separators, **kwargs)
  1640. class LatexTextSplitter(RecursiveCharacterTextSplitter):
  1641. """Attempts to split the text along Latex-formatted layout elements."""
  1642. def __init__(self, **kwargs: Any) -> None:
  1643. """Initialize a LatexTextSplitter."""
  1644. separators = self.get_separators_for_language(Language.LATEX)
  1645. super().__init__(separators=separators, **kwargs)
  1646. class RecursiveJsonSplitter:
  1647. def __init__(
  1648. self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None
  1649. ):
  1650. super().__init__()
  1651. self.max_chunk_size = max_chunk_size
  1652. self.min_chunk_size = (
  1653. min_chunk_size
  1654. if min_chunk_size is not None
  1655. else max(max_chunk_size - 200, 50)
  1656. )
  1657. @staticmethod
  1658. def _json_size(data: dict) -> int:
  1659. """Calculate the size of the serialized JSON object."""
  1660. return len(json.dumps(data))
  1661. @staticmethod
  1662. def _set_nested_dict(d: dict, path: list[str], value: Any) -> None:
  1663. """Set a value in a nested dictionary based on the given path."""
  1664. for key in path[:-1]:
  1665. d = d.setdefault(key, {})
  1666. d[path[-1]] = value
  1667. def _list_to_dict_preprocessing(self, data: Any) -> Any:
  1668. if isinstance(data, dict):
  1669. # Process each key-value pair in the dictionary
  1670. return {
  1671. k: self._list_to_dict_preprocessing(v) for k, v in data.items()
  1672. }
  1673. elif isinstance(data, list):
  1674. # Convert the list to a dictionary with index-based keys
  1675. return {
  1676. str(i): self._list_to_dict_preprocessing(item)
  1677. for i, item in enumerate(data)
  1678. }
  1679. else:
  1680. # Base case: the item is neither a dict nor a list, so return it unchanged
  1681. return data
  1682. def _json_split(
  1683. self,
  1684. data: dict[str, Any],
  1685. current_path: list[str] = [],
  1686. chunks: list[dict] = [{}],
  1687. ) -> list[dict]:
  1688. """
  1689. Split json into maximum size dictionaries while preserving structure.
  1690. """
  1691. if isinstance(data, dict):
  1692. for key, value in data.items():
  1693. new_path = current_path + [key]
  1694. chunk_size = self._json_size(chunks[-1])
  1695. size = self._json_size({key: value})
  1696. remaining = self.max_chunk_size - chunk_size
  1697. if size < remaining:
  1698. # Add item to current chunk
  1699. self._set_nested_dict(chunks[-1], new_path, value)
  1700. else:
  1701. if chunk_size >= self.min_chunk_size:
  1702. # Chunk is big enough, start a new chunk
  1703. chunks.append({})
  1704. # Iterate
  1705. self._json_split(value, new_path, chunks)
  1706. else:
  1707. # handle single item
  1708. self._set_nested_dict(chunks[-1], current_path, data)
  1709. return chunks
  1710. def split_json(
  1711. self,
  1712. json_data: dict[str, Any],
  1713. convert_lists: bool = False,
  1714. ) -> list[dict]:
  1715. """Splits JSON into a list of JSON chunks"""
  1716. if convert_lists:
  1717. chunks = self._json_split(
  1718. self._list_to_dict_preprocessing(json_data)
  1719. )
  1720. else:
  1721. chunks = self._json_split(json_data)
  1722. # Remove the last chunk if it's empty
  1723. if not chunks[-1]:
  1724. chunks.pop()
  1725. return chunks
  1726. def split_text(
  1727. self, json_data: dict[str, Any], convert_lists: bool = False
  1728. ) -> list[str]:
  1729. """Splits JSON into a list of JSON formatted strings"""
  1730. chunks = self.split_json(
  1731. json_data=json_data, convert_lists=convert_lists
  1732. )
  1733. # Convert to string
  1734. return [json.dumps(chunk) for chunk in chunks]
  1735. def create_documents(
  1736. self,
  1737. texts: list[dict],
  1738. convert_lists: bool = False,
  1739. metadatas: Optional[list[dict]] = None,
  1740. ) -> list[SplitterDocument]:
  1741. """Create documents from a list of json objects (dict)."""
  1742. _metadatas = metadatas or [{}] * len(texts)
  1743. documents = []
  1744. for i, text in enumerate(texts):
  1745. for chunk in self.split_text(
  1746. json_data=text, convert_lists=convert_lists
  1747. ):
  1748. metadata = copy.deepcopy(_metadatas[i])
  1749. new_doc = SplitterDocument(
  1750. page_content=chunk, metadata=metadata
  1751. )
  1752. documents.append(new_doc)
  1753. return documents