text.py 66 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000
  1. # Source - LangChain
  2. # URL: https://github.com/langchain-ai/langchain/blob/6a5b084704afa22ca02f78d0464f35aed75d1ff2/libs/langchain/langchain/text_splitter.py#L851
  3. """**Text Splitters** are classes for splitting text.
  4. **Class hierarchy:**
  5. .. code-block::
  6. BaseDocumentTransformer --> TextSplitter --> <name>TextSplitter # Example: CharacterTextSplitter
  7. RecursiveCharacterTextSplitter --> <name>TextSplitter
  8. Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter.
  9. **Main helpers:**
  10. .. code-block::
  11. Document, Tokenizer, Language, LineType, HeaderType
  12. """ # noqa: E501
  13. from __future__ import annotations
  14. import copy
  15. import json
  16. import logging
  17. import pathlib
  18. import re
  19. from abc import ABC, abstractmethod
  20. from dataclasses import dataclass
  21. from enum import Enum
  22. from io import BytesIO, StringIO
  23. from typing import (
  24. AbstractSet,
  25. Any,
  26. Callable,
  27. Collection,
  28. Iterable,
  29. Literal,
  30. Optional,
  31. Sequence,
  32. Tuple,
  33. Type,
  34. TypedDict,
  35. TypeVar,
  36. cast,
  37. )
  38. import requests
  39. from pydantic import BaseModel, Field, PrivateAttr
  40. from typing_extensions import NotRequired
  41. logger = logging.getLogger()
  42. TS = TypeVar("TS", bound="TextSplitter")
  43. class BaseSerialized(TypedDict):
  44. """Base class for serialized objects."""
  45. lc: int
  46. id: list[str]
  47. name: NotRequired[str]
  48. graph: NotRequired[dict[str, Any]]
  49. class SerializedConstructor(BaseSerialized):
  50. """Serialized constructor."""
  51. type: Literal["constructor"]
  52. kwargs: dict[str, Any]
  53. class SerializedSecret(BaseSerialized):
  54. """Serialized secret."""
  55. type: Literal["secret"]
  56. class SerializedNotImplemented(BaseSerialized):
  57. """Serialized not implemented."""
  58. type: Literal["not_implemented"]
  59. repr: Optional[str]
  60. def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
  61. """Try to determine if a value is different from the default.
  62. Args:
  63. value: The value.
  64. key: The key.
  65. model: The model.
  66. Returns:
  67. Whether the value is different from the default.
  68. """
  69. try:
  70. return model.__fields__[key].get_default() != value
  71. except Exception:
  72. return True
  73. class Serializable(BaseModel, ABC):
  74. """Serializable base class."""
  75. @classmethod
  76. def is_lc_serializable(cls) -> bool:
  77. """Is this class serializable?"""
  78. return False
  79. @classmethod
  80. def get_lc_namespace(cls) -> list[str]:
  81. """Get the namespace of the langchain object.
  82. For example, if the class is `langchain.llms.openai.OpenAI`, then the
  83. namespace is ["langchain", "llms", "openai"]
  84. """
  85. return cls.__module__.split(".")
  86. @property
  87. def lc_secrets(self) -> dict[str, str]:
  88. """A map of constructor argument names to secret ids.
  89. For example, {"openai_api_key": "OPENAI_API_KEY"}
  90. """
  91. return {}
  92. @property
  93. def lc_attributes(self) -> dict:
  94. """List of attribute names that should be included in the serialized
  95. kwargs.
  96. These attributes must be accepted by the constructor.
  97. """
  98. return {}
  99. @classmethod
  100. def lc_id(cls) -> list[str]:
  101. """A unique identifier for this class for serialization purposes.
  102. The unique identifier is a list of strings that describes the path to
  103. the object.
  104. """
  105. return [*cls.get_lc_namespace(), cls.__name__]
  106. class Config:
  107. extra = "ignore"
  108. def __repr_args__(self) -> Any:
  109. return [
  110. (k, v)
  111. for k, v in super().__repr_args__()
  112. if (k not in self.__fields__ or try_neq_default(v, k, self))
  113. ]
  114. _lc_kwargs: dict[str, Any] = PrivateAttr(default_factory=dict)
  115. def __init__(self, **kwargs: Any) -> None:
  116. super().__init__(**kwargs)
  117. self._lc_kwargs = kwargs
  118. def to_json(
  119. self,
  120. ) -> SerializedConstructor | SerializedNotImplemented:
  121. if not self.is_lc_serializable():
  122. return self.to_json_not_implemented()
  123. secrets = dict()
  124. # Get latest values for kwargs if there is an attribute with same name
  125. lc_kwargs = {
  126. k: getattr(self, k, v)
  127. for k, v in self._lc_kwargs.items()
  128. if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
  129. }
  130. # Merge the lc_secrets and lc_attributes from every class in the MRO
  131. for cls in [None, *self.__class__.mro()]:
  132. # Once we get to Serializable, we're done
  133. if cls is Serializable:
  134. break
  135. if cls:
  136. deprecated_attributes = [
  137. "lc_namespace",
  138. "lc_serializable",
  139. ]
  140. for attr in deprecated_attributes:
  141. if hasattr(cls, attr):
  142. raise ValueError(
  143. f"Class {self.__class__} has a deprecated "
  144. f"attribute {attr}. Please use the corresponding "
  145. f"classmethod instead."
  146. )
  147. # Get a reference to self bound to each class in the MRO
  148. this = cast(
  149. Serializable, self if cls is None else super(cls, self)
  150. )
  151. secrets.update(this.lc_secrets)
  152. # Now also add the aliases for the secrets
  153. # This ensures known secret aliases are hidden.
  154. # Note: this does NOT hide any other extra kwargs
  155. # that are not present in the fields.
  156. for key in list(secrets):
  157. value = secrets[key]
  158. if key in this.__fields__:
  159. secrets[this.__fields__[key].alias] = value # type: ignore
  160. lc_kwargs.update(this.lc_attributes)
  161. # include all secrets, even if not specified in kwargs
  162. # as these secrets may be passed as an environment variable instead
  163. for key in secrets.keys():
  164. secret_value = getattr(self, key, None) or lc_kwargs.get(key)
  165. if secret_value is not None:
  166. lc_kwargs.update({key: secret_value})
  167. return {
  168. "lc": 1,
  169. "type": "constructor",
  170. "id": self.lc_id(),
  171. "kwargs": (
  172. lc_kwargs
  173. if not secrets
  174. else _replace_secrets(lc_kwargs, secrets)
  175. ),
  176. }
  177. def to_json_not_implemented(self) -> SerializedNotImplemented:
  178. return to_json_not_implemented(self)
  179. def _replace_secrets(
  180. root: dict[Any, Any], secrets_map: dict[str, str]
  181. ) -> dict[Any, Any]:
  182. result = root.copy()
  183. for path, secret_id in secrets_map.items():
  184. [*parts, last] = path.split(".")
  185. current = result
  186. for part in parts:
  187. if part not in current:
  188. break
  189. current[part] = current[part].copy()
  190. current = current[part]
  191. if last in current:
  192. current[last] = {
  193. "lc": 1,
  194. "type": "secret",
  195. "id": [secret_id],
  196. }
  197. return result
  198. def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
  199. """Serialize a "not implemented" object.
  200. Args:
  201. obj: object to serialize
  202. Returns:
  203. SerializedNotImplemented
  204. """
  205. _id: list[str] = []
  206. try:
  207. if hasattr(obj, "__name__"):
  208. _id = [*obj.__module__.split("."), obj.__name__]
  209. elif hasattr(obj, "__class__"):
  210. _id = [
  211. *obj.__class__.__module__.split("."),
  212. obj.__class__.__name__,
  213. ]
  214. except Exception:
  215. pass
  216. result: SerializedNotImplemented = {
  217. "lc": 1,
  218. "type": "not_implemented",
  219. "id": _id,
  220. "repr": None,
  221. }
  222. try:
  223. result["repr"] = repr(obj)
  224. except Exception:
  225. pass
  226. return result
  227. class SplitterDocument(Serializable):
  228. """Class for storing a piece of text and associated metadata."""
  229. page_content: str
  230. """String text."""
  231. metadata: dict = Field(default_factory=dict)
  232. """Arbitrary metadata about the page content (e.g., source, relationships
  233. to other documents, etc.)."""
  234. type: Literal["Document"] = "Document"
  235. def __init__(self, page_content: str, **kwargs: Any) -> None:
  236. """Pass page_content in as positional or named arg."""
  237. super().__init__(page_content=page_content, **kwargs)
  238. @classmethod
  239. def is_lc_serializable(cls) -> bool:
  240. """Return whether this class is serializable."""
  241. return True
  242. @classmethod
  243. def get_lc_namespace(cls) -> list[str]:
  244. """Get the namespace of the langchain object."""
  245. return ["langchain", "schema", "document"]
  246. class BaseDocumentTransformer(ABC):
  247. """Abstract base class for document transformation systems.
  248. A document transformation system takes a sequence of Documents and returns a
  249. sequence of transformed Documents.
  250. Example:
  251. .. code-block:: python
  252. class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
  253. embeddings: Embeddings
  254. similarity_fn: Callable = cosine_similarity
  255. similarity_threshold: float = 0.95
  256. class Config:
  257. arbitrary_types_allowed = True
  258. def transform_documents(
  259. self, documents: Sequence[Document], **kwargs: Any
  260. ) -> Sequence[Document]:
  261. stateful_documents = get_stateful_documents(documents)
  262. embedded_documents = _get_embeddings_from_stateful_docs(
  263. self.embeddings, stateful_documents
  264. )
  265. included_idxs = _filter_similar_embeddings(
  266. embedded_documents, self.similarity_fn, self.similarity_threshold
  267. )
  268. return [stateful_documents[i] for i in sorted(included_idxs)]
  269. async def atransform_documents(
  270. self, documents: Sequence[Document], **kwargs: Any
  271. ) -> Sequence[Document]:
  272. raise NotImplementedError
  273. """ # noqa: E501
  274. @abstractmethod
  275. def transform_documents(
  276. self, documents: Sequence[SplitterDocument], **kwargs: Any
  277. ) -> Sequence[SplitterDocument]:
  278. """Transform a list of documents.
  279. Args:
  280. documents: A sequence of Documents to be transformed.
  281. Returns:
  282. A list of transformed Documents.
  283. """
  284. async def atransform_documents(
  285. self, documents: Sequence[SplitterDocument], **kwargs: Any
  286. ) -> Sequence[SplitterDocument]:
  287. """Asynchronously transform a list of documents.
  288. Args:
  289. documents: A sequence of Documents to be transformed.
  290. Returns:
  291. A list of transformed Documents.
  292. """
  293. raise NotImplementedError("This method is not implemented.")
  294. # return await langchain_core.runnables.config.run_in_executor(
  295. # None, self.transform_documents, documents, **kwargs
  296. # )
  297. def _make_spacy_pipe_for_splitting(
  298. pipe: str, *, max_length: int = 1_000_000
  299. ) -> Any: # avoid importing spacy
  300. try:
  301. import spacy
  302. except ImportError:
  303. raise ImportError(
  304. "Spacy is not installed, run `pip install spacy`."
  305. ) from None
  306. if pipe == "sentencizer":
  307. from spacy.lang.en import English
  308. sentencizer = English()
  309. sentencizer.add_pipe("sentencizer")
  310. else:
  311. sentencizer = spacy.load(pipe, exclude=["ner", "tagger"])
  312. sentencizer.max_length = max_length
  313. return sentencizer
  314. def _split_text_with_regex(
  315. text: str, separator: str, keep_separator: bool
  316. ) -> list[str]:
  317. # Now that we have the separator, split the text
  318. if separator:
  319. if keep_separator:
  320. # The parentheses in the pattern keep the delimiters in the result.
  321. _splits = re.split(f"({separator})", text)
  322. splits = [
  323. _splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)
  324. ]
  325. if len(_splits) % 2 == 0:
  326. splits += _splits[-1:]
  327. splits = [_splits[0]] + splits
  328. else:
  329. splits = re.split(separator, text)
  330. else:
  331. splits = list(text)
  332. return [s for s in splits if s != ""]
  333. class TextSplitter(BaseDocumentTransformer, ABC):
  334. """Interface for splitting text into chunks."""
  335. def __init__(
  336. self,
  337. chunk_size: int = 4000,
  338. chunk_overlap: int = 200,
  339. length_function: Callable[[str], int] = len,
  340. keep_separator: bool = False,
  341. add_start_index: bool = False,
  342. strip_whitespace: bool = True,
  343. ) -> None:
  344. """Create a new TextSplitter.
  345. Args:
  346. chunk_size: Maximum size of chunks to return
  347. chunk_overlap: Overlap in characters between chunks
  348. length_function: Function that measures the length of given chunks
  349. keep_separator: Whether to keep the separator in the chunks
  350. add_start_index: If `True`, includes chunk's start index in
  351. metadata
  352. strip_whitespace: If `True`, strips whitespace from the start and
  353. end of every document
  354. """
  355. if chunk_overlap > chunk_size:
  356. raise ValueError(
  357. f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
  358. f"({chunk_size}), should be smaller."
  359. )
  360. self._chunk_size = chunk_size
  361. self._chunk_overlap = chunk_overlap
  362. self._length_function = length_function
  363. self._keep_separator = keep_separator
  364. self._add_start_index = add_start_index
  365. self._strip_whitespace = strip_whitespace
  366. @abstractmethod
  367. def split_text(self, text: str) -> list[str]:
  368. """Split text into multiple components."""
  369. def create_documents(
  370. self, texts: list[str], metadatas: Optional[list[dict]] = None
  371. ) -> list[SplitterDocument]:
  372. """Create documents from a list of texts."""
  373. _metadatas = metadatas or [{}] * len(texts)
  374. documents = []
  375. for i, text in enumerate(texts):
  376. index = 0
  377. previous_chunk_len = 0
  378. for chunk in self.split_text(text):
  379. metadata = copy.deepcopy(_metadatas[i])
  380. if self._add_start_index:
  381. offset = index + previous_chunk_len - self._chunk_overlap
  382. index = text.find(chunk, max(0, offset))
  383. metadata["start_index"] = index
  384. previous_chunk_len = len(chunk)
  385. new_doc = SplitterDocument(
  386. page_content=chunk, metadata=metadata
  387. )
  388. documents.append(new_doc)
  389. return documents
  390. def split_documents(
  391. self, documents: Iterable[SplitterDocument]
  392. ) -> list[SplitterDocument]:
  393. """Split documents."""
  394. texts, metadatas = [], []
  395. for doc in documents:
  396. texts.append(doc.page_content)
  397. metadatas.append(doc.metadata)
  398. return self.create_documents(texts, metadatas=metadatas)
  399. def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
  400. text = separator.join(docs)
  401. if self._strip_whitespace:
  402. text = text.strip()
  403. if text == "":
  404. return None
  405. else:
  406. return text
  407. def _merge_splits(
  408. self, splits: Iterable[str], separator: str
  409. ) -> list[str]:
  410. # We now want to combine these smaller pieces into medium size
  411. # chunks to send to the LLM.
  412. separator_len = self._length_function(separator)
  413. docs = []
  414. current_doc: list[str] = []
  415. total = 0
  416. for d in splits:
  417. _len = self._length_function(d)
  418. if (
  419. total + _len + (separator_len if len(current_doc) > 0 else 0)
  420. > self._chunk_size
  421. ):
  422. if total > self._chunk_size:
  423. logger.warning(
  424. f"Created a chunk of size {total}, "
  425. f"which is longer than the specified {self._chunk_size}"
  426. )
  427. if len(current_doc) > 0:
  428. doc = self._join_docs(current_doc, separator)
  429. if doc is not None:
  430. docs.append(doc)
  431. # Keep on popping if:
  432. # - we have a larger chunk than in the chunk overlap
  433. # - or if we still have any chunks and the length is long
  434. while total > self._chunk_overlap or (
  435. total
  436. + _len
  437. + (separator_len if len(current_doc) > 0 else 0)
  438. > self._chunk_size
  439. and total > 0
  440. ):
  441. total -= self._length_function(current_doc[0]) + (
  442. separator_len if len(current_doc) > 1 else 0
  443. )
  444. current_doc = current_doc[1:]
  445. current_doc.append(d)
  446. total += _len + (separator_len if len(current_doc) > 1 else 0)
  447. doc = self._join_docs(current_doc, separator)
  448. if doc is not None:
  449. docs.append(doc)
  450. return docs
  451. @classmethod
  452. def from_huggingface_tokenizer(
  453. cls, tokenizer: Any, **kwargs: Any
  454. ) -> TextSplitter:
  455. """Text splitter that uses HuggingFace tokenizer to count length."""
  456. try:
  457. from transformers import PreTrainedTokenizerBase
  458. if not isinstance(tokenizer, PreTrainedTokenizerBase):
  459. raise ValueError(
  460. "Tokenizer received was not an instance of PreTrainedTokenizerBase"
  461. )
  462. def _huggingface_tokenizer_length(text: str) -> int:
  463. return len(tokenizer.encode(text))
  464. except ImportError:
  465. raise ValueError(
  466. "Could not import transformers python package. "
  467. "Please install it with `pip install transformers`."
  468. ) from None
  469. return cls(length_function=_huggingface_tokenizer_length, **kwargs)
  470. @classmethod
  471. def from_tiktoken_encoder(
  472. cls: Type[TS],
  473. encoding_name: str = "gpt2",
  474. model: Optional[str] = None,
  475. allowed_special: Literal["all"] | AbstractSet[str] = set(),
  476. disallowed_special: Literal["all"] | Collection[str] = "all",
  477. **kwargs: Any,
  478. ) -> TS:
  479. """Text splitter that uses tiktoken encoder to count length."""
  480. try:
  481. import tiktoken
  482. except ImportError:
  483. raise ImportError("""Could not import tiktoken python package.
  484. This is needed in order to calculate max_tokens_for_prompt.
  485. Please install it with `pip install tiktoken`.""") from None
  486. if model is not None:
  487. enc = tiktoken.encoding_for_model(model)
  488. else:
  489. enc = tiktoken.get_encoding(encoding_name)
  490. def _tiktoken_encoder(text: str) -> int:
  491. return len(
  492. enc.encode(
  493. text,
  494. allowed_special=allowed_special,
  495. disallowed_special=disallowed_special,
  496. )
  497. )
  498. if issubclass(cls, TokenTextSplitter):
  499. extra_kwargs = {
  500. "encoding_name": encoding_name,
  501. "model": model,
  502. "allowed_special": allowed_special,
  503. "disallowed_special": disallowed_special,
  504. }
  505. kwargs = {**kwargs, **extra_kwargs}
  506. return cls(length_function=_tiktoken_encoder, **kwargs)
  507. def transform_documents(
  508. self, documents: Sequence[SplitterDocument], **kwargs: Any
  509. ) -> Sequence[SplitterDocument]:
  510. """Transform sequence of documents by splitting them."""
  511. return self.split_documents(list(documents))
  512. class CharacterTextSplitter(TextSplitter):
  513. """Splitting text that looks at characters."""
  514. DEFAULT_SEPARATOR: str = "\n\n"
  515. def __init__(
  516. self,
  517. separator: str = DEFAULT_SEPARATOR,
  518. is_separator_regex: bool = False,
  519. **kwargs: Any,
  520. ) -> None:
  521. """Create a new TextSplitter."""
  522. super().__init__(**kwargs)
  523. self._separator = separator
  524. self._is_separator_regex = is_separator_regex
  525. def split_text(self, text: str) -> list[str]:
  526. """Split incoming text and return chunks."""
  527. # First we naively split the large input into a bunch of smaller ones.
  528. separator = (
  529. self._separator
  530. if self._is_separator_regex
  531. else re.escape(self._separator)
  532. )
  533. splits = _split_text_with_regex(text, separator, self._keep_separator)
  534. _separator = "" if self._keep_separator else self._separator
  535. return self._merge_splits(splits, _separator)
  536. class LineType(TypedDict):
  537. """Line type as typed dict."""
  538. metadata: dict[str, str]
  539. content: str
  540. class HeaderType(TypedDict):
  541. """Header type as typed dict."""
  542. level: int
  543. name: str
  544. data: str
  545. class MarkdownHeaderTextSplitter:
  546. """Splitting markdown files based on specified headers."""
  547. def __init__(
  548. self,
  549. headers_to_split_on: list[Tuple[str, str]],
  550. return_each_line: bool = False,
  551. strip_headers: bool = True,
  552. ):
  553. """Create a new MarkdownHeaderTextSplitter.
  554. Args:
  555. headers_to_split_on: Headers we want to track
  556. return_each_line: Return each line w/ associated headers
  557. strip_headers: Strip split headers from the content of the chunk
  558. """
  559. # Output line-by-line or aggregated into chunks w/ common headers
  560. self.return_each_line = return_each_line
  561. # Given the headers we want to split on,
  562. # (e.g., "#, ##, etc") order by length
  563. self.headers_to_split_on = sorted(
  564. headers_to_split_on, key=lambda split: len(split[0]), reverse=True
  565. )
  566. # Strip headers split headers from the content of the chunk
  567. self.strip_headers = strip_headers
  568. def aggregate_lines_to_chunks(
  569. self, lines: list[LineType]
  570. ) -> list[SplitterDocument]:
  571. """Combine lines with common metadata into chunks
  572. Args:
  573. lines: Line of text / associated header metadata
  574. """
  575. aggregated_chunks: list[LineType] = []
  576. for line in lines:
  577. if (
  578. aggregated_chunks
  579. and aggregated_chunks[-1]["metadata"] == line["metadata"]
  580. ):
  581. # If the last line in the aggregated list
  582. # has the same metadata as the current line,
  583. # append the current content to the last lines's content
  584. aggregated_chunks[-1]["content"] += " \n" + line["content"]
  585. elif (
  586. aggregated_chunks
  587. and aggregated_chunks[-1]["metadata"] != line["metadata"]
  588. # may be issues if other metadata is present
  589. and len(aggregated_chunks[-1]["metadata"])
  590. < len(line["metadata"])
  591. and aggregated_chunks[-1]["content"].split("\n")[-1][0] == "#"
  592. and not self.strip_headers
  593. ):
  594. # If the last line in the aggregated list
  595. # has different metadata as the current line,
  596. # and has shallower header level than the current line,
  597. # and the last line is a header,
  598. # and we are not stripping headers,
  599. # append the current content to the last line's content
  600. aggregated_chunks[-1]["content"] += " \n" + line["content"]
  601. # and update the last line's metadata
  602. aggregated_chunks[-1]["metadata"] = line["metadata"]
  603. else:
  604. # Otherwise, append the current line to the aggregated list
  605. aggregated_chunks.append(line)
  606. return [
  607. SplitterDocument(
  608. page_content=chunk["content"], metadata=chunk["metadata"]
  609. )
  610. for chunk in aggregated_chunks
  611. ]
  612. def split_text(self, text: str) -> list[SplitterDocument]:
  613. """Split markdown file
  614. Args:
  615. text: Markdown file"""
  616. # Split the input text by newline character ("\n").
  617. lines = text.split("\n")
  618. # Final output
  619. lines_with_metadata: list[LineType] = []
  620. # Content and metadata of the chunk currently being processed
  621. current_content: list[str] = []
  622. current_metadata: dict[str, str] = {}
  623. # Keep track of the nested header structure
  624. # header_stack: list[dict[str, int | str]] = []
  625. header_stack: list[HeaderType] = []
  626. initial_metadata: dict[str, str] = {}
  627. in_code_block = False
  628. opening_fence = ""
  629. for line in lines:
  630. stripped_line = line.strip()
  631. if not in_code_block:
  632. # Exclude inline code spans
  633. if (
  634. stripped_line.startswith("```")
  635. and stripped_line.count("```") == 1
  636. ):
  637. in_code_block = True
  638. opening_fence = "```"
  639. elif stripped_line.startswith("~~~"):
  640. in_code_block = True
  641. opening_fence = "~~~"
  642. else:
  643. if stripped_line.startswith(opening_fence):
  644. in_code_block = False
  645. opening_fence = ""
  646. if in_code_block:
  647. current_content.append(stripped_line)
  648. continue
  649. # Check each line against each of the header types (e.g., #, ##)
  650. for sep, name in self.headers_to_split_on:
  651. # Check if line starts with a header that we intend to split on
  652. if stripped_line.startswith(sep) and (
  653. # Header with no text OR header is followed by space
  654. # Both are valid conditions that sep is being used a header
  655. len(stripped_line) == len(sep)
  656. or stripped_line[len(sep)] == " "
  657. ):
  658. # Ensure we are tracking the header as metadata
  659. if name is not None:
  660. # Get the current header level
  661. current_header_level = sep.count("#")
  662. # Pop out headers of lower or same level from the stack
  663. while (
  664. header_stack
  665. and header_stack[-1]["level"]
  666. >= current_header_level
  667. ):
  668. # We have encountered a new header
  669. # at the same or higher level
  670. popped_header = header_stack.pop()
  671. # Clear the metadata for the
  672. # popped header in initial_metadata
  673. if popped_header["name"] in initial_metadata:
  674. initial_metadata.pop(popped_header["name"])
  675. # Push the current header to the stack
  676. header: HeaderType = {
  677. "level": current_header_level,
  678. "name": name,
  679. "data": stripped_line[len(sep) :].strip(),
  680. }
  681. header_stack.append(header)
  682. # Update initial_metadata with the current header
  683. initial_metadata[name] = header["data"]
  684. # Add the previous line to the lines_with_metadata
  685. # only if current_content is not empty
  686. if current_content:
  687. lines_with_metadata.append(
  688. {
  689. "content": "\n".join(current_content),
  690. "metadata": current_metadata.copy(),
  691. }
  692. )
  693. current_content.clear()
  694. if not self.strip_headers:
  695. current_content.append(stripped_line)
  696. break
  697. else:
  698. if stripped_line:
  699. current_content.append(stripped_line)
  700. elif current_content:
  701. lines_with_metadata.append(
  702. {
  703. "content": "\n".join(current_content),
  704. "metadata": current_metadata.copy(),
  705. }
  706. )
  707. current_content.clear()
  708. current_metadata = initial_metadata.copy()
  709. if current_content:
  710. lines_with_metadata.append(
  711. {
  712. "content": "\n".join(current_content),
  713. "metadata": current_metadata,
  714. }
  715. )
  716. # lines_with_metadata has each line with associated header metadata
  717. # aggregate these into chunks based on common metadata
  718. if not self.return_each_line:
  719. return self.aggregate_lines_to_chunks(lines_with_metadata)
  720. else:
  721. return [
  722. SplitterDocument(
  723. page_content=chunk["content"], metadata=chunk["metadata"]
  724. )
  725. for chunk in lines_with_metadata
  726. ]
  727. class ElementType(TypedDict):
  728. """Element type as typed dict."""
  729. url: str
  730. xpath: str
  731. content: str
  732. metadata: dict[str, str]
  733. class HTMLHeaderTextSplitter:
  734. """Splitting HTML files based on specified headers.
  735. Requires lxml package.
  736. """
  737. def __init__(
  738. self,
  739. headers_to_split_on: list[Tuple[str, str]],
  740. return_each_element: bool = False,
  741. ):
  742. """Create a new HTMLHeaderTextSplitter.
  743. Args:
  744. headers_to_split_on: list of tuples of headers we want to track
  745. mapped to (arbitrary) keys for metadata. Allowed header values:
  746. h1, h2, h3, h4, h5, h6
  747. e.g. [("h1", "Header 1"), ("h2", "Header 2)].
  748. return_each_element: Return each element w/ associated headers.
  749. """
  750. # Output element-by-element or aggregated into chunks w/ common headers
  751. self.return_each_element = return_each_element
  752. self.headers_to_split_on = sorted(headers_to_split_on)
  753. def aggregate_elements_to_chunks(
  754. self, elements: list[ElementType]
  755. ) -> list[SplitterDocument]:
  756. """Combine elements with common metadata into chunks.
  757. Args:
  758. elements: HTML element content with associated identifying
  759. info and metadata
  760. """
  761. aggregated_chunks: list[ElementType] = []
  762. for element in elements:
  763. if (
  764. aggregated_chunks
  765. and aggregated_chunks[-1]["metadata"] == element["metadata"]
  766. ):
  767. # If the last element in the aggregated list
  768. # has the same metadata as the current element,
  769. # append the current content to the last element's content
  770. aggregated_chunks[-1]["content"] += " \n" + element["content"]
  771. else:
  772. # Otherwise, append the current element to the aggregated list
  773. aggregated_chunks.append(element)
  774. return [
  775. SplitterDocument(
  776. page_content=chunk["content"], metadata=chunk["metadata"]
  777. )
  778. for chunk in aggregated_chunks
  779. ]
  780. def split_text_from_url(self, url: str) -> list[SplitterDocument]:
  781. """Split HTML from web URL.
  782. Args:
  783. url: web URL
  784. """
  785. r = requests.get(url)
  786. return self.split_text_from_file(BytesIO(r.content))
  787. def split_text(self, text: str) -> list[SplitterDocument]:
  788. """Split HTML text string.
  789. Args:
  790. text: HTML text
  791. """
  792. return self.split_text_from_file(StringIO(text))
  793. def split_text_from_file(self, file: Any) -> list[SplitterDocument]:
  794. """Split HTML file.
  795. Args:
  796. file: HTML file
  797. """
  798. try:
  799. from lxml import etree
  800. except ImportError:
  801. raise ImportError(
  802. "Unable to import lxml, run `pip install lxml`."
  803. ) from None
  804. # use lxml library to parse html document and return xml ElementTree
  805. # Explicitly encoding in utf-8 allows non-English
  806. # html files to be processed without garbled characters
  807. parser = etree.HTMLParser(encoding="utf-8")
  808. tree = etree.parse(file, parser)
  809. # document transformation for "structure-aware" chunking is handled
  810. # with xsl. See comments in html_chunks_with_headers.xslt for more
  811. # detailed information.
  812. xslt_path = (
  813. pathlib.Path(__file__).parent
  814. / "document_transformers/xsl/html_chunks_with_headers.xslt"
  815. )
  816. xslt_tree = etree.parse(xslt_path)
  817. transform = etree.XSLT(xslt_tree)
  818. result = transform(tree)
  819. result_dom = etree.fromstring(str(result))
  820. # create filter and mapping for header metadata
  821. header_filter = [header[0] for header in self.headers_to_split_on]
  822. header_mapping = dict(self.headers_to_split_on)
  823. # map xhtml namespace prefix
  824. ns_map = {"h": "http://www.w3.org/1999/xhtml"}
  825. # build list of elements from DOM
  826. elements = []
  827. for element in result_dom.findall("*//*", ns_map):
  828. if element.findall("*[@class='headers']") or element.findall(
  829. "*[@class='chunk']"
  830. ):
  831. elements.append(
  832. ElementType(
  833. url=file,
  834. xpath="".join(
  835. [
  836. node.text
  837. for node in element.findall(
  838. "*[@class='xpath']", ns_map
  839. )
  840. ]
  841. ),
  842. content="".join(
  843. [
  844. node.text
  845. for node in element.findall(
  846. "*[@class='chunk']", ns_map
  847. )
  848. ]
  849. ),
  850. metadata={
  851. # Add text of specified headers to
  852. # metadata using header mapping.
  853. header_mapping[node.tag]: node.text
  854. for node in filter(
  855. lambda x: x.tag in header_filter,
  856. element.findall(
  857. "*[@class='headers']/*", ns_map
  858. ),
  859. )
  860. },
  861. )
  862. )
  863. if not self.return_each_element:
  864. return self.aggregate_elements_to_chunks(elements)
  865. else:
  866. return [
  867. SplitterDocument(
  868. page_content=chunk["content"], metadata=chunk["metadata"]
  869. )
  870. for chunk in elements
  871. ]
  872. # should be in newer Python versions (3.11+)
  873. # @dataclass(frozen=True, kw_only=True, slots=True)
  874. @dataclass(frozen=True)
  875. class Tokenizer:
  876. """Tokenizer data class."""
  877. chunk_overlap: int
  878. """Overlap in tokens between chunks."""
  879. tokens_per_chunk: int
  880. """Maximum number of tokens per chunk."""
  881. decode: Callable[[list[int]], str]
  882. """Function to decode a list of token ids to a string."""
  883. encode: Callable[[str], list[int]]
  884. """Function to encode a string to a list of token ids."""
  885. def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
  886. """Split incoming text and return chunks using tokenizer."""
  887. splits: list[str] = []
  888. input_ids = tokenizer.encode(text)
  889. start_idx = 0
  890. cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
  891. chunk_ids = input_ids[start_idx:cur_idx]
  892. while start_idx < len(input_ids):
  893. splits.append(tokenizer.decode(chunk_ids))
  894. if cur_idx == len(input_ids):
  895. break
  896. start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
  897. cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
  898. chunk_ids = input_ids[start_idx:cur_idx]
  899. return splits
  900. class TokenTextSplitter(TextSplitter):
  901. """Splitting text to tokens using model tokenizer."""
  902. def __init__(
  903. self,
  904. encoding_name: str = "gpt2",
  905. model: Optional[str] = None,
  906. allowed_special: Literal["all"] | AbstractSet[str] = set(),
  907. disallowed_special: Literal["all"] | Collection[str] = "all",
  908. **kwargs: Any,
  909. ) -> None:
  910. """Create a new TextSplitter."""
  911. super().__init__(**kwargs)
  912. try:
  913. import tiktoken
  914. except ImportError:
  915. raise ImportError(
  916. "Could not import tiktoken python package. "
  917. "This is needed in order to for TokenTextSplitter. "
  918. "Please install it with `pip install tiktoken`."
  919. ) from None
  920. if model is not None:
  921. enc = tiktoken.encoding_for_model(model)
  922. else:
  923. enc = tiktoken.get_encoding(encoding_name)
  924. self._tokenizer = enc
  925. self._allowed_special = allowed_special
  926. self._disallowed_special = disallowed_special
  927. def split_text(self, text: str) -> list[str]:
  928. def _encode(_text: str) -> list[int]:
  929. return self._tokenizer.encode(
  930. _text,
  931. allowed_special=self._allowed_special,
  932. disallowed_special=self._disallowed_special,
  933. )
  934. tokenizer = Tokenizer(
  935. chunk_overlap=self._chunk_overlap,
  936. tokens_per_chunk=self._chunk_size,
  937. decode=self._tokenizer.decode,
  938. encode=_encode,
  939. )
  940. return split_text_on_tokens(text=text, tokenizer=tokenizer)
  941. class SentenceTransformersTokenTextSplitter(TextSplitter):
  942. """Splitting text to tokens using sentence model tokenizer."""
  943. def __init__(
  944. self,
  945. chunk_overlap: int = 50,
  946. model: str = "sentence-transformers/all-mpnet-base-v2",
  947. tokens_per_chunk: Optional[int] = None,
  948. **kwargs: Any,
  949. ) -> None:
  950. """Create a new TextSplitter."""
  951. super().__init__(**kwargs, chunk_overlap=chunk_overlap)
  952. try:
  953. from sentence_transformers import SentenceTransformer
  954. except ImportError:
  955. raise ImportError(
  956. """Could not import sentence_transformer python package.
  957. This is needed in order to for
  958. SentenceTransformersTokenTextSplitter.
  959. Please install it with `pip install sentence-transformers`.
  960. """
  961. ) from None
  962. self.model = model
  963. self._model = SentenceTransformer(self.model, trust_remote_code=True)
  964. self.tokenizer = self._model.tokenizer
  965. self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
  966. def _initialize_chunk_configuration(
  967. self, *, tokens_per_chunk: Optional[int]
  968. ) -> None:
  969. self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
  970. if tokens_per_chunk is None:
  971. self.tokens_per_chunk = self.maximum_tokens_per_chunk
  972. else:
  973. self.tokens_per_chunk = tokens_per_chunk
  974. if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
  975. raise ValueError(
  976. f"The token limit of the models '{self.model}'"
  977. f" is: {self.maximum_tokens_per_chunk}."
  978. f" Argument tokens_per_chunk={self.tokens_per_chunk}"
  979. f" > maximum token limit."
  980. )
  981. def split_text(self, text: str) -> list[str]:
  982. def encode_strip_start_and_stop_token_ids(text: str) -> list[int]:
  983. return self._encode(text)[1:-1]
  984. tokenizer = Tokenizer(
  985. chunk_overlap=self._chunk_overlap,
  986. tokens_per_chunk=self.tokens_per_chunk,
  987. decode=self.tokenizer.decode,
  988. encode=encode_strip_start_and_stop_token_ids,
  989. )
  990. return split_text_on_tokens(text=text, tokenizer=tokenizer)
  991. def count_tokens(self, *, text: str) -> int:
  992. return len(self._encode(text))
  993. _max_length_equal_32_bit_integer: int = 2**32
  994. def _encode(self, text: str) -> list[int]:
  995. token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
  996. text,
  997. max_length=self._max_length_equal_32_bit_integer,
  998. truncation="do_not_truncate",
  999. )
  1000. return token_ids_with_start_and_end_token_ids
  1001. class Language(str, Enum):
  1002. """Enum of the programming languages."""
  1003. CPP = "cpp"
  1004. GO = "go"
  1005. JAVA = "java"
  1006. KOTLIN = "kotlin"
  1007. JS = "js"
  1008. TS = "ts"
  1009. PHP = "php"
  1010. PROTO = "proto"
  1011. PYTHON = "python"
  1012. RST = "rst"
  1013. RUBY = "ruby"
  1014. RUST = "rust"
  1015. SCALA = "scala"
  1016. SWIFT = "swift"
  1017. MARKDOWN = "markdown"
  1018. LATEX = "latex"
  1019. HTML = "html"
  1020. SOL = "sol"
  1021. CSHARP = "csharp"
  1022. COBOL = "cobol"
  1023. C = "c"
  1024. LUA = "lua"
  1025. PERL = "perl"
  1026. class RecursiveCharacterTextSplitter(TextSplitter):
  1027. """Splitting text by recursively look at characters.
  1028. Recursively tries to split by different characters to find one that works.
  1029. """
  1030. def __init__(
  1031. self,
  1032. separators: Optional[list[str]] = None,
  1033. keep_separator: bool = True,
  1034. is_separator_regex: bool = False,
  1035. chunk_size: int = 4000,
  1036. chunk_overlap: int = 200,
  1037. **kwargs: Any,
  1038. ) -> None:
  1039. """Create a new TextSplitter."""
  1040. super().__init__(
  1041. chunk_size=chunk_size,
  1042. chunk_overlap=chunk_overlap,
  1043. keep_separator=keep_separator,
  1044. **kwargs,
  1045. )
  1046. self._separators = separators or ["\n\n", "\n", " ", ""]
  1047. self._is_separator_regex = is_separator_regex
  1048. self.chunk_size = chunk_size
  1049. self.chunk_overlap = chunk_overlap
  1050. def _split_text(self, text: str, separators: list[str]) -> list[str]:
  1051. """Split incoming text and return chunks."""
  1052. final_chunks = []
  1053. # Get appropriate separator to use
  1054. separator = separators[-1]
  1055. new_separators = []
  1056. for i, _s in enumerate(separators):
  1057. _separator = _s if self._is_separator_regex else re.escape(_s)
  1058. if _s == "":
  1059. separator = _s
  1060. break
  1061. if re.search(_separator, text):
  1062. separator = _s
  1063. new_separators = separators[i + 1 :]
  1064. break
  1065. _separator = (
  1066. separator if self._is_separator_regex else re.escape(separator)
  1067. )
  1068. splits = _split_text_with_regex(text, _separator, self._keep_separator)
  1069. # Now go merging things, recursively splitting longer texts.
  1070. _good_splits = []
  1071. _separator = "" if self._keep_separator else separator
  1072. for s in splits:
  1073. if self._length_function(s) < self._chunk_size:
  1074. _good_splits.append(s)
  1075. else:
  1076. if _good_splits:
  1077. merged_text = self._merge_splits(_good_splits, _separator)
  1078. final_chunks.extend(merged_text)
  1079. _good_splits = []
  1080. if not new_separators:
  1081. final_chunks.append(s)
  1082. else:
  1083. other_info = self._split_text(s, new_separators)
  1084. final_chunks.extend(other_info)
  1085. if _good_splits:
  1086. merged_text = self._merge_splits(_good_splits, _separator)
  1087. final_chunks.extend(merged_text)
  1088. return final_chunks
  1089. def split_text(self, text: str) -> list[str]:
  1090. return self._split_text(text, self._separators)
  1091. @classmethod
  1092. def from_language(
  1093. cls, language: Language, **kwargs: Any
  1094. ) -> RecursiveCharacterTextSplitter:
  1095. separators = cls.get_separators_for_language(language)
  1096. return cls(separators=separators, is_separator_regex=True, **kwargs)
  1097. @staticmethod
  1098. def get_separators_for_language(language: Language) -> list[str]:
  1099. if language == Language.CPP:
  1100. return [
  1101. # Split along class definitions
  1102. "\nclass ",
  1103. # Split along function definitions
  1104. "\nvoid ",
  1105. "\nint ",
  1106. "\nfloat ",
  1107. "\ndouble ",
  1108. # Split along control flow statements
  1109. "\nif ",
  1110. "\nfor ",
  1111. "\nwhile ",
  1112. "\nswitch ",
  1113. "\ncase ",
  1114. # Split by the normal type of lines
  1115. "\n\n",
  1116. "\n",
  1117. " ",
  1118. "",
  1119. ]
  1120. elif language == Language.GO:
  1121. return [
  1122. # Split along function definitions
  1123. "\nfunc ",
  1124. "\nvar ",
  1125. "\nconst ",
  1126. "\ntype ",
  1127. # Split along control flow statements
  1128. "\nif ",
  1129. "\nfor ",
  1130. "\nswitch ",
  1131. "\ncase ",
  1132. # Split by the normal type of lines
  1133. "\n\n",
  1134. "\n",
  1135. " ",
  1136. "",
  1137. ]
  1138. elif language == Language.JAVA:
  1139. return [
  1140. # Split along class definitions
  1141. "\nclass ",
  1142. # Split along method definitions
  1143. "\npublic ",
  1144. "\nprotected ",
  1145. "\nprivate ",
  1146. "\nstatic ",
  1147. # Split along control flow statements
  1148. "\nif ",
  1149. "\nfor ",
  1150. "\nwhile ",
  1151. "\nswitch ",
  1152. "\ncase ",
  1153. # Split by the normal type of lines
  1154. "\n\n",
  1155. "\n",
  1156. " ",
  1157. "",
  1158. ]
  1159. elif language == Language.KOTLIN:
  1160. return [
  1161. # Split along class definitions
  1162. "\nclass ",
  1163. # Split along method definitions
  1164. "\npublic ",
  1165. "\nprotected ",
  1166. "\nprivate ",
  1167. "\ninternal ",
  1168. "\ncompanion ",
  1169. "\nfun ",
  1170. "\nval ",
  1171. "\nvar ",
  1172. # Split along control flow statements
  1173. "\nif ",
  1174. "\nfor ",
  1175. "\nwhile ",
  1176. "\nwhen ",
  1177. "\ncase ",
  1178. "\nelse ",
  1179. # Split by the normal type of lines
  1180. "\n\n",
  1181. "\n",
  1182. " ",
  1183. "",
  1184. ]
  1185. elif language == Language.JS:
  1186. return [
  1187. # Split along function definitions
  1188. "\nfunction ",
  1189. "\nconst ",
  1190. "\nlet ",
  1191. "\nvar ",
  1192. "\nclass ",
  1193. # Split along control flow statements
  1194. "\nif ",
  1195. "\nfor ",
  1196. "\nwhile ",
  1197. "\nswitch ",
  1198. "\ncase ",
  1199. "\ndefault ",
  1200. # Split by the normal type of lines
  1201. "\n\n",
  1202. "\n",
  1203. " ",
  1204. "",
  1205. ]
  1206. elif language == Language.TS:
  1207. return [
  1208. "\nenum ",
  1209. "\ninterface ",
  1210. "\nnamespace ",
  1211. "\ntype ",
  1212. # Split along class definitions
  1213. "\nclass ",
  1214. # Split along function definitions
  1215. "\nfunction ",
  1216. "\nconst ",
  1217. "\nlet ",
  1218. "\nvar ",
  1219. # Split along control flow statements
  1220. "\nif ",
  1221. "\nfor ",
  1222. "\nwhile ",
  1223. "\nswitch ",
  1224. "\ncase ",
  1225. "\ndefault ",
  1226. # Split by the normal type of lines
  1227. "\n\n",
  1228. "\n",
  1229. " ",
  1230. "",
  1231. ]
  1232. elif language == Language.PHP:
  1233. return [
  1234. # Split along function definitions
  1235. "\nfunction ",
  1236. # Split along class definitions
  1237. "\nclass ",
  1238. # Split along control flow statements
  1239. "\nif ",
  1240. "\nforeach ",
  1241. "\nwhile ",
  1242. "\ndo ",
  1243. "\nswitch ",
  1244. "\ncase ",
  1245. # Split by the normal type of lines
  1246. "\n\n",
  1247. "\n",
  1248. " ",
  1249. "",
  1250. ]
  1251. elif language == Language.PROTO:
  1252. return [
  1253. # Split along message definitions
  1254. "\nmessage ",
  1255. # Split along service definitions
  1256. "\nservice ",
  1257. # Split along enum definitions
  1258. "\nenum ",
  1259. # Split along option definitions
  1260. "\noption ",
  1261. # Split along import statements
  1262. "\nimport ",
  1263. # Split along syntax declarations
  1264. "\nsyntax ",
  1265. # Split by the normal type of lines
  1266. "\n\n",
  1267. "\n",
  1268. " ",
  1269. "",
  1270. ]
  1271. elif language == Language.PYTHON:
  1272. return [
  1273. # First, try to split along class definitions
  1274. "\nclass ",
  1275. "\ndef ",
  1276. "\n\tdef ",
  1277. # Now split by the normal type of lines
  1278. "\n\n",
  1279. "\n",
  1280. " ",
  1281. "",
  1282. ]
  1283. elif language == Language.RST:
  1284. return [
  1285. # Split along section titles
  1286. "\n=+\n",
  1287. "\n-+\n",
  1288. "\n\\*+\n",
  1289. # Split along directive markers
  1290. "\n\n.. *\n\n",
  1291. # Split by the normal type of lines
  1292. "\n\n",
  1293. "\n",
  1294. " ",
  1295. "",
  1296. ]
  1297. elif language == Language.RUBY:
  1298. return [
  1299. # Split along method definitions
  1300. "\ndef ",
  1301. "\nclass ",
  1302. # Split along control flow statements
  1303. "\nif ",
  1304. "\nunless ",
  1305. "\nwhile ",
  1306. "\nfor ",
  1307. "\ndo ",
  1308. "\nbegin ",
  1309. "\nrescue ",
  1310. # Split by the normal type of lines
  1311. "\n\n",
  1312. "\n",
  1313. " ",
  1314. "",
  1315. ]
  1316. elif language == Language.RUST:
  1317. return [
  1318. # Split along function definitions
  1319. "\nfn ",
  1320. "\nconst ",
  1321. "\nlet ",
  1322. # Split along control flow statements
  1323. "\nif ",
  1324. "\nwhile ",
  1325. "\nfor ",
  1326. "\nloop ",
  1327. "\nmatch ",
  1328. "\nconst ",
  1329. # Split by the normal type of lines
  1330. "\n\n",
  1331. "\n",
  1332. " ",
  1333. "",
  1334. ]
  1335. elif language == Language.SCALA:
  1336. return [
  1337. # Split along class definitions
  1338. "\nclass ",
  1339. "\nobject ",
  1340. # Split along method definitions
  1341. "\ndef ",
  1342. "\nval ",
  1343. "\nvar ",
  1344. # Split along control flow statements
  1345. "\nif ",
  1346. "\nfor ",
  1347. "\nwhile ",
  1348. "\nmatch ",
  1349. "\ncase ",
  1350. # Split by the normal type of lines
  1351. "\n\n",
  1352. "\n",
  1353. " ",
  1354. "",
  1355. ]
  1356. elif language == Language.SWIFT:
  1357. return [
  1358. # Split along function definitions
  1359. "\nfunc ",
  1360. # Split along class definitions
  1361. "\nclass ",
  1362. "\nstruct ",
  1363. "\nenum ",
  1364. # Split along control flow statements
  1365. "\nif ",
  1366. "\nfor ",
  1367. "\nwhile ",
  1368. "\ndo ",
  1369. "\nswitch ",
  1370. "\ncase ",
  1371. # Split by the normal type of lines
  1372. "\n\n",
  1373. "\n",
  1374. " ",
  1375. "",
  1376. ]
  1377. elif language == Language.MARKDOWN:
  1378. return [
  1379. # First, try to split along Markdown headings
  1380. # (starting with level 2)
  1381. "\n#{1,6} ",
  1382. # Note the alternative syntax for headings (below)
  1383. # is not handled here
  1384. # Heading level 2
  1385. # ---------------
  1386. # End of code block
  1387. "```\n",
  1388. # Horizontal lines
  1389. "\n\\*\\*\\*+\n",
  1390. "\n---+\n",
  1391. "\n___+\n",
  1392. # Note that this splitter doesn't handle
  1393. # horizontal lines defined
  1394. # by *three or more* of ***, ---, or ___,
  1395. # but this is not handled
  1396. "\n\n",
  1397. "\n",
  1398. " ",
  1399. "",
  1400. ]
  1401. elif language == Language.LATEX:
  1402. return [
  1403. # First, try to split along Latex sections
  1404. "\n\\\\chapter{",
  1405. "\n\\\\section{",
  1406. "\n\\\\subsection{",
  1407. "\n\\\\subsubsection{",
  1408. # Now split by environments
  1409. "\n\\\\begin{enumerate}",
  1410. "\n\\\\begin{itemize}",
  1411. "\n\\\\begin{description}",
  1412. "\n\\\\begin{list}",
  1413. "\n\\\\begin{quote}",
  1414. "\n\\\\begin{quotation}",
  1415. "\n\\\\begin{verse}",
  1416. "\n\\\\begin{verbatim}",
  1417. # Now split by math environments
  1418. "\n\\\begin{align}",
  1419. "$$",
  1420. "$",
  1421. # Now split by the normal type of lines
  1422. " ",
  1423. "",
  1424. ]
  1425. elif language == Language.HTML:
  1426. return [
  1427. # First, try to split along HTML tags
  1428. "<body",
  1429. "<div",
  1430. "<p",
  1431. "<br",
  1432. "<li",
  1433. "<h1",
  1434. "<h2",
  1435. "<h3",
  1436. "<h4",
  1437. "<h5",
  1438. "<h6",
  1439. "<span",
  1440. "<table",
  1441. "<tr",
  1442. "<td",
  1443. "<th",
  1444. "<ul",
  1445. "<ol",
  1446. "<header",
  1447. "<footer",
  1448. "<nav",
  1449. # Head
  1450. "<head",
  1451. "<style",
  1452. "<script",
  1453. "<meta",
  1454. "<title",
  1455. "",
  1456. ]
  1457. elif language == Language.CSHARP:
  1458. return [
  1459. "\ninterface ",
  1460. "\nenum ",
  1461. "\nimplements ",
  1462. "\ndelegate ",
  1463. "\nevent ",
  1464. # Split along class definitions
  1465. "\nclass ",
  1466. "\nabstract ",
  1467. # Split along method definitions
  1468. "\npublic ",
  1469. "\nprotected ",
  1470. "\nprivate ",
  1471. "\nstatic ",
  1472. "\nreturn ",
  1473. # Split along control flow statements
  1474. "\nif ",
  1475. "\ncontinue ",
  1476. "\nfor ",
  1477. "\nforeach ",
  1478. "\nwhile ",
  1479. "\nswitch ",
  1480. "\nbreak ",
  1481. "\ncase ",
  1482. "\nelse ",
  1483. # Split by exceptions
  1484. "\ntry ",
  1485. "\nthrow ",
  1486. "\nfinally ",
  1487. "\ncatch ",
  1488. # Split by the normal type of lines
  1489. "\n\n",
  1490. "\n",
  1491. " ",
  1492. "",
  1493. ]
  1494. elif language == Language.SOL:
  1495. return [
  1496. # Split along compiler information definitions
  1497. "\npragma ",
  1498. "\nusing ",
  1499. # Split along contract definitions
  1500. "\ncontract ",
  1501. "\ninterface ",
  1502. "\nlibrary ",
  1503. # Split along method definitions
  1504. "\nconstructor ",
  1505. "\ntype ",
  1506. "\nfunction ",
  1507. "\nevent ",
  1508. "\nmodifier ",
  1509. "\nerror ",
  1510. "\nstruct ",
  1511. "\nenum ",
  1512. # Split along control flow statements
  1513. "\nif ",
  1514. "\nfor ",
  1515. "\nwhile ",
  1516. "\ndo while ",
  1517. "\nassembly ",
  1518. # Split by the normal type of lines
  1519. "\n\n",
  1520. "\n",
  1521. " ",
  1522. "",
  1523. ]
  1524. elif language == Language.COBOL:
  1525. return [
  1526. # Split along divisions
  1527. "\nIDENTIFICATION DIVISION.",
  1528. "\nENVIRONMENT DIVISION.",
  1529. "\nDATA DIVISION.",
  1530. "\nPROCEDURE DIVISION.",
  1531. # Split along sections within DATA DIVISION
  1532. "\nWORKING-STORAGE SECTION.",
  1533. "\nLINKAGE SECTION.",
  1534. "\nFILE SECTION.",
  1535. # Split along sections within PROCEDURE DIVISION
  1536. "\nINPUT-OUTPUT SECTION.",
  1537. # Split along paragraphs and common statements
  1538. "\nOPEN ",
  1539. "\nCLOSE ",
  1540. "\nREAD ",
  1541. "\nWRITE ",
  1542. "\nIF ",
  1543. "\nELSE ",
  1544. "\nMOVE ",
  1545. "\nPERFORM ",
  1546. "\nUNTIL ",
  1547. "\nVARYING ",
  1548. "\nACCEPT ",
  1549. "\nDISPLAY ",
  1550. "\nSTOP RUN.",
  1551. # Split by the normal type of lines
  1552. "\n",
  1553. " ",
  1554. "",
  1555. ]
  1556. else:
  1557. raise ValueError(
  1558. f"Language {language} is not supported! "
  1559. f"Please choose from {list(Language)}"
  1560. )
  1561. class NLTKTextSplitter(TextSplitter):
  1562. """Splitting text using NLTK package."""
  1563. def __init__(
  1564. self, separator: str = "\n\n", language: str = "english", **kwargs: Any
  1565. ) -> None:
  1566. """Initialize the NLTK splitter."""
  1567. super().__init__(**kwargs)
  1568. try:
  1569. from nltk.tokenize import sent_tokenize
  1570. self._tokenizer = sent_tokenize
  1571. except ImportError:
  1572. raise ImportError("""NLTK is not installed, please install it with
  1573. `pip install nltk`.""") from None
  1574. self._separator = separator
  1575. self._language = language
  1576. def split_text(self, text: str) -> list[str]:
  1577. """Split incoming text and return chunks."""
  1578. # First we naively split the large input into a bunch of smaller ones.
  1579. splits = self._tokenizer(text, language=self._language)
  1580. return self._merge_splits(splits, self._separator)
  1581. class SpacyTextSplitter(TextSplitter):
  1582. """Splitting text using Spacy package.
  1583. Per default, Spacy's `en_core_web_sm` model is used and
  1584. its default max_length is 1000000 (it is the length of maximum character
  1585. this model takes which can be increased for large files). For a faster,
  1586. but potentially less accurate splitting, you can use `pipe='sentencizer'`.
  1587. """
  1588. def __init__(
  1589. self,
  1590. separator: str = "\n\n",
  1591. pipe: str = "en_core_web_sm",
  1592. max_length: int = 1_000_000,
  1593. **kwargs: Any,
  1594. ) -> None:
  1595. """Initialize the spacy text splitter."""
  1596. super().__init__(**kwargs)
  1597. self._tokenizer = _make_spacy_pipe_for_splitting(
  1598. pipe, max_length=max_length
  1599. )
  1600. self._separator = separator
  1601. def split_text(self, text: str) -> list[str]:
  1602. """Split incoming text and return chunks."""
  1603. splits = (s.text for s in self._tokenizer(text).sents)
  1604. return self._merge_splits(splits, self._separator)
  1605. class KonlpyTextSplitter(TextSplitter):
  1606. """Splitting text using Konlpy package.
  1607. It is good for splitting Korean text.
  1608. """
  1609. def __init__(
  1610. self,
  1611. separator: str = "\n\n",
  1612. **kwargs: Any,
  1613. ) -> None:
  1614. """Initialize the Konlpy text splitter."""
  1615. super().__init__(**kwargs)
  1616. self._separator = separator
  1617. try:
  1618. from konlpy.tag import Kkma
  1619. except ImportError:
  1620. raise ImportError("""
  1621. Konlpy is not installed, please install it with
  1622. `pip install konlpy`
  1623. """) from None
  1624. self.kkma = Kkma()
  1625. def split_text(self, text: str) -> list[str]:
  1626. """Split incoming text and return chunks."""
  1627. splits = self.kkma.sentences(text)
  1628. return self._merge_splits(splits, self._separator)
  1629. # For backwards compatibility
  1630. class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
  1631. """Attempts to split the text along Python syntax."""
  1632. def __init__(self, **kwargs: Any) -> None:
  1633. """Initialize a PythonCodeTextSplitter."""
  1634. separators = self.get_separators_for_language(Language.PYTHON)
  1635. super().__init__(separators=separators, **kwargs)
  1636. class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
  1637. """Attempts to split the text along Markdown-formatted headings."""
  1638. def __init__(self, **kwargs: Any) -> None:
  1639. """Initialize a MarkdownTextSplitter."""
  1640. separators = self.get_separators_for_language(Language.MARKDOWN)
  1641. super().__init__(separators=separators, **kwargs)
  1642. class LatexTextSplitter(RecursiveCharacterTextSplitter):
  1643. """Attempts to split the text along Latex-formatted layout elements."""
  1644. def __init__(self, **kwargs: Any) -> None:
  1645. """Initialize a LatexTextSplitter."""
  1646. separators = self.get_separators_for_language(Language.LATEX)
  1647. super().__init__(separators=separators, **kwargs)
  1648. class RecursiveJsonSplitter:
  1649. def __init__(
  1650. self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None
  1651. ):
  1652. super().__init__()
  1653. self.max_chunk_size = max_chunk_size
  1654. self.min_chunk_size = (
  1655. min_chunk_size
  1656. if min_chunk_size is not None
  1657. else max(max_chunk_size - 200, 50)
  1658. )
  1659. @staticmethod
  1660. def _json_size(data: dict) -> int:
  1661. """Calculate the size of the serialized JSON object."""
  1662. return len(json.dumps(data))
  1663. @staticmethod
  1664. def _set_nested_dict(d: dict, path: list[str], value: Any) -> None:
  1665. """Set a value in a nested dictionary based on the given path."""
  1666. for key in path[:-1]:
  1667. d = d.setdefault(key, {})
  1668. d[path[-1]] = value
  1669. def _list_to_dict_preprocessing(self, data: Any) -> Any:
  1670. if isinstance(data, dict):
  1671. # Process each key-value pair in the dictionary
  1672. return {
  1673. k: self._list_to_dict_preprocessing(v) for k, v in data.items()
  1674. }
  1675. elif isinstance(data, list):
  1676. # Convert the list to a dictionary with index-based keys
  1677. return {
  1678. str(i): self._list_to_dict_preprocessing(item)
  1679. for i, item in enumerate(data)
  1680. }
  1681. else:
  1682. # The item is neither a dict nor a list, return unchanged
  1683. return data
  1684. def _json_split(
  1685. self,
  1686. data: dict[str, Any],
  1687. current_path: list[str] | None = None,
  1688. chunks: list[dict] | None = None,
  1689. ) -> list[dict]:
  1690. """Split json into maximum size dictionaries while preserving
  1691. structure."""
  1692. if current_path is None:
  1693. current_path = []
  1694. if chunks is None:
  1695. chunks = [{}]
  1696. if isinstance(data, dict):
  1697. for key, value in data.items():
  1698. new_path = current_path + [key]
  1699. chunk_size = self._json_size(chunks[-1])
  1700. size = self._json_size({key: value})
  1701. remaining = self.max_chunk_size - chunk_size
  1702. if size < remaining:
  1703. # Add item to current chunk
  1704. self._set_nested_dict(chunks[-1], new_path, value)
  1705. else:
  1706. if chunk_size >= self.min_chunk_size:
  1707. # Chunk is big enough, start a new chunk
  1708. chunks.append({})
  1709. # Iterate
  1710. self._json_split(value, new_path, chunks)
  1711. else:
  1712. # handle single item
  1713. self._set_nested_dict(chunks[-1], current_path, data)
  1714. return chunks
  1715. def split_json(
  1716. self,
  1717. json_data: dict[str, Any],
  1718. convert_lists: bool = False,
  1719. ) -> list[dict]:
  1720. """Splits JSON into a list of JSON chunks."""
  1721. if convert_lists:
  1722. chunks = self._json_split(
  1723. self._list_to_dict_preprocessing(json_data)
  1724. )
  1725. else:
  1726. chunks = self._json_split(json_data)
  1727. # Remove the last chunk if it's empty
  1728. if not chunks[-1]:
  1729. chunks.pop()
  1730. return chunks
  1731. def split_text(
  1732. self, json_data: dict[str, Any], convert_lists: bool = False
  1733. ) -> list[str]:
  1734. """Splits JSON into a list of JSON formatted strings."""
  1735. chunks = self.split_json(
  1736. json_data=json_data, convert_lists=convert_lists
  1737. )
  1738. # Convert to string
  1739. return [json.dumps(chunk) for chunk in chunks]
  1740. def create_documents(
  1741. self,
  1742. texts: list[dict],
  1743. convert_lists: bool = False,
  1744. metadatas: Optional[list[dict]] = None,
  1745. ) -> list[SplitterDocument]:
  1746. """Create documents from a list of json objects (dict)."""
  1747. _metadatas = metadatas or [{}] * len(texts)
  1748. documents = []
  1749. for i, text in enumerate(texts):
  1750. for chunk in self.split_text(
  1751. json_data=text, convert_lists=convert_lists
  1752. ):
  1753. metadata = copy.deepcopy(_metadatas[i])
  1754. new_doc = SplitterDocument(
  1755. page_content=chunk, metadata=metadata
  1756. )
  1757. documents.append(new_doc)
  1758. return documents