base.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. # type: ignore
  2. import logging
  3. import time
  4. from typing import Any, AsyncGenerator, Optional
  5. from core import parsers
  6. from core.base import (
  7. AsyncParser,
  8. ChunkingStrategy,
  9. Document,
  10. DocumentChunk,
  11. DocumentType,
  12. IngestionConfig,
  13. IngestionProvider,
  14. R2RDocumentProcessingError,
  15. RecursiveCharacterTextSplitter,
  16. TextSplitter,
  17. )
  18. from core.providers.database import PostgresDatabaseProvider
  19. from core.providers.llm import (
  20. LiteLLMCompletionProvider,
  21. OpenAICompletionProvider,
  22. R2RCompletionProvider,
  23. )
  24. from core.providers.ocr import MistralOCRProvider
  25. from core.utils import generate_extraction_id
  26. logger = logging.getLogger()
  27. class R2RIngestionConfig(IngestionConfig):
  28. chunk_size: int = 1024
  29. chunk_overlap: int = 512
  30. chunking_strategy: ChunkingStrategy = ChunkingStrategy.RECURSIVE
  31. extra_fields: dict[str, Any] = {}
  32. separator: Optional[str] = None
  33. class R2RIngestionProvider(IngestionProvider):
  34. DEFAULT_PARSERS = {
  35. DocumentType.BMP: parsers.BMPParser,
  36. DocumentType.CSV: parsers.CSVParser,
  37. DocumentType.DOC: parsers.DOCParser,
  38. DocumentType.DOCX: parsers.DOCXParser,
  39. DocumentType.EML: parsers.EMLParser,
  40. DocumentType.EPUB: parsers.EPUBParser,
  41. DocumentType.HTML: parsers.HTMLParser,
  42. DocumentType.HTM: parsers.HTMLParser,
  43. DocumentType.ODT: parsers.ODTParser,
  44. DocumentType.JSON: parsers.JSONParser,
  45. DocumentType.MSG: parsers.MSGParser,
  46. DocumentType.ORG: parsers.ORGParser,
  47. DocumentType.MD: parsers.MDParser,
  48. DocumentType.PDF: parsers.BasicPDFParser,
  49. DocumentType.PPT: parsers.PPTParser,
  50. DocumentType.PPTX: parsers.PPTXParser,
  51. DocumentType.TXT: parsers.TextParser,
  52. DocumentType.XLSX: parsers.XLSXParser,
  53. DocumentType.GIF: parsers.ImageParser,
  54. DocumentType.JPEG: parsers.ImageParser,
  55. DocumentType.JPG: parsers.ImageParser,
  56. DocumentType.TSV: parsers.TSVParser,
  57. DocumentType.PNG: parsers.ImageParser,
  58. DocumentType.HEIC: parsers.ImageParser,
  59. DocumentType.SVG: parsers.ImageParser,
  60. DocumentType.MP3: parsers.AudioParser,
  61. DocumentType.P7S: parsers.P7SParser,
  62. DocumentType.RST: parsers.RSTParser,
  63. DocumentType.RTF: parsers.RTFParser,
  64. DocumentType.TIFF: parsers.ImageParser,
  65. DocumentType.XLS: parsers.XLSParser,
  66. DocumentType.PY: parsers.PythonParser,
  67. DocumentType.CSS: parsers.CSSParser,
  68. DocumentType.JS: parsers.JSParser,
  69. DocumentType.TS: parsers.TSParser,
  70. }
  71. EXTRA_PARSERS = {
  72. DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced},
  73. DocumentType.PDF: {
  74. "ocr": parsers.OCRPDFParser,
  75. "unstructured": parsers.PDFParserUnstructured,
  76. "zerox": parsers.VLMPDFParser,
  77. },
  78. DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced},
  79. }
  80. IMAGE_TYPES = {
  81. DocumentType.GIF,
  82. DocumentType.HEIC,
  83. DocumentType.JPG,
  84. DocumentType.JPEG,
  85. DocumentType.PNG,
  86. DocumentType.SVG,
  87. }
  88. def __init__(
  89. self,
  90. config: R2RIngestionConfig,
  91. database_provider: PostgresDatabaseProvider,
  92. llm_provider: (
  93. LiteLLMCompletionProvider
  94. | OpenAICompletionProvider
  95. | R2RCompletionProvider
  96. ),
  97. ocr_provider: MistralOCRProvider,
  98. ):
  99. super().__init__(config, database_provider, llm_provider)
  100. self.config: R2RIngestionConfig = config
  101. self.database_provider: PostgresDatabaseProvider = database_provider
  102. self.llm_provider: (
  103. LiteLLMCompletionProvider
  104. | OpenAICompletionProvider
  105. | R2RCompletionProvider
  106. ) = llm_provider
  107. self.ocr_provider: MistralOCRProvider = ocr_provider
  108. self.parsers: dict[DocumentType, AsyncParser] = {}
  109. self.text_splitter = self._build_text_splitter()
  110. self._initialize_parsers()
  111. logger.info(
  112. f"R2RIngestionProvider initialized with config: {self.config}"
  113. )
  114. def _initialize_parsers(self):
  115. for doc_type, parser in self.DEFAULT_PARSERS.items():
  116. # will choose the first parser in the list
  117. if doc_type not in self.config.excluded_parsers:
  118. self.parsers[doc_type] = parser(
  119. config=self.config,
  120. database_provider=self.database_provider,
  121. llm_provider=self.llm_provider,
  122. )
  123. # FIXME: This doesn't allow for flexibility for a parser that might not
  124. # need an llm_provider, etc.
  125. for doc_type, parser_names in self.config.extra_parsers.items():
  126. if not isinstance(parser_names, list):
  127. parser_names = [parser_names]
  128. for parser_name in parser_names:
  129. parser_key = f"{parser_name}_{str(doc_type)}"
  130. try:
  131. self.parsers[parser_key] = self.EXTRA_PARSERS[doc_type][
  132. parser_name
  133. ](
  134. config=self.config,
  135. database_provider=self.database_provider,
  136. llm_provider=self.llm_provider,
  137. ocr_provider=self.ocr_provider,
  138. )
  139. logger.info(
  140. f"Initialized extra parser {parser_name} for {doc_type}"
  141. )
  142. except KeyError as e:
  143. logger.error(
  144. f"Parser {parser_name} for document type {doc_type} not found: {e}"
  145. )
  146. def _build_text_splitter(
  147. self, ingestion_config_override: Optional[dict] = None
  148. ) -> TextSplitter:
  149. logger.info(
  150. f"Initializing text splitter with method: {self.config.chunking_strategy}"
  151. )
  152. if not ingestion_config_override:
  153. ingestion_config_override = {}
  154. chunking_strategy = (
  155. ingestion_config_override.get("chunking_strategy")
  156. or self.config.chunking_strategy
  157. )
  158. chunk_size = (
  159. ingestion_config_override.get("chunk_size")
  160. if ingestion_config_override.get("chunk_size") is not None
  161. else self.config.chunk_size
  162. )
  163. chunk_overlap = (
  164. ingestion_config_override.get("chunk_overlap")
  165. if ingestion_config_override.get("chunk_overlap") is not None
  166. else self.config.chunk_overlap
  167. )
  168. if chunking_strategy == ChunkingStrategy.RECURSIVE:
  169. return RecursiveCharacterTextSplitter(
  170. chunk_size=chunk_size,
  171. chunk_overlap=chunk_overlap,
  172. )
  173. elif chunking_strategy == ChunkingStrategy.CHARACTER:
  174. from shared.utils.splitter.text import CharacterTextSplitter
  175. separator = (
  176. ingestion_config_override.get("separator")
  177. or self.config.separator
  178. or CharacterTextSplitter.DEFAULT_SEPARATOR
  179. )
  180. return CharacterTextSplitter(
  181. chunk_size=chunk_size,
  182. chunk_overlap=chunk_overlap,
  183. separator=separator,
  184. keep_separator=False,
  185. strip_whitespace=True,
  186. )
  187. elif chunking_strategy == ChunkingStrategy.BASIC:
  188. raise NotImplementedError(
  189. "Basic chunking method not implemented. Please use Recursive."
  190. )
  191. elif chunking_strategy == ChunkingStrategy.BY_TITLE:
  192. raise NotImplementedError("By title method not implemented")
  193. else:
  194. raise ValueError(f"Unsupported method type: {chunking_strategy}")
  195. def validate_config(self) -> bool:
  196. return self.config.chunk_size > 0 and self.config.chunk_overlap >= 0
  197. def chunk(
  198. self,
  199. parsed_document: str | DocumentChunk,
  200. ingestion_config_override: dict,
  201. ) -> AsyncGenerator[Any, None]:
  202. text_spliiter = self.text_splitter
  203. if ingestion_config_override:
  204. text_spliiter = self._build_text_splitter(
  205. ingestion_config_override
  206. )
  207. if isinstance(parsed_document, DocumentChunk):
  208. parsed_document = parsed_document.data
  209. if isinstance(parsed_document, str):
  210. chunks = text_spliiter.create_documents([parsed_document])
  211. else:
  212. # Assuming parsed_document is already a list of text chunks
  213. chunks = parsed_document
  214. for chunk in chunks:
  215. yield (
  216. chunk.page_content if hasattr(chunk, "page_content") else chunk
  217. )
  218. async def parse(
  219. self,
  220. file_content: bytes,
  221. document: Document,
  222. ingestion_config_override: dict,
  223. ) -> AsyncGenerator[DocumentChunk, None]:
  224. if document.document_type not in self.parsers:
  225. raise R2RDocumentProcessingError(
  226. document_id=document.id,
  227. error_message=f"Parser for {document.document_type} not found in `R2RIngestionProvider`.",
  228. )
  229. else:
  230. t0 = time.time()
  231. contents = []
  232. parser_overrides = ingestion_config_override.get(
  233. "parser_overrides", {}
  234. )
  235. if document.document_type.value in parser_overrides:
  236. logger.info(
  237. f"Using parser_override for {document.document_type} with input value {parser_overrides[document.document_type.value]}"
  238. )
  239. if parser_overrides[DocumentType.PDF.value] == "zerox":
  240. # Collect content from VLMPDFParser
  241. async for chunk in self.parsers[
  242. f"zerox_{DocumentType.PDF.value}"
  243. ].ingest(file_content, **ingestion_config_override):
  244. if isinstance(chunk, dict) and chunk.get("content"):
  245. contents.append(chunk)
  246. elif (
  247. chunk
  248. ): # Handle string output for backward compatibility
  249. contents.append({"content": chunk})
  250. elif parser_overrides[DocumentType.PDF.value] == "ocr":
  251. async for chunk in self.parsers[
  252. f"ocr_{DocumentType.PDF.value}"
  253. ].ingest(file_content, **ingestion_config_override):
  254. if isinstance(chunk, dict) and chunk.get("content"):
  255. contents.append(chunk)
  256. if (
  257. contents
  258. and document.document_type == DocumentType.PDF
  259. and parser_overrides.get(DocumentType.PDF.value) == "zerox"
  260. or parser_overrides.get(DocumentType.PDF.value) == "ocr"
  261. ):
  262. vlm_ocr_one_page_per_chunk = ingestion_config_override.get(
  263. "vlm_ocr_one_page_per_chunk", True
  264. )
  265. if vlm_ocr_one_page_per_chunk:
  266. # Use one page per chunk for OCR/VLM
  267. iteration = 0
  268. sorted_contents = [
  269. item
  270. for item in sorted(
  271. contents, key=lambda x: x.get("page_number", 0)
  272. )
  273. if isinstance(item.get("content"), str)
  274. ]
  275. for content_item in sorted_contents:
  276. page_num = content_item.get("page_number", 0)
  277. page_content = content_item["content"]
  278. # Create a document chunk directly from the page content
  279. metadata = {
  280. **document.metadata,
  281. "chunk_order": iteration,
  282. "page_number": page_num,
  283. }
  284. extraction = DocumentChunk(
  285. id=generate_extraction_id(
  286. document.id, iteration
  287. ),
  288. document_id=document.id,
  289. owner_id=document.owner_id,
  290. collection_ids=document.collection_ids,
  291. data=page_content,
  292. metadata=metadata,
  293. )
  294. iteration += 1
  295. yield extraction
  296. logger.debug(
  297. f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, "
  298. f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} "
  299. f"into {iteration} extractions in t={time.time() - t0:.2f} seconds using one-page-per-chunk."
  300. )
  301. return
  302. else:
  303. # Text splitting
  304. text_splitter = self._build_text_splitter(
  305. ingestion_config_override
  306. )
  307. iteration = 0
  308. sorted_contents = [
  309. item
  310. for item in sorted(
  311. contents, key=lambda x: x.get("page_number", 0)
  312. )
  313. if isinstance(item.get("content"), str)
  314. ]
  315. for content_item in sorted_contents:
  316. page_num = content_item.get("page_number", 0)
  317. page_content = content_item["content"]
  318. page_chunks = text_splitter.create_documents(
  319. [page_content]
  320. )
  321. # Create document chunks for each split piece
  322. for chunk in page_chunks:
  323. metadata = {
  324. **document.metadata,
  325. "chunk_order": iteration,
  326. "page_number": page_num,
  327. }
  328. extraction = DocumentChunk(
  329. id=generate_extraction_id(
  330. document.id, iteration
  331. ),
  332. document_id=document.id,
  333. owner_id=document.owner_id,
  334. collection_ids=document.collection_ids,
  335. data=chunk.page_content,
  336. metadata=metadata,
  337. )
  338. iteration += 1
  339. yield extraction
  340. logger.debug(
  341. f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, "
  342. f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} "
  343. f"into {iteration} extractions in t={time.time() - t0:.2f} seconds using page-by-page splitting."
  344. )
  345. return
  346. else:
  347. # Standard parsing for non-override cases
  348. async for text in self.parsers[document.document_type].ingest(
  349. file_content, **ingestion_config_override
  350. ):
  351. if text is not None:
  352. contents.append({"content": text})
  353. if not contents:
  354. logging.warning(
  355. "No valid text content was extracted during parsing"
  356. )
  357. return
  358. iteration = 0
  359. for content_item in contents:
  360. chunk_text = content_item["content"]
  361. chunks = self.chunk(chunk_text, ingestion_config_override)
  362. for chunk in chunks:
  363. metadata = {**document.metadata, "chunk_order": iteration}
  364. if "page_number" in content_item:
  365. metadata["page_number"] = content_item["page_number"]
  366. extraction = DocumentChunk(
  367. id=generate_extraction_id(document.id, iteration),
  368. document_id=document.id,
  369. owner_id=document.owner_id,
  370. collection_ids=document.collection_ids,
  371. data=chunk,
  372. metadata=metadata,
  373. )
  374. iteration += 1
  375. yield extraction
  376. logger.debug(
  377. f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, "
  378. f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} "
  379. f"into {iteration} extractions in t={time.time() - t0:.2f} seconds."
  380. )
  381. def get_parser_for_document_type(self, doc_type: DocumentType) -> Any:
  382. return self.parsers.get(doc_type)