text.py 66 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000
  1. # Source - LangChain
  2. # URL: https://github.com/langchain-ai/langchain/blob/6a5b084704afa22ca02f78d0464f35aed75d1ff2/libs/langchain/langchain/text_splitter.py#L851
  3. """**Text Splitters** are classes for splitting text.
  4. **Class hierarchy:**
  5. .. code-block::
  6. BaseDocumentTransformer --> TextSplitter --> <name>TextSplitter # Example: CharacterTextSplitter
  7. RecursiveCharacterTextSplitter --> <name>TextSplitter
  8. Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter.
  9. **Main helpers:**
  10. .. code-block::
  11. Document, Tokenizer, Language, LineType, HeaderType
  12. """ # noqa: E501
  13. from __future__ import annotations
  14. import copy
  15. import json
  16. import logging
  17. import pathlib
  18. import re
  19. from abc import ABC, abstractmethod
  20. from dataclasses import dataclass
  21. from enum import Enum
  22. from io import BytesIO, StringIO
  23. from typing import (
  24. AbstractSet,
  25. Any,
  26. Callable,
  27. Collection,
  28. Dict,
  29. Iterable,
  30. List,
  31. Literal,
  32. Optional,
  33. Sequence,
  34. Tuple,
  35. Type,
  36. TypedDict,
  37. TypeVar,
  38. Union,
  39. cast,
  40. )
  41. import requests
  42. from pydantic import BaseModel, Field, PrivateAttr
  43. from typing_extensions import NotRequired
  44. logger = logging.getLogger()
  45. TS = TypeVar("TS", bound="TextSplitter")
  46. class BaseSerialized(TypedDict):
  47. """Base class for serialized objects."""
  48. lc: int
  49. id: List[str]
  50. name: NotRequired[str]
  51. graph: NotRequired[Dict[str, Any]]
  52. class SerializedConstructor(BaseSerialized):
  53. """Serialized constructor."""
  54. type: Literal["constructor"]
  55. kwargs: Dict[str, Any]
  56. class SerializedSecret(BaseSerialized):
  57. """Serialized secret."""
  58. type: Literal["secret"]
  59. class SerializedNotImplemented(BaseSerialized):
  60. """Serialized not implemented."""
  61. type: Literal["not_implemented"]
  62. repr: Optional[str]
  63. def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
  64. """Try to determine if a value is different from the default.
  65. Args:
  66. value: The value.
  67. key: The key.
  68. model: The model.
  69. Returns:
  70. Whether the value is different from the default.
  71. """
  72. try:
  73. return model.__fields__[key].get_default() != value
  74. except Exception:
  75. return True
  76. class Serializable(BaseModel, ABC):
  77. """Serializable base class."""
  78. @classmethod
  79. def is_lc_serializable(cls) -> bool:
  80. """Is this class serializable?"""
  81. return False
  82. @classmethod
  83. def get_lc_namespace(cls) -> List[str]:
  84. """Get the namespace of the langchain object.
  85. For example, if the class is `langchain.llms.openai.OpenAI`, then the
  86. namespace is ["langchain", "llms", "openai"]
  87. """
  88. return cls.__module__.split(".")
  89. @property
  90. def lc_secrets(self) -> Dict[str, str]:
  91. """A map of constructor argument names to secret ids.
  92. For example,
  93. {"openai_api_key": "OPENAI_API_KEY"}
  94. """
  95. return dict()
  96. @property
  97. def lc_attributes(self) -> Dict:
  98. """List of attribute names that should be included in the serialized kwargs.
  99. These attributes must be accepted by the constructor.
  100. """
  101. return {}
  102. @classmethod
  103. def lc_id(cls) -> List[str]:
  104. """A unique identifier for this class for serialization purposes.
  105. The unique identifier is a list of strings that describes the path
  106. to the object.
  107. """
  108. return [*cls.get_lc_namespace(), cls.__name__]
  109. class Config:
  110. extra = "ignore"
  111. def __repr_args__(self) -> Any:
  112. return [
  113. (k, v)
  114. for k, v in super().__repr_args__()
  115. if (k not in self.__fields__ or try_neq_default(v, k, self))
  116. ]
  117. _lc_kwargs = PrivateAttr(default_factory=dict)
  118. def __init__(self, **kwargs: Any) -> None:
  119. super().__init__(**kwargs)
  120. self._lc_kwargs = kwargs
  121. def to_json(
  122. self,
  123. ) -> Union[SerializedConstructor, SerializedNotImplemented]:
  124. if not self.is_lc_serializable():
  125. return self.to_json_not_implemented()
  126. secrets = dict()
  127. # Get latest values for kwargs if there is an attribute with same name
  128. lc_kwargs = {
  129. k: getattr(self, k, v)
  130. for k, v in self._lc_kwargs.items()
  131. if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
  132. }
  133. # Merge the lc_secrets and lc_attributes from every class in the MRO
  134. for cls in [None, *self.__class__.mro()]:
  135. # Once we get to Serializable, we're done
  136. if cls is Serializable:
  137. break
  138. if cls:
  139. deprecated_attributes = [
  140. "lc_namespace",
  141. "lc_serializable",
  142. ]
  143. for attr in deprecated_attributes:
  144. if hasattr(cls, attr):
  145. raise ValueError(
  146. f"Class {self.__class__} has a deprecated "
  147. f"attribute {attr}. Please use the corresponding "
  148. f"classmethod instead."
  149. )
  150. # Get a reference to self bound to each class in the MRO
  151. this = cast(
  152. Serializable, self if cls is None else super(cls, self)
  153. )
  154. secrets.update(this.lc_secrets)
  155. # Now also add the aliases for the secrets
  156. # This ensures known secret aliases are hidden.
  157. # Note: this does NOT hide any other extra kwargs
  158. # that are not present in the fields.
  159. for key in list(secrets):
  160. value = secrets[key]
  161. if key in this.__fields__:
  162. secrets[this.__fields__[key].alias] = value # type: ignore
  163. lc_kwargs.update(this.lc_attributes)
  164. # include all secrets, even if not specified in kwargs
  165. # as these secrets may be passed as an environment variable instead
  166. for key in secrets.keys():
  167. secret_value = getattr(self, key, None) or lc_kwargs.get(key)
  168. if secret_value is not None:
  169. lc_kwargs.update({key: secret_value})
  170. return {
  171. "lc": 1,
  172. "type": "constructor",
  173. "id": self.lc_id(),
  174. "kwargs": (
  175. lc_kwargs
  176. if not secrets
  177. else _replace_secrets(lc_kwargs, secrets)
  178. ),
  179. }
  180. def to_json_not_implemented(self) -> SerializedNotImplemented:
  181. return to_json_not_implemented(self)
  182. def _replace_secrets(
  183. root: Dict[Any, Any], secrets_map: Dict[str, str]
  184. ) -> Dict[Any, Any]:
  185. result = root.copy()
  186. for path, secret_id in secrets_map.items():
  187. [*parts, last] = path.split(".")
  188. current = result
  189. for part in parts:
  190. if part not in current:
  191. break
  192. current[part] = current[part].copy()
  193. current = current[part]
  194. if last in current:
  195. current[last] = {
  196. "lc": 1,
  197. "type": "secret",
  198. "id": [secret_id],
  199. }
  200. return result
  201. def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
  202. """Serialize a "not implemented" object.
  203. Args:
  204. obj: object to serialize
  205. Returns:
  206. SerializedNotImplemented
  207. """
  208. _id: List[str] = []
  209. try:
  210. if hasattr(obj, "__name__"):
  211. _id = [*obj.__module__.split("."), obj.__name__]
  212. elif hasattr(obj, "__class__"):
  213. _id = [
  214. *obj.__class__.__module__.split("."),
  215. obj.__class__.__name__,
  216. ]
  217. except Exception:
  218. pass
  219. result: SerializedNotImplemented = {
  220. "lc": 1,
  221. "type": "not_implemented",
  222. "id": _id,
  223. "repr": None,
  224. }
  225. try:
  226. result["repr"] = repr(obj)
  227. except Exception:
  228. pass
  229. return result
  230. class SplitterDocument(Serializable):
  231. """Class for storing a piece of text and associated metadata."""
  232. page_content: str
  233. """String text."""
  234. metadata: dict = Field(default_factory=dict)
  235. """Arbitrary metadata about the page content (e.g., source, relationships to other
  236. documents, etc.).
  237. """
  238. type: Literal["Document"] = "Document"
  239. def __init__(self, page_content: str, **kwargs: Any) -> None:
  240. """Pass page_content in as positional or named arg."""
  241. super().__init__(page_content=page_content, **kwargs)
  242. @classmethod
  243. def is_lc_serializable(cls) -> bool:
  244. """Return whether this class is serializable."""
  245. return True
  246. @classmethod
  247. def get_lc_namespace(cls) -> List[str]:
  248. """Get the namespace of the langchain object."""
  249. return ["langchain", "schema", "document"]
  250. class BaseDocumentTransformer(ABC):
  251. """Abstract base class for document transformation systems.
  252. A document transformation system takes a sequence of Documents and returns a
  253. sequence of transformed Documents.
  254. Example:
  255. .. code-block:: python
  256. class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
  257. embeddings: Embeddings
  258. similarity_fn: Callable = cosine_similarity
  259. similarity_threshold: float = 0.95
  260. class Config:
  261. arbitrary_types_allowed = True
  262. def transform_documents(
  263. self, documents: Sequence[Document], **kwargs: Any
  264. ) -> Sequence[Document]:
  265. stateful_documents = get_stateful_documents(documents)
  266. embedded_documents = _get_embeddings_from_stateful_docs(
  267. self.embeddings, stateful_documents
  268. )
  269. included_idxs = _filter_similar_embeddings(
  270. embedded_documents, self.similarity_fn, self.similarity_threshold
  271. )
  272. return [stateful_documents[i] for i in sorted(included_idxs)]
  273. async def atransform_documents(
  274. self, documents: Sequence[Document], **kwargs: Any
  275. ) -> Sequence[Document]:
  276. raise NotImplementedError
  277. """ # noqa: E501
  278. @abstractmethod
  279. def transform_documents(
  280. self, documents: Sequence[SplitterDocument], **kwargs: Any
  281. ) -> Sequence[SplitterDocument]:
  282. """Transform a list of documents.
  283. Args:
  284. documents: A sequence of Documents to be transformed.
  285. Returns:
  286. A list of transformed Documents.
  287. """
  288. async def atransform_documents(
  289. self, documents: Sequence[SplitterDocument], **kwargs: Any
  290. ) -> Sequence[SplitterDocument]:
  291. """Asynchronously transform a list of documents.
  292. Args:
  293. documents: A sequence of Documents to be transformed.
  294. Returns:
  295. A list of transformed Documents.
  296. """
  297. raise NotImplementedError("This method is not implemented.")
  298. # return await langchain_core.runnables.config.run_in_executor(
  299. # None, self.transform_documents, documents, **kwargs
  300. # )
  301. def _make_spacy_pipe_for_splitting(
  302. pipe: str, *, max_length: int = 1_000_000
  303. ) -> Any: # avoid importing spacy
  304. try:
  305. import spacy
  306. except ImportError:
  307. raise ImportError(
  308. "Spacy is not installed, please install it with `pip install spacy`."
  309. )
  310. if pipe == "sentencizer":
  311. from spacy.lang.en import English
  312. sentencizer = English()
  313. sentencizer.add_pipe("sentencizer")
  314. else:
  315. sentencizer = spacy.load(pipe, exclude=["ner", "tagger"])
  316. sentencizer.max_length = max_length
  317. return sentencizer
  318. def _split_text_with_regex(
  319. text: str, separator: str, keep_separator: bool
  320. ) -> List[str]:
  321. # Now that we have the separator, split the text
  322. if separator:
  323. if keep_separator:
  324. # The parentheses in the pattern keep the delimiters in the result.
  325. _splits = re.split(f"({separator})", text)
  326. splits = [
  327. _splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)
  328. ]
  329. if len(_splits) % 2 == 0:
  330. splits += _splits[-1:]
  331. splits = [_splits[0]] + splits
  332. else:
  333. splits = re.split(separator, text)
  334. else:
  335. splits = list(text)
  336. return [s for s in splits if s != ""]
  337. class TextSplitter(BaseDocumentTransformer, ABC):
  338. """Interface for splitting text into chunks."""
  339. def __init__(
  340. self,
  341. chunk_size: int = 4000,
  342. chunk_overlap: int = 200,
  343. length_function: Callable[[str], int] = len,
  344. keep_separator: bool = False,
  345. add_start_index: bool = False,
  346. strip_whitespace: bool = True,
  347. ) -> None:
  348. """Create a new TextSplitter.
  349. Args:
  350. chunk_size: Maximum size of chunks to return
  351. chunk_overlap: Overlap in characters between chunks
  352. length_function: Function that measures the length of given chunks
  353. keep_separator: Whether to keep the separator in the chunks
  354. add_start_index: If `True`, includes chunk's start index in metadata
  355. strip_whitespace: If `True`, strips whitespace from the start and end of
  356. every document
  357. """
  358. if chunk_overlap > chunk_size:
  359. raise ValueError(
  360. f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
  361. f"({chunk_size}), should be smaller."
  362. )
  363. self._chunk_size = chunk_size
  364. self._chunk_overlap = chunk_overlap
  365. self._length_function = length_function
  366. self._keep_separator = keep_separator
  367. self._add_start_index = add_start_index
  368. self._strip_whitespace = strip_whitespace
  369. @abstractmethod
  370. def split_text(self, text: str) -> List[str]:
  371. """Split text into multiple components."""
  372. def create_documents(
  373. self, texts: List[str], metadatas: Optional[List[dict]] = None
  374. ) -> List[SplitterDocument]:
  375. """Create documents from a list of texts."""
  376. _metadatas = metadatas or [{}] * len(texts)
  377. documents = []
  378. for i, text in enumerate(texts):
  379. index = 0
  380. previous_chunk_len = 0
  381. for chunk in self.split_text(text):
  382. metadata = copy.deepcopy(_metadatas[i])
  383. if self._add_start_index:
  384. offset = index + previous_chunk_len - self._chunk_overlap
  385. index = text.find(chunk, max(0, offset))
  386. metadata["start_index"] = index
  387. previous_chunk_len = len(chunk)
  388. new_doc = SplitterDocument(
  389. page_content=chunk, metadata=metadata
  390. )
  391. documents.append(new_doc)
  392. return documents
  393. def split_documents(
  394. self, documents: Iterable[SplitterDocument]
  395. ) -> List[SplitterDocument]:
  396. """Split documents."""
  397. texts, metadatas = [], []
  398. for doc in documents:
  399. texts.append(doc.page_content)
  400. metadatas.append(doc.metadata)
  401. return self.create_documents(texts, metadatas=metadatas)
  402. def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
  403. text = separator.join(docs)
  404. if self._strip_whitespace:
  405. text = text.strip()
  406. if text == "":
  407. return None
  408. else:
  409. return text
  410. def _merge_splits(
  411. self, splits: Iterable[str], separator: str
  412. ) -> List[str]:
  413. # We now want to combine these smaller pieces into medium size
  414. # chunks to send to the LLM.
  415. separator_len = self._length_function(separator)
  416. docs = []
  417. current_doc: List[str] = []
  418. total = 0
  419. for d in splits:
  420. _len = self._length_function(d)
  421. if (
  422. total + _len + (separator_len if len(current_doc) > 0 else 0)
  423. > self._chunk_size
  424. ):
  425. if total > self._chunk_size:
  426. logger.warning(
  427. f"Created a chunk of size {total}, "
  428. f"which is longer than the specified {self._chunk_size}"
  429. )
  430. if len(current_doc) > 0:
  431. doc = self._join_docs(current_doc, separator)
  432. if doc is not None:
  433. docs.append(doc)
  434. # Keep on popping if:
  435. # - we have a larger chunk than in the chunk overlap
  436. # - or if we still have any chunks and the length is long
  437. while total > self._chunk_overlap or (
  438. total
  439. + _len
  440. + (separator_len if len(current_doc) > 0 else 0)
  441. > self._chunk_size
  442. and total > 0
  443. ):
  444. total -= self._length_function(current_doc[0]) + (
  445. separator_len if len(current_doc) > 1 else 0
  446. )
  447. current_doc = current_doc[1:]
  448. current_doc.append(d)
  449. total += _len + (separator_len if len(current_doc) > 1 else 0)
  450. doc = self._join_docs(current_doc, separator)
  451. if doc is not None:
  452. docs.append(doc)
  453. return docs
  454. @classmethod
  455. def from_huggingface_tokenizer(
  456. cls, tokenizer: Any, **kwargs: Any
  457. ) -> TextSplitter:
  458. """Text splitter that uses HuggingFace tokenizer to count length."""
  459. try:
  460. from transformers import PreTrainedTokenizerBase
  461. if not isinstance(tokenizer, PreTrainedTokenizerBase):
  462. raise ValueError(
  463. "Tokenizer received was not an instance of PreTrainedTokenizerBase"
  464. )
  465. def _huggingface_tokenizer_length(text: str) -> int:
  466. return len(tokenizer.encode(text))
  467. except ImportError:
  468. raise ValueError(
  469. "Could not import transformers python package. "
  470. "Please install it with `pip install transformers`."
  471. )
  472. return cls(length_function=_huggingface_tokenizer_length, **kwargs)
  473. @classmethod
  474. def from_tiktoken_encoder(
  475. cls: Type[TS],
  476. encoding_name: str = "gpt2",
  477. model: Optional[str] = None,
  478. allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
  479. disallowed_special: Union[Literal["all"], Collection[str]] = "all",
  480. **kwargs: Any,
  481. ) -> TS:
  482. """Text splitter that uses tiktoken encoder to count length."""
  483. try:
  484. import tiktoken
  485. except ImportError:
  486. raise ImportError(
  487. "Could not import tiktoken python package. "
  488. "This is needed in order to calculate max_tokens_for_prompt. "
  489. "Please install it with `pip install tiktoken`."
  490. )
  491. if model is not None:
  492. enc = tiktoken.encoding_for_model(model)
  493. else:
  494. enc = tiktoken.get_encoding(encoding_name)
  495. def _tiktoken_encoder(text: str) -> int:
  496. return len(
  497. enc.encode(
  498. text,
  499. allowed_special=allowed_special,
  500. disallowed_special=disallowed_special,
  501. )
  502. )
  503. if issubclass(cls, TokenTextSplitter):
  504. extra_kwargs = {
  505. "encoding_name": encoding_name,
  506. "model": model,
  507. "allowed_special": allowed_special,
  508. "disallowed_special": disallowed_special,
  509. }
  510. kwargs = {**kwargs, **extra_kwargs}
  511. return cls(length_function=_tiktoken_encoder, **kwargs)
  512. def transform_documents(
  513. self, documents: Sequence[SplitterDocument], **kwargs: Any
  514. ) -> Sequence[SplitterDocument]:
  515. """Transform sequence of documents by splitting them."""
  516. return self.split_documents(list(documents))
  517. class CharacterTextSplitter(TextSplitter):
  518. """Splitting text that looks at characters."""
  519. DEFAULT_SEPARATOR: str = "\n\n"
  520. def __init__(
  521. self,
  522. separator: str = DEFAULT_SEPARATOR,
  523. is_separator_regex: bool = False,
  524. **kwargs: Any,
  525. ) -> None:
  526. """Create a new TextSplitter."""
  527. super().__init__(**kwargs)
  528. self._separator = separator
  529. self._is_separator_regex = is_separator_regex
  530. def split_text(self, text: str) -> List[str]:
  531. """Split incoming text and return chunks."""
  532. # First we naively split the large input into a bunch of smaller ones.
  533. separator = (
  534. self._separator
  535. if self._is_separator_regex
  536. else re.escape(self._separator)
  537. )
  538. splits = _split_text_with_regex(text, separator, self._keep_separator)
  539. _separator = "" if self._keep_separator else self._separator
  540. return self._merge_splits(splits, _separator)
  541. class LineType(TypedDict):
  542. """Line type as typed dict."""
  543. metadata: Dict[str, str]
  544. content: str
  545. class HeaderType(TypedDict):
  546. """Header type as typed dict."""
  547. level: int
  548. name: str
  549. data: str
  550. class MarkdownHeaderTextSplitter:
  551. """Splitting markdown files based on specified headers."""
  552. def __init__(
  553. self,
  554. headers_to_split_on: List[Tuple[str, str]],
  555. return_each_line: bool = False,
  556. strip_headers: bool = True,
  557. ):
  558. """Create a new MarkdownHeaderTextSplitter.
  559. Args:
  560. headers_to_split_on: Headers we want to track
  561. return_each_line: Return each line w/ associated headers
  562. strip_headers: Strip split headers from the content of the chunk
  563. """
  564. # Output line-by-line or aggregated into chunks w/ common headers
  565. self.return_each_line = return_each_line
  566. # Given the headers we want to split on,
  567. # (e.g., "#, ##, etc") order by length
  568. self.headers_to_split_on = sorted(
  569. headers_to_split_on, key=lambda split: len(split[0]), reverse=True
  570. )
  571. # Strip headers split headers from the content of the chunk
  572. self.strip_headers = strip_headers
  573. def aggregate_lines_to_chunks(
  574. self, lines: List[LineType]
  575. ) -> List[SplitterDocument]:
  576. """Combine lines with common metadata into chunks
  577. Args:
  578. lines: Line of text / associated header metadata
  579. """
  580. aggregated_chunks: List[LineType] = []
  581. for line in lines:
  582. if (
  583. aggregated_chunks
  584. and aggregated_chunks[-1]["metadata"] == line["metadata"]
  585. ):
  586. # If the last line in the aggregated list
  587. # has the same metadata as the current line,
  588. # append the current content to the last lines's content
  589. aggregated_chunks[-1]["content"] += " \n" + line["content"]
  590. elif (
  591. aggregated_chunks
  592. and aggregated_chunks[-1]["metadata"] != line["metadata"]
  593. # may be issues if other metadata is present
  594. and len(aggregated_chunks[-1]["metadata"])
  595. < len(line["metadata"])
  596. and aggregated_chunks[-1]["content"].split("\n")[-1][0] == "#"
  597. and not self.strip_headers
  598. ):
  599. # If the last line in the aggregated list
  600. # has different metadata as the current line,
  601. # and has shallower header level than the current line,
  602. # and the last line is a header,
  603. # and we are not stripping headers,
  604. # append the current content to the last line's content
  605. aggregated_chunks[-1]["content"] += " \n" + line["content"]
  606. # and update the last line's metadata
  607. aggregated_chunks[-1]["metadata"] = line["metadata"]
  608. else:
  609. # Otherwise, append the current line to the aggregated list
  610. aggregated_chunks.append(line)
  611. return [
  612. SplitterDocument(
  613. page_content=chunk["content"], metadata=chunk["metadata"]
  614. )
  615. for chunk in aggregated_chunks
  616. ]
  617. def split_text(self, text: str) -> List[SplitterDocument]:
  618. """Split markdown file
  619. Args:
  620. text: Markdown file"""
  621. # Split the input text by newline character ("\n").
  622. lines = text.split("\n")
  623. # Final output
  624. lines_with_metadata: List[LineType] = []
  625. # Content and metadata of the chunk currently being processed
  626. current_content: List[str] = []
  627. current_metadata: Dict[str, str] = {}
  628. # Keep track of the nested header structure
  629. # header_stack: List[Dict[str, Union[int, str]]] = []
  630. header_stack: List[HeaderType] = []
  631. initial_metadata: Dict[str, str] = {}
  632. in_code_block = False
  633. opening_fence = ""
  634. for line in lines:
  635. stripped_line = line.strip()
  636. if not in_code_block:
  637. # Exclude inline code spans
  638. if (
  639. stripped_line.startswith("```")
  640. and stripped_line.count("```") == 1
  641. ):
  642. in_code_block = True
  643. opening_fence = "```"
  644. elif stripped_line.startswith("~~~"):
  645. in_code_block = True
  646. opening_fence = "~~~"
  647. else:
  648. if stripped_line.startswith(opening_fence):
  649. in_code_block = False
  650. opening_fence = ""
  651. if in_code_block:
  652. current_content.append(stripped_line)
  653. continue
  654. # Check each line against each of the header types (e.g., #, ##)
  655. for sep, name in self.headers_to_split_on:
  656. # Check if line starts with a header that we intend to split on
  657. if stripped_line.startswith(sep) and (
  658. # Header with no text OR header is followed by space
  659. # Both are valid conditions that sep is being used a header
  660. len(stripped_line) == len(sep)
  661. or stripped_line[len(sep)] == " "
  662. ):
  663. # Ensure we are tracking the header as metadata
  664. if name is not None:
  665. # Get the current header level
  666. current_header_level = sep.count("#")
  667. # Pop out headers of lower or same level from the stack
  668. while (
  669. header_stack
  670. and header_stack[-1]["level"]
  671. >= current_header_level
  672. ):
  673. # We have encountered a new header
  674. # at the same or higher level
  675. popped_header = header_stack.pop()
  676. # Clear the metadata for the
  677. # popped header in initial_metadata
  678. if popped_header["name"] in initial_metadata:
  679. initial_metadata.pop(popped_header["name"])
  680. # Push the current header to the stack
  681. header: HeaderType = {
  682. "level": current_header_level,
  683. "name": name,
  684. "data": stripped_line[len(sep) :].strip(),
  685. }
  686. header_stack.append(header)
  687. # Update initial_metadata with the current header
  688. initial_metadata[name] = header["data"]
  689. # Add the previous line to the lines_with_metadata
  690. # only if current_content is not empty
  691. if current_content:
  692. lines_with_metadata.append(
  693. {
  694. "content": "\n".join(current_content),
  695. "metadata": current_metadata.copy(),
  696. }
  697. )
  698. current_content.clear()
  699. if not self.strip_headers:
  700. current_content.append(stripped_line)
  701. break
  702. else:
  703. if stripped_line:
  704. current_content.append(stripped_line)
  705. elif current_content:
  706. lines_with_metadata.append(
  707. {
  708. "content": "\n".join(current_content),
  709. "metadata": current_metadata.copy(),
  710. }
  711. )
  712. current_content.clear()
  713. current_metadata = initial_metadata.copy()
  714. if current_content:
  715. lines_with_metadata.append(
  716. {
  717. "content": "\n".join(current_content),
  718. "metadata": current_metadata,
  719. }
  720. )
  721. # lines_with_metadata has each line with associated header metadata
  722. # aggregate these into chunks based on common metadata
  723. if not self.return_each_line:
  724. return self.aggregate_lines_to_chunks(lines_with_metadata)
  725. else:
  726. return [
  727. SplitterDocument(
  728. page_content=chunk["content"], metadata=chunk["metadata"]
  729. )
  730. for chunk in lines_with_metadata
  731. ]
  732. class ElementType(TypedDict):
  733. """Element type as typed dict."""
  734. url: str
  735. xpath: str
  736. content: str
  737. metadata: Dict[str, str]
  738. class HTMLHeaderTextSplitter:
  739. """
  740. Splitting HTML files based on specified headers.
  741. Requires lxml package.
  742. """
  743. def __init__(
  744. self,
  745. headers_to_split_on: List[Tuple[str, str]],
  746. return_each_element: bool = False,
  747. ):
  748. """Create a new HTMLHeaderTextSplitter.
  749. Args:
  750. headers_to_split_on: list of tuples of headers we want to track mapped to
  751. (arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4,
  752. h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2)].
  753. return_each_element: Return each element w/ associated headers.
  754. """
  755. # Output element-by-element or aggregated into chunks w/ common headers
  756. self.return_each_element = return_each_element
  757. self.headers_to_split_on = sorted(headers_to_split_on)
  758. def aggregate_elements_to_chunks(
  759. self, elements: List[ElementType]
  760. ) -> List[SplitterDocument]:
  761. """Combine elements with common metadata into chunks
  762. Args:
  763. elements: HTML element content with associated identifying info and metadata
  764. """
  765. aggregated_chunks: List[ElementType] = []
  766. for element in elements:
  767. if (
  768. aggregated_chunks
  769. and aggregated_chunks[-1]["metadata"] == element["metadata"]
  770. ):
  771. # If the last element in the aggregated list
  772. # has the same metadata as the current element,
  773. # append the current content to the last element's content
  774. aggregated_chunks[-1]["content"] += " \n" + element["content"]
  775. else:
  776. # Otherwise, append the current element to the aggregated list
  777. aggregated_chunks.append(element)
  778. return [
  779. SplitterDocument(
  780. page_content=chunk["content"], metadata=chunk["metadata"]
  781. )
  782. for chunk in aggregated_chunks
  783. ]
  784. def split_text_from_url(self, url: str) -> List[SplitterDocument]:
  785. """Split HTML from web URL
  786. Args:
  787. url: web URL
  788. """
  789. r = requests.get(url)
  790. return self.split_text_from_file(BytesIO(r.content))
  791. def split_text(self, text: str) -> List[SplitterDocument]:
  792. """Split HTML text string
  793. Args:
  794. text: HTML text
  795. """
  796. return self.split_text_from_file(StringIO(text))
  797. def split_text_from_file(self, file: Any) -> List[SplitterDocument]:
  798. """Split HTML file
  799. Args:
  800. file: HTML file
  801. """
  802. try:
  803. from lxml import etree
  804. except ImportError as e:
  805. raise ImportError(
  806. "Unable to import lxml, please install with `pip install lxml`."
  807. ) from e
  808. # use lxml library to parse html document and return xml ElementTree
  809. # Explicitly encoding in utf-8 allows non-English
  810. # html files to be processed without garbled characters
  811. parser = etree.HTMLParser(encoding="utf-8")
  812. tree = etree.parse(file, parser)
  813. # document transformation for "structure-aware" chunking is handled with xsl.
  814. # see comments in html_chunks_with_headers.xslt for more detailed information.
  815. xslt_path = (
  816. pathlib.Path(__file__).parent
  817. / "document_transformers/xsl/html_chunks_with_headers.xslt"
  818. )
  819. xslt_tree = etree.parse(xslt_path)
  820. transform = etree.XSLT(xslt_tree)
  821. result = transform(tree)
  822. result_dom = etree.fromstring(str(result))
  823. # create filter and mapping for header metadata
  824. header_filter = [header[0] for header in self.headers_to_split_on]
  825. header_mapping = dict(self.headers_to_split_on)
  826. # map xhtml namespace prefix
  827. ns_map = {"h": "http://www.w3.org/1999/xhtml"}
  828. # build list of elements from DOM
  829. elements = []
  830. for element in result_dom.findall("*//*", ns_map):
  831. if element.findall("*[@class='headers']") or element.findall(
  832. "*[@class='chunk']"
  833. ):
  834. elements.append(
  835. ElementType(
  836. url=file,
  837. xpath="".join(
  838. [
  839. node.text
  840. for node in element.findall(
  841. "*[@class='xpath']", ns_map
  842. )
  843. ]
  844. ),
  845. content="".join(
  846. [
  847. node.text
  848. for node in element.findall(
  849. "*[@class='chunk']", ns_map
  850. )
  851. ]
  852. ),
  853. metadata={
  854. # Add text of specified headers to metadata using header
  855. # mapping.
  856. header_mapping[node.tag]: node.text
  857. for node in filter(
  858. lambda x: x.tag in header_filter,
  859. element.findall(
  860. "*[@class='headers']/*", ns_map
  861. ),
  862. )
  863. },
  864. )
  865. )
  866. if not self.return_each_element:
  867. return self.aggregate_elements_to_chunks(elements)
  868. else:
  869. return [
  870. SplitterDocument(
  871. page_content=chunk["content"], metadata=chunk["metadata"]
  872. )
  873. for chunk in elements
  874. ]
  875. # should be in newer Python versions (3.11+)
  876. # @dataclass(frozen=True, kw_only=True, slots=True)
  877. @dataclass(frozen=True)
  878. class Tokenizer:
  879. """Tokenizer data class."""
  880. chunk_overlap: int
  881. """Overlap in tokens between chunks"""
  882. tokens_per_chunk: int
  883. """Maximum number of tokens per chunk"""
  884. decode: Callable[[List[int]], str]
  885. """ Function to decode a list of token ids to a string"""
  886. encode: Callable[[str], List[int]]
  887. """ Function to encode a string to a list of token ids"""
  888. def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
  889. """Split incoming text and return chunks using tokenizer."""
  890. splits: List[str] = []
  891. input_ids = tokenizer.encode(text)
  892. start_idx = 0
  893. cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
  894. chunk_ids = input_ids[start_idx:cur_idx]
  895. while start_idx < len(input_ids):
  896. splits.append(tokenizer.decode(chunk_ids))
  897. if cur_idx == len(input_ids):
  898. break
  899. start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
  900. cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
  901. chunk_ids = input_ids[start_idx:cur_idx]
  902. return splits
  903. class TokenTextSplitter(TextSplitter):
  904. """Splitting text to tokens using model tokenizer."""
  905. def __init__(
  906. self,
  907. encoding_name: str = "gpt2",
  908. model: Optional[str] = None,
  909. allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
  910. disallowed_special: Union[Literal["all"], Collection[str]] = "all",
  911. **kwargs: Any,
  912. ) -> None:
  913. """Create a new TextSplitter."""
  914. super().__init__(**kwargs)
  915. try:
  916. import tiktoken
  917. except ImportError:
  918. raise ImportError(
  919. "Could not import tiktoken python package. "
  920. "This is needed in order to for TokenTextSplitter. "
  921. "Please install it with `pip install tiktoken`."
  922. )
  923. if model is not None:
  924. enc = tiktoken.encoding_for_model(model)
  925. else:
  926. enc = tiktoken.get_encoding(encoding_name)
  927. self._tokenizer = enc
  928. self._allowed_special = allowed_special
  929. self._disallowed_special = disallowed_special
  930. def split_text(self, text: str) -> List[str]:
  931. def _encode(_text: str) -> List[int]:
  932. return self._tokenizer.encode(
  933. _text,
  934. allowed_special=self._allowed_special,
  935. disallowed_special=self._disallowed_special,
  936. )
  937. tokenizer = Tokenizer(
  938. chunk_overlap=self._chunk_overlap,
  939. tokens_per_chunk=self._chunk_size,
  940. decode=self._tokenizer.decode,
  941. encode=_encode,
  942. )
  943. return split_text_on_tokens(text=text, tokenizer=tokenizer)
  944. class SentenceTransformersTokenTextSplitter(TextSplitter):
  945. """Splitting text to tokens using sentence model tokenizer."""
  946. def __init__(
  947. self,
  948. chunk_overlap: int = 50,
  949. model: str = "sentence-transformers/all-mpnet-base-v2",
  950. tokens_per_chunk: Optional[int] = None,
  951. **kwargs: Any,
  952. ) -> None:
  953. """Create a new TextSplitter."""
  954. super().__init__(**kwargs, chunk_overlap=chunk_overlap)
  955. try:
  956. from sentence_transformers import SentenceTransformer
  957. except ImportError:
  958. raise ImportError(
  959. "Could not import sentence_transformer python package. "
  960. "This is needed in order to for SentenceTransformersTokenTextSplitter. "
  961. "Please install it with `pip install sentence-transformers`."
  962. )
  963. self.model = model
  964. self._model = SentenceTransformer(self.model, trust_remote_code=True)
  965. self.tokenizer = self._model.tokenizer
  966. self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
  967. def _initialize_chunk_configuration(
  968. self, *, tokens_per_chunk: Optional[int]
  969. ) -> None:
  970. self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
  971. if tokens_per_chunk is None:
  972. self.tokens_per_chunk = self.maximum_tokens_per_chunk
  973. else:
  974. self.tokens_per_chunk = tokens_per_chunk
  975. if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
  976. raise ValueError(
  977. f"The token limit of the models '{self.model}'"
  978. f" is: {self.maximum_tokens_per_chunk}."
  979. f" Argument tokens_per_chunk={self.tokens_per_chunk}"
  980. f" > maximum token limit."
  981. )
  982. def split_text(self, text: str) -> List[str]:
  983. def encode_strip_start_and_stop_token_ids(text: str) -> List[int]:
  984. return self._encode(text)[1:-1]
  985. tokenizer = Tokenizer(
  986. chunk_overlap=self._chunk_overlap,
  987. tokens_per_chunk=self.tokens_per_chunk,
  988. decode=self.tokenizer.decode,
  989. encode=encode_strip_start_and_stop_token_ids,
  990. )
  991. return split_text_on_tokens(text=text, tokenizer=tokenizer)
  992. def count_tokens(self, *, text: str) -> int:
  993. return len(self._encode(text))
  994. _max_length_equal_32_bit_integer: int = 2**32
  995. def _encode(self, text: str) -> List[int]:
  996. token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
  997. text,
  998. max_length=self._max_length_equal_32_bit_integer,
  999. truncation="do_not_truncate",
  1000. )
  1001. return token_ids_with_start_and_end_token_ids
  1002. class Language(str, Enum):
  1003. """Enum of the programming languages."""
  1004. CPP = "cpp"
  1005. GO = "go"
  1006. JAVA = "java"
  1007. KOTLIN = "kotlin"
  1008. JS = "js"
  1009. TS = "ts"
  1010. PHP = "php"
  1011. PROTO = "proto"
  1012. PYTHON = "python"
  1013. RST = "rst"
  1014. RUBY = "ruby"
  1015. RUST = "rust"
  1016. SCALA = "scala"
  1017. SWIFT = "swift"
  1018. MARKDOWN = "markdown"
  1019. LATEX = "latex"
  1020. HTML = "html"
  1021. SOL = "sol"
  1022. CSHARP = "csharp"
  1023. COBOL = "cobol"
  1024. C = "c"
  1025. LUA = "lua"
  1026. PERL = "perl"
  1027. class RecursiveCharacterTextSplitter(TextSplitter):
  1028. """Splitting text by recursively look at characters.
  1029. Recursively tries to split by different characters to find one
  1030. that works.
  1031. """
  1032. def __init__(
  1033. self,
  1034. separators: Optional[List[str]] = None,
  1035. keep_separator: bool = True,
  1036. is_separator_regex: bool = False,
  1037. chunk_size: int = 4000,
  1038. chunk_overlap: int = 200,
  1039. **kwargs: Any,
  1040. ) -> None:
  1041. """Create a new TextSplitter."""
  1042. super().__init__(
  1043. chunk_size=chunk_size,
  1044. chunk_overlap=chunk_overlap,
  1045. keep_separator=keep_separator,
  1046. **kwargs,
  1047. )
  1048. self._separators = separators or ["\n\n", "\n", " ", ""]
  1049. self._is_separator_regex = is_separator_regex
  1050. self.chunk_size = chunk_size
  1051. self.chunk_overlap = chunk_overlap
  1052. def _split_text(self, text: str, separators: List[str]) -> List[str]:
  1053. """Split incoming text and return chunks."""
  1054. final_chunks = []
  1055. # Get appropriate separator to use
  1056. separator = separators[-1]
  1057. new_separators = []
  1058. for i, _s in enumerate(separators):
  1059. _separator = _s if self._is_separator_regex else re.escape(_s)
  1060. if _s == "":
  1061. separator = _s
  1062. break
  1063. if re.search(_separator, text):
  1064. separator = _s
  1065. new_separators = separators[i + 1 :]
  1066. break
  1067. _separator = (
  1068. separator if self._is_separator_regex else re.escape(separator)
  1069. )
  1070. splits = _split_text_with_regex(text, _separator, self._keep_separator)
  1071. # Now go merging things, recursively splitting longer texts.
  1072. _good_splits = []
  1073. _separator = "" if self._keep_separator else separator
  1074. for s in splits:
  1075. if self._length_function(s) < self._chunk_size:
  1076. _good_splits.append(s)
  1077. else:
  1078. if _good_splits:
  1079. merged_text = self._merge_splits(_good_splits, _separator)
  1080. final_chunks.extend(merged_text)
  1081. _good_splits = []
  1082. if not new_separators:
  1083. final_chunks.append(s)
  1084. else:
  1085. other_info = self._split_text(s, new_separators)
  1086. final_chunks.extend(other_info)
  1087. if _good_splits:
  1088. merged_text = self._merge_splits(_good_splits, _separator)
  1089. final_chunks.extend(merged_text)
  1090. return final_chunks
  1091. def split_text(self, text: str) -> List[str]:
  1092. return self._split_text(text, self._separators)
  1093. @classmethod
  1094. def from_language(
  1095. cls, language: Language, **kwargs: Any
  1096. ) -> RecursiveCharacterTextSplitter:
  1097. separators = cls.get_separators_for_language(language)
  1098. return cls(separators=separators, is_separator_regex=True, **kwargs)
  1099. @staticmethod
  1100. def get_separators_for_language(language: Language) -> List[str]:
  1101. if language == Language.CPP:
  1102. return [
  1103. # Split along class definitions
  1104. "\nclass ",
  1105. # Split along function definitions
  1106. "\nvoid ",
  1107. "\nint ",
  1108. "\nfloat ",
  1109. "\ndouble ",
  1110. # Split along control flow statements
  1111. "\nif ",
  1112. "\nfor ",
  1113. "\nwhile ",
  1114. "\nswitch ",
  1115. "\ncase ",
  1116. # Split by the normal type of lines
  1117. "\n\n",
  1118. "\n",
  1119. " ",
  1120. "",
  1121. ]
  1122. elif language == Language.GO:
  1123. return [
  1124. # Split along function definitions
  1125. "\nfunc ",
  1126. "\nvar ",
  1127. "\nconst ",
  1128. "\ntype ",
  1129. # Split along control flow statements
  1130. "\nif ",
  1131. "\nfor ",
  1132. "\nswitch ",
  1133. "\ncase ",
  1134. # Split by the normal type of lines
  1135. "\n\n",
  1136. "\n",
  1137. " ",
  1138. "",
  1139. ]
  1140. elif language == Language.JAVA:
  1141. return [
  1142. # Split along class definitions
  1143. "\nclass ",
  1144. # Split along method definitions
  1145. "\npublic ",
  1146. "\nprotected ",
  1147. "\nprivate ",
  1148. "\nstatic ",
  1149. # Split along control flow statements
  1150. "\nif ",
  1151. "\nfor ",
  1152. "\nwhile ",
  1153. "\nswitch ",
  1154. "\ncase ",
  1155. # Split by the normal type of lines
  1156. "\n\n",
  1157. "\n",
  1158. " ",
  1159. "",
  1160. ]
  1161. elif language == Language.KOTLIN:
  1162. return [
  1163. # Split along class definitions
  1164. "\nclass ",
  1165. # Split along method definitions
  1166. "\npublic ",
  1167. "\nprotected ",
  1168. "\nprivate ",
  1169. "\ninternal ",
  1170. "\ncompanion ",
  1171. "\nfun ",
  1172. "\nval ",
  1173. "\nvar ",
  1174. # Split along control flow statements
  1175. "\nif ",
  1176. "\nfor ",
  1177. "\nwhile ",
  1178. "\nwhen ",
  1179. "\ncase ",
  1180. "\nelse ",
  1181. # Split by the normal type of lines
  1182. "\n\n",
  1183. "\n",
  1184. " ",
  1185. "",
  1186. ]
  1187. elif language == Language.JS:
  1188. return [
  1189. # Split along function definitions
  1190. "\nfunction ",
  1191. "\nconst ",
  1192. "\nlet ",
  1193. "\nvar ",
  1194. "\nclass ",
  1195. # Split along control flow statements
  1196. "\nif ",
  1197. "\nfor ",
  1198. "\nwhile ",
  1199. "\nswitch ",
  1200. "\ncase ",
  1201. "\ndefault ",
  1202. # Split by the normal type of lines
  1203. "\n\n",
  1204. "\n",
  1205. " ",
  1206. "",
  1207. ]
  1208. elif language == Language.TS:
  1209. return [
  1210. "\nenum ",
  1211. "\ninterface ",
  1212. "\nnamespace ",
  1213. "\ntype ",
  1214. # Split along class definitions
  1215. "\nclass ",
  1216. # Split along function definitions
  1217. "\nfunction ",
  1218. "\nconst ",
  1219. "\nlet ",
  1220. "\nvar ",
  1221. # Split along control flow statements
  1222. "\nif ",
  1223. "\nfor ",
  1224. "\nwhile ",
  1225. "\nswitch ",
  1226. "\ncase ",
  1227. "\ndefault ",
  1228. # Split by the normal type of lines
  1229. "\n\n",
  1230. "\n",
  1231. " ",
  1232. "",
  1233. ]
  1234. elif language == Language.PHP:
  1235. return [
  1236. # Split along function definitions
  1237. "\nfunction ",
  1238. # Split along class definitions
  1239. "\nclass ",
  1240. # Split along control flow statements
  1241. "\nif ",
  1242. "\nforeach ",
  1243. "\nwhile ",
  1244. "\ndo ",
  1245. "\nswitch ",
  1246. "\ncase ",
  1247. # Split by the normal type of lines
  1248. "\n\n",
  1249. "\n",
  1250. " ",
  1251. "",
  1252. ]
  1253. elif language == Language.PROTO:
  1254. return [
  1255. # Split along message definitions
  1256. "\nmessage ",
  1257. # Split along service definitions
  1258. "\nservice ",
  1259. # Split along enum definitions
  1260. "\nenum ",
  1261. # Split along option definitions
  1262. "\noption ",
  1263. # Split along import statements
  1264. "\nimport ",
  1265. # Split along syntax declarations
  1266. "\nsyntax ",
  1267. # Split by the normal type of lines
  1268. "\n\n",
  1269. "\n",
  1270. " ",
  1271. "",
  1272. ]
  1273. elif language == Language.PYTHON:
  1274. return [
  1275. # First, try to split along class definitions
  1276. "\nclass ",
  1277. "\ndef ",
  1278. "\n\tdef ",
  1279. # Now split by the normal type of lines
  1280. "\n\n",
  1281. "\n",
  1282. " ",
  1283. "",
  1284. ]
  1285. elif language == Language.RST:
  1286. return [
  1287. # Split along section titles
  1288. "\n=+\n",
  1289. "\n-+\n",
  1290. "\n\\*+\n",
  1291. # Split along directive markers
  1292. "\n\n.. *\n\n",
  1293. # Split by the normal type of lines
  1294. "\n\n",
  1295. "\n",
  1296. " ",
  1297. "",
  1298. ]
  1299. elif language == Language.RUBY:
  1300. return [
  1301. # Split along method definitions
  1302. "\ndef ",
  1303. "\nclass ",
  1304. # Split along control flow statements
  1305. "\nif ",
  1306. "\nunless ",
  1307. "\nwhile ",
  1308. "\nfor ",
  1309. "\ndo ",
  1310. "\nbegin ",
  1311. "\nrescue ",
  1312. # Split by the normal type of lines
  1313. "\n\n",
  1314. "\n",
  1315. " ",
  1316. "",
  1317. ]
  1318. elif language == Language.RUST:
  1319. return [
  1320. # Split along function definitions
  1321. "\nfn ",
  1322. "\nconst ",
  1323. "\nlet ",
  1324. # Split along control flow statements
  1325. "\nif ",
  1326. "\nwhile ",
  1327. "\nfor ",
  1328. "\nloop ",
  1329. "\nmatch ",
  1330. "\nconst ",
  1331. # Split by the normal type of lines
  1332. "\n\n",
  1333. "\n",
  1334. " ",
  1335. "",
  1336. ]
  1337. elif language == Language.SCALA:
  1338. return [
  1339. # Split along class definitions
  1340. "\nclass ",
  1341. "\nobject ",
  1342. # Split along method definitions
  1343. "\ndef ",
  1344. "\nval ",
  1345. "\nvar ",
  1346. # Split along control flow statements
  1347. "\nif ",
  1348. "\nfor ",
  1349. "\nwhile ",
  1350. "\nmatch ",
  1351. "\ncase ",
  1352. # Split by the normal type of lines
  1353. "\n\n",
  1354. "\n",
  1355. " ",
  1356. "",
  1357. ]
  1358. elif language == Language.SWIFT:
  1359. return [
  1360. # Split along function definitions
  1361. "\nfunc ",
  1362. # Split along class definitions
  1363. "\nclass ",
  1364. "\nstruct ",
  1365. "\nenum ",
  1366. # Split along control flow statements
  1367. "\nif ",
  1368. "\nfor ",
  1369. "\nwhile ",
  1370. "\ndo ",
  1371. "\nswitch ",
  1372. "\ncase ",
  1373. # Split by the normal type of lines
  1374. "\n\n",
  1375. "\n",
  1376. " ",
  1377. "",
  1378. ]
  1379. elif language == Language.MARKDOWN:
  1380. return [
  1381. # First, try to split along Markdown headings (starting with level 2)
  1382. "\n#{1,6} ",
  1383. # Note the alternative syntax for headings (below) is not handled here
  1384. # Heading level 2
  1385. # ---------------
  1386. # End of code block
  1387. "```\n",
  1388. # Horizontal lines
  1389. "\n\\*\\*\\*+\n",
  1390. "\n---+\n",
  1391. "\n___+\n",
  1392. # Note that this splitter doesn't handle horizontal lines defined
  1393. # by *three or more* of ***, ---, or ___, but this is not handled
  1394. "\n\n",
  1395. "\n",
  1396. " ",
  1397. "",
  1398. ]
  1399. elif language == Language.LATEX:
  1400. return [
  1401. # First, try to split along Latex sections
  1402. "\n\\\\chapter{",
  1403. "\n\\\\section{",
  1404. "\n\\\\subsection{",
  1405. "\n\\\\subsubsection{",
  1406. # Now split by environments
  1407. "\n\\\\begin{enumerate}",
  1408. "\n\\\\begin{itemize}",
  1409. "\n\\\\begin{description}",
  1410. "\n\\\\begin{list}",
  1411. "\n\\\\begin{quote}",
  1412. "\n\\\\begin{quotation}",
  1413. "\n\\\\begin{verse}",
  1414. "\n\\\\begin{verbatim}",
  1415. # Now split by math environments
  1416. "\n\\\begin{align}",
  1417. "$$",
  1418. "$",
  1419. # Now split by the normal type of lines
  1420. " ",
  1421. "",
  1422. ]
  1423. elif language == Language.HTML:
  1424. return [
  1425. # First, try to split along HTML tags
  1426. "<body",
  1427. "<div",
  1428. "<p",
  1429. "<br",
  1430. "<li",
  1431. "<h1",
  1432. "<h2",
  1433. "<h3",
  1434. "<h4",
  1435. "<h5",
  1436. "<h6",
  1437. "<span",
  1438. "<table",
  1439. "<tr",
  1440. "<td",
  1441. "<th",
  1442. "<ul",
  1443. "<ol",
  1444. "<header",
  1445. "<footer",
  1446. "<nav",
  1447. # Head
  1448. "<head",
  1449. "<style",
  1450. "<script",
  1451. "<meta",
  1452. "<title",
  1453. "",
  1454. ]
  1455. elif language == Language.CSHARP:
  1456. return [
  1457. "\ninterface ",
  1458. "\nenum ",
  1459. "\nimplements ",
  1460. "\ndelegate ",
  1461. "\nevent ",
  1462. # Split along class definitions
  1463. "\nclass ",
  1464. "\nabstract ",
  1465. # Split along method definitions
  1466. "\npublic ",
  1467. "\nprotected ",
  1468. "\nprivate ",
  1469. "\nstatic ",
  1470. "\nreturn ",
  1471. # Split along control flow statements
  1472. "\nif ",
  1473. "\ncontinue ",
  1474. "\nfor ",
  1475. "\nforeach ",
  1476. "\nwhile ",
  1477. "\nswitch ",
  1478. "\nbreak ",
  1479. "\ncase ",
  1480. "\nelse ",
  1481. # Split by exceptions
  1482. "\ntry ",
  1483. "\nthrow ",
  1484. "\nfinally ",
  1485. "\ncatch ",
  1486. # Split by the normal type of lines
  1487. "\n\n",
  1488. "\n",
  1489. " ",
  1490. "",
  1491. ]
  1492. elif language == Language.SOL:
  1493. return [
  1494. # Split along compiler information definitions
  1495. "\npragma ",
  1496. "\nusing ",
  1497. # Split along contract definitions
  1498. "\ncontract ",
  1499. "\ninterface ",
  1500. "\nlibrary ",
  1501. # Split along method definitions
  1502. "\nconstructor ",
  1503. "\ntype ",
  1504. "\nfunction ",
  1505. "\nevent ",
  1506. "\nmodifier ",
  1507. "\nerror ",
  1508. "\nstruct ",
  1509. "\nenum ",
  1510. # Split along control flow statements
  1511. "\nif ",
  1512. "\nfor ",
  1513. "\nwhile ",
  1514. "\ndo while ",
  1515. "\nassembly ",
  1516. # Split by the normal type of lines
  1517. "\n\n",
  1518. "\n",
  1519. " ",
  1520. "",
  1521. ]
  1522. elif language == Language.COBOL:
  1523. return [
  1524. # Split along divisions
  1525. "\nIDENTIFICATION DIVISION.",
  1526. "\nENVIRONMENT DIVISION.",
  1527. "\nDATA DIVISION.",
  1528. "\nPROCEDURE DIVISION.",
  1529. # Split along sections within DATA DIVISION
  1530. "\nWORKING-STORAGE SECTION.",
  1531. "\nLINKAGE SECTION.",
  1532. "\nFILE SECTION.",
  1533. # Split along sections within PROCEDURE DIVISION
  1534. "\nINPUT-OUTPUT SECTION.",
  1535. # Split along paragraphs and common statements
  1536. "\nOPEN ",
  1537. "\nCLOSE ",
  1538. "\nREAD ",
  1539. "\nWRITE ",
  1540. "\nIF ",
  1541. "\nELSE ",
  1542. "\nMOVE ",
  1543. "\nPERFORM ",
  1544. "\nUNTIL ",
  1545. "\nVARYING ",
  1546. "\nACCEPT ",
  1547. "\nDISPLAY ",
  1548. "\nSTOP RUN.",
  1549. # Split by the normal type of lines
  1550. "\n",
  1551. " ",
  1552. "",
  1553. ]
  1554. else:
  1555. raise ValueError(
  1556. f"Language {language} is not supported! "
  1557. f"Please choose from {list(Language)}"
  1558. )
  1559. class NLTKTextSplitter(TextSplitter):
  1560. """Splitting text using NLTK package."""
  1561. def __init__(
  1562. self, separator: str = "\n\n", language: str = "english", **kwargs: Any
  1563. ) -> None:
  1564. """Initialize the NLTK splitter."""
  1565. super().__init__(**kwargs)
  1566. try:
  1567. from nltk.tokenize import sent_tokenize
  1568. self._tokenizer = sent_tokenize
  1569. except ImportError:
  1570. raise ImportError(
  1571. "NLTK is not installed, please install it with `pip install nltk`."
  1572. )
  1573. self._separator = separator
  1574. self._language = language
  1575. def split_text(self, text: str) -> List[str]:
  1576. """Split incoming text and return chunks."""
  1577. # First we naively split the large input into a bunch of smaller ones.
  1578. splits = self._tokenizer(text, language=self._language)
  1579. return self._merge_splits(splits, self._separator)
  1580. class SpacyTextSplitter(TextSplitter):
  1581. """Splitting text using Spacy package.
  1582. Per default, Spacy's `en_core_web_sm` model is used and
  1583. its default max_length is 1000000 (it is the length of maximum character
  1584. this model takes which can be increased for large files). For a faster, but
  1585. potentially less accurate splitting, you can use `pipe='sentencizer'`.
  1586. """
  1587. def __init__(
  1588. self,
  1589. separator: str = "\n\n",
  1590. pipe: str = "en_core_web_sm",
  1591. max_length: int = 1_000_000,
  1592. **kwargs: Any,
  1593. ) -> None:
  1594. """Initialize the spacy text splitter."""
  1595. super().__init__(**kwargs)
  1596. self._tokenizer = _make_spacy_pipe_for_splitting(
  1597. pipe, max_length=max_length
  1598. )
  1599. self._separator = separator
  1600. def split_text(self, text: str) -> List[str]:
  1601. """Split incoming text and return chunks."""
  1602. splits = (s.text for s in self._tokenizer(text).sents)
  1603. return self._merge_splits(splits, self._separator)
  1604. class KonlpyTextSplitter(TextSplitter):
  1605. """Splitting text using Konlpy package.
  1606. It is good for splitting Korean text.
  1607. """
  1608. def __init__(
  1609. self,
  1610. separator: str = "\n\n",
  1611. **kwargs: Any,
  1612. ) -> None:
  1613. """Initialize the Konlpy text splitter."""
  1614. super().__init__(**kwargs)
  1615. self._separator = separator
  1616. try:
  1617. from konlpy.tag import Kkma
  1618. except ImportError:
  1619. raise ImportError(
  1620. """
  1621. Konlpy is not installed, please install it with
  1622. `pip install konlpy`
  1623. """
  1624. )
  1625. self.kkma = Kkma()
  1626. def split_text(self, text: str) -> List[str]:
  1627. """Split incoming text and return chunks."""
  1628. splits = self.kkma.sentences(text)
  1629. return self._merge_splits(splits, self._separator)
  1630. # For backwards compatibility
  1631. class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
  1632. """Attempts to split the text along Python syntax."""
  1633. def __init__(self, **kwargs: Any) -> None:
  1634. """Initialize a PythonCodeTextSplitter."""
  1635. separators = self.get_separators_for_language(Language.PYTHON)
  1636. super().__init__(separators=separators, **kwargs)
  1637. class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
  1638. """Attempts to split the text along Markdown-formatted headings."""
  1639. def __init__(self, **kwargs: Any) -> None:
  1640. """Initialize a MarkdownTextSplitter."""
  1641. separators = self.get_separators_for_language(Language.MARKDOWN)
  1642. super().__init__(separators=separators, **kwargs)
  1643. class LatexTextSplitter(RecursiveCharacterTextSplitter):
  1644. """Attempts to split the text along Latex-formatted layout elements."""
  1645. def __init__(self, **kwargs: Any) -> None:
  1646. """Initialize a LatexTextSplitter."""
  1647. separators = self.get_separators_for_language(Language.LATEX)
  1648. super().__init__(separators=separators, **kwargs)
  1649. class RecursiveJsonSplitter:
  1650. def __init__(
  1651. self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None
  1652. ):
  1653. super().__init__()
  1654. self.max_chunk_size = max_chunk_size
  1655. self.min_chunk_size = (
  1656. min_chunk_size
  1657. if min_chunk_size is not None
  1658. else max(max_chunk_size - 200, 50)
  1659. )
  1660. @staticmethod
  1661. def _json_size(data: Dict) -> int:
  1662. """Calculate the size of the serialized JSON object."""
  1663. return len(json.dumps(data))
  1664. @staticmethod
  1665. def _set_nested_dict(d: Dict, path: List[str], value: Any) -> None:
  1666. """Set a value in a nested dictionary based on the given path."""
  1667. for key in path[:-1]:
  1668. d = d.setdefault(key, {})
  1669. d[path[-1]] = value
  1670. def _list_to_dict_preprocessing(self, data: Any) -> Any:
  1671. if isinstance(data, dict):
  1672. # Process each key-value pair in the dictionary
  1673. return {
  1674. k: self._list_to_dict_preprocessing(v) for k, v in data.items()
  1675. }
  1676. elif isinstance(data, list):
  1677. # Convert the list to a dictionary with index-based keys
  1678. return {
  1679. str(i): self._list_to_dict_preprocessing(item)
  1680. for i, item in enumerate(data)
  1681. }
  1682. else:
  1683. # Base case: the item is neither a dict nor a list, so return it unchanged
  1684. return data
  1685. def _json_split(
  1686. self,
  1687. data: Dict[str, Any],
  1688. current_path: List[str] = [],
  1689. chunks: List[Dict] = [{}],
  1690. ) -> List[Dict]:
  1691. """
  1692. Split json into maximum size dictionaries while preserving structure.
  1693. """
  1694. if isinstance(data, dict):
  1695. for key, value in data.items():
  1696. new_path = current_path + [key]
  1697. chunk_size = self._json_size(chunks[-1])
  1698. size = self._json_size({key: value})
  1699. remaining = self.max_chunk_size - chunk_size
  1700. if size < remaining:
  1701. # Add item to current chunk
  1702. self._set_nested_dict(chunks[-1], new_path, value)
  1703. else:
  1704. if chunk_size >= self.min_chunk_size:
  1705. # Chunk is big enough, start a new chunk
  1706. chunks.append({})
  1707. # Iterate
  1708. self._json_split(value, new_path, chunks)
  1709. else:
  1710. # handle single item
  1711. self._set_nested_dict(chunks[-1], current_path, data)
  1712. return chunks
  1713. def split_json(
  1714. self,
  1715. json_data: Dict[str, Any],
  1716. convert_lists: bool = False,
  1717. ) -> List[Dict]:
  1718. """Splits JSON into a list of JSON chunks"""
  1719. if convert_lists:
  1720. chunks = self._json_split(
  1721. self._list_to_dict_preprocessing(json_data)
  1722. )
  1723. else:
  1724. chunks = self._json_split(json_data)
  1725. # Remove the last chunk if it's empty
  1726. if not chunks[-1]:
  1727. chunks.pop()
  1728. return chunks
  1729. def split_text(
  1730. self, json_data: Dict[str, Any], convert_lists: bool = False
  1731. ) -> List[str]:
  1732. """Splits JSON into a list of JSON formatted strings"""
  1733. chunks = self.split_json(
  1734. json_data=json_data, convert_lists=convert_lists
  1735. )
  1736. # Convert to string
  1737. return [json.dumps(chunk) for chunk in chunks]
  1738. def create_documents(
  1739. self,
  1740. texts: List[Dict],
  1741. convert_lists: bool = False,
  1742. metadatas: Optional[List[dict]] = None,
  1743. ) -> List[SplitterDocument]:
  1744. """Create documents from a list of json objects (Dict)."""
  1745. _metadatas = metadatas or [{}] * len(texts)
  1746. documents = []
  1747. for i, text in enumerate(texts):
  1748. for chunk in self.split_text(
  1749. json_data=text, convert_lists=convert_lists
  1750. ):
  1751. metadata = copy.deepcopy(_metadatas[i])
  1752. new_doc = SplitterDocument(
  1753. page_content=chunk, metadata=metadata
  1754. )
  1755. documents.append(new_doc)
  1756. return documents