base.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. import asyncio
  2. import base64
  3. import logging
  4. import os
  5. import time
  6. from copy import copy
  7. from io import BytesIO
  8. from typing import Any, AsyncGenerator
  9. import httpx
  10. from unstructured_client import UnstructuredClient
  11. from unstructured_client.models import operations, shared
  12. from core import parsers
  13. from core.base import (
  14. AsyncParser,
  15. ChunkingStrategy,
  16. Document,
  17. DocumentChunk,
  18. DocumentType,
  19. RecursiveCharacterTextSplitter,
  20. )
  21. from core.base.abstractions import R2RSerializable
  22. from core.base.providers.ingestion import IngestionConfig, IngestionProvider
  23. from core.providers.ocr import MistralOCRProvider
  24. from core.utils import generate_extraction_id
  25. from ...database import PostgresDatabaseProvider
  26. from ...llm import (
  27. LiteLLMCompletionProvider,
  28. OpenAICompletionProvider,
  29. R2RCompletionProvider,
  30. )
  31. logger = logging.getLogger()
  32. class FallbackElement(R2RSerializable):
  33. text: str
  34. metadata: dict[str, Any]
  35. class UnstructuredIngestionConfig(IngestionConfig):
  36. combine_under_n_chars: int = 128
  37. max_characters: int = 500
  38. new_after_n_chars: int = 1500
  39. overlap: int = 64
  40. coordinates: bool | None = None
  41. encoding: str | None = None # utf-8
  42. extract_image_block_types: list[str] | None = None
  43. gz_uncompressed_content_type: str | None = None
  44. hi_res_model_name: str | None = None
  45. include_orig_elements: bool | None = None
  46. include_page_breaks: bool | None = None
  47. languages: list[str] | None = None
  48. multipage_sections: bool | None = None
  49. ocr_languages: list[str] | None = None
  50. # output_format: Optional[str] = "application/json"
  51. overlap_all: bool | None = None
  52. pdf_infer_table_structure: bool | None = None
  53. similarity_threshold: float | None = None
  54. skip_infer_table_types: list[str] | None = None
  55. split_pdf_concurrency_level: int | None = None
  56. split_pdf_page: bool | None = None
  57. starting_page_number: int | None = None
  58. strategy: str | None = None
  59. chunking_strategy: str | ChunkingStrategy | None = None # type: ignore
  60. unique_element_ids: bool | None = None
  61. xml_keep_tags: bool | None = None
  62. def to_ingestion_request(self):
  63. import json
  64. x = json.loads(self.json())
  65. x.pop("extra_fields", None)
  66. x.pop("provider", None)
  67. x.pop("excluded_parsers", None)
  68. x = {k: v for k, v in x.items() if v is not None}
  69. return x
  70. class UnstructuredIngestionProvider(IngestionProvider):
  71. R2R_FALLBACK_PARSERS = {
  72. DocumentType.GIF: [parsers.ImageParser], # type: ignore
  73. DocumentType.JPEG: [parsers.ImageParser], # type: ignore
  74. DocumentType.JPG: [parsers.ImageParser], # type: ignore
  75. DocumentType.PNG: [parsers.ImageParser], # type: ignore
  76. DocumentType.SVG: [parsers.ImageParser], # type: ignore
  77. DocumentType.HEIC: [parsers.ImageParser], # type: ignore
  78. DocumentType.MP3: [parsers.AudioParser], # type: ignore
  79. DocumentType.JSON: [parsers.JSONParser], # type: ignore
  80. DocumentType.HTML: [parsers.HTMLParser], # type: ignore
  81. DocumentType.XLS: [parsers.XLSParser], # type: ignore
  82. DocumentType.XLSX: [parsers.XLSXParser], # type: ignore
  83. #DocumentType.DOC: [parsers.DOCParser], # type: ignore
  84. DocumentType.PPT: [parsers.PPTParser], # type: ignore
  85. }
  86. EXTRA_PARSERS = {
  87. DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced}, # type: ignore
  88. DocumentType.PDF: {
  89. "ocr": parsers.OCRPDFParser, # type: ignore
  90. "unstructured": parsers.VLMPDFParser, # type: ignore
  91. "zerox": parsers.VLMPDFParser, # type: ignore
  92. },
  93. DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced}, # type: ignore
  94. }
  95. IMAGE_TYPES = {
  96. DocumentType.GIF,
  97. DocumentType.HEIC,
  98. DocumentType.JPG,
  99. DocumentType.JPEG,
  100. DocumentType.PNG,
  101. DocumentType.SVG,
  102. }
  103. def __init__(
  104. self,
  105. config: UnstructuredIngestionConfig,
  106. database_provider: PostgresDatabaseProvider,
  107. llm_provider: (
  108. LiteLLMCompletionProvider
  109. | OpenAICompletionProvider
  110. | R2RCompletionProvider
  111. ),
  112. ocr_provider: MistralOCRProvider,
  113. ):
  114. super().__init__(config, database_provider, llm_provider)
  115. self.config: UnstructuredIngestionConfig = config
  116. self.database_provider: PostgresDatabaseProvider = database_provider
  117. self.llm_provider: (
  118. LiteLLMCompletionProvider
  119. | OpenAICompletionProvider
  120. | R2RCompletionProvider
  121. ) = llm_provider
  122. self.ocr_provider: MistralOCRProvider = ocr_provider
  123. self.client: UnstructuredClient | httpx.AsyncClient
  124. if config.provider == "unstructured_api":
  125. try:
  126. self.unstructured_api_auth = os.environ["UNSTRUCTURED_API_KEY"]
  127. except KeyError as e:
  128. raise ValueError(
  129. "UNSTRUCTURED_API_KEY environment variable is not set"
  130. ) from e
  131. self.unstructured_api_url = os.environ.get(
  132. "UNSTRUCTURED_API_URL",
  133. "https://api.unstructuredapp.io/general/v0/general",
  134. )
  135. self.client = UnstructuredClient(
  136. api_key_auth=self.unstructured_api_auth,
  137. server_url=self.unstructured_api_url,
  138. )
  139. self.shared = shared
  140. self.operations = operations
  141. else:
  142. try:
  143. self.local_unstructured_url = os.environ[
  144. "UNSTRUCTURED_SERVICE_URL"
  145. ]
  146. except KeyError as e:
  147. raise ValueError(
  148. "UNSTRUCTURED_SERVICE_URL environment variable is not set"
  149. ) from e
  150. self.client = httpx.AsyncClient()
  151. self.parsers: dict[DocumentType, AsyncParser] = {}
  152. self._initialize_parsers()
  153. def _initialize_parsers(self):
  154. for doc_type, parsers in self.R2R_FALLBACK_PARSERS.items():
  155. for parser in parsers:
  156. if (
  157. doc_type not in self.config.excluded_parsers
  158. and doc_type not in self.parsers
  159. ):
  160. # will choose the first parser in the list
  161. self.parsers[doc_type] = parser(
  162. config=self.config,
  163. database_provider=self.database_provider,
  164. llm_provider=self.llm_provider,
  165. )
  166. # TODO - Reduce code duplication between Unstructured & R2R
  167. for doc_type, parser_names in self.config.extra_parsers.items():
  168. if not isinstance(parser_names, list):
  169. parser_names = [parser_names]
  170. for parser_name in parser_names:
  171. parser_key = f"{parser_name}_{str(doc_type)}"
  172. try:
  173. self.parsers[parser_key] = self.EXTRA_PARSERS[doc_type][
  174. parser_name
  175. ](
  176. config=self.config,
  177. database_provider=self.database_provider,
  178. llm_provider=self.llm_provider,
  179. ocr_provider=self.ocr_provider,
  180. )
  181. logger.info(
  182. f"Initialized extra parser {parser_name} for {doc_type}"
  183. )
  184. except KeyError as e:
  185. logger.error(
  186. f"Parser {parser_name} for document type {doc_type} not found: {e}"
  187. )
  188. async def parse_fallback(
  189. self,
  190. file_content: bytes,
  191. ingestion_config: dict,
  192. parser_name: str,
  193. ) -> AsyncGenerator[FallbackElement, None]:
  194. contents = []
  195. async for chunk in self.parsers[parser_name].ingest( # type: ignore
  196. file_content, **ingestion_config
  197. ): # type: ignore
  198. if isinstance(chunk, dict) and chunk.get("content"):
  199. contents.append(chunk)
  200. elif chunk: # Handle string output for backward compatibility
  201. contents.append({"content": chunk})
  202. if not contents:
  203. logging.warning(
  204. "No valid text content was extracted during parsing"
  205. )
  206. return
  207. logging.info(f"Fallback ingestion with config = {ingestion_config}")
  208. vlm_ocr_one_page_per_chunk = ingestion_config.get(
  209. "vlm_ocr_one_page_per_chunk", True
  210. )
  211. iteration = 0
  212. for content_item in contents:
  213. text = content_item["content"]
  214. if vlm_ocr_one_page_per_chunk and parser_name.startswith(
  215. ("zerox_", "ocr_")
  216. ):
  217. # Use one page per chunk for OCR/VLM
  218. metadata = {"chunk_id": iteration}
  219. if "page_number" in content_item:
  220. metadata["page_number"] = content_item["page_number"]
  221. yield FallbackElement(
  222. text=text or "No content extracted.",
  223. metadata=metadata,
  224. )
  225. iteration += 1
  226. await asyncio.sleep(0)
  227. else:
  228. # Use regular text splitting
  229. loop = asyncio.get_event_loop()
  230. splitter = RecursiveCharacterTextSplitter(
  231. chunk_size=ingestion_config["new_after_n_chars"],
  232. chunk_overlap=ingestion_config["overlap"],
  233. )
  234. chunks = await loop.run_in_executor(
  235. None, splitter.create_documents, [text]
  236. )
  237. for text_chunk in chunks:
  238. metadata = {"chunk_id": iteration}
  239. if "page_number" in content_item:
  240. metadata["page_number"] = content_item["page_number"]
  241. yield FallbackElement(
  242. text=text_chunk.page_content,
  243. metadata=metadata,
  244. )
  245. iteration += 1
  246. await asyncio.sleep(0)
  247. async def parse(
  248. self,
  249. file_content: bytes,
  250. document: Document,
  251. ingestion_config_override: dict,
  252. ) -> AsyncGenerator[DocumentChunk, None]:
  253. ingestion_config = copy(
  254. {
  255. **self.config.to_ingestion_request(),
  256. **(ingestion_config_override or {}),
  257. }
  258. )
  259. # cleanup extra fields
  260. ingestion_config.pop("provider", None)
  261. ingestion_config.pop("excluded_parsers", None)
  262. t0 = time.time()
  263. parser_overrides = ingestion_config_override.get(
  264. "parser_overrides", {}
  265. )
  266. elements = []
  267. # TODO - Cleanup this approach to be less hardcoded
  268. # TODO - Remove code duplication between Unstructured & R2R
  269. logger.info(f"Parser overrides: {parser_overrides}")
  270. logger.info(f"R2R fallback parsers is: {document.document_type.value in parser_overrides or document.document_type.value in self.EXTRA_PARSERS.keys()}")
  271. if document.document_type.value in parser_overrides:
  272. '''
  273. logger.info(
  274. f"Using parser_override for {document.document_type} with input value {parser_overrides[document.document_type.value]}"
  275. )
  276. if parser_overrides[document.document_type.value] == "zerox":
  277. async for element in self.parse_fallback(
  278. file_content,
  279. ingestion_config=ingestion_config,
  280. parser_name=f"zerox_{DocumentType.PDF.value}",
  281. ):
  282. logger.warning(
  283. f"Using parser_override for {document.document_type}"
  284. )
  285. elements.append(element)
  286. elif parser_overrides[document.document_type.value] == "ocr":
  287. async for element in self.parse_fallback(
  288. file_content,
  289. ingestion_config=ingestion_config,
  290. parser_name=f"ocr_{DocumentType.PDF.value}",
  291. ):
  292. logger.warning(
  293. f"Using OCR parser_override for {document.document_type}"
  294. )
  295. elements.append(element)
  296. async for element in self.parse_fallback(
  297. file_content,
  298. ingestion_config=ingestion_config,
  299. parser_name=f"zerox_{DocumentType.PDF.value}",
  300. ):
  301. logger.warning(
  302. f"Using parser_override for {document.document_type}"
  303. )
  304. elements.append(element)
  305. '''
  306. elif document.document_type in self.R2R_FALLBACK_PARSERS.keys():
  307. logger.info(
  308. f"Parsing {document.document_type}: {document.id} with fallback parser"
  309. )
  310. async for element in self.parse_fallback(
  311. file_content,
  312. ingestion_config=ingestion_config,
  313. parser_name=document.document_type,
  314. ):
  315. elements.append(element)
  316. else:
  317. logger.info(
  318. f"Parsing {document.document_type}: {document.id} with unstructured"
  319. )
  320. file_io = BytesIO(file_content)
  321. # TODO - Include check on excluded parsers here.
  322. if self.config.provider == "unstructured_api":
  323. logger.info(f"Using API to parse document {document.id}")
  324. files = self.shared.Files(
  325. content=file_io.read(),
  326. file_name=document.metadata.get("title", "unknown_file"),
  327. )
  328. ingestion_config.pop("app", None)
  329. ingestion_config.pop("extra_parsers", None)
  330. req = self.operations.PartitionRequest(
  331. partition_parameters=self.shared.PartitionParameters(
  332. files=files,
  333. **ingestion_config,
  334. )
  335. )
  336. elements = await self.client.general.partition_async( # type: ignore
  337. request=req
  338. )
  339. elements = list(elements.elements) # type: ignore
  340. else:
  341. logger.info(
  342. f"Using local unstructured fastapi server to parse document {document.id}"
  343. )
  344. # Base64 encode the file content
  345. encoded_content = base64.b64encode(file_io.read()).decode(
  346. "utf-8"
  347. )
  348. logger.info(
  349. f"Sending a request to {self.local_unstructured_url}/partition"
  350. )
  351. #ingestion_config["strategy"] = "hi_res"
  352. print(ingestion_config)
  353. response = await self.client.post(
  354. f"{self.local_unstructured_url}/partition",
  355. json={
  356. "file_content": encoded_content, # Use encoded string
  357. "ingestion_config": ingestion_config,
  358. "filename": document.metadata.get("title", None),
  359. },
  360. timeout=3600, # Adjust timeout as needed
  361. )
  362. if response.status_code != 200:
  363. logger.error(f"Error partitioning file: {response.text}")
  364. raise ValueError(
  365. f"Error partitioning file: {response.text}"
  366. )
  367. elements = response.json().get("elements", [])
  368. iteration = 0 # if there are no chunks
  369. for iteration, element in enumerate(elements):
  370. if isinstance(element, FallbackElement):
  371. text = element.text
  372. metadata = copy(document.metadata)
  373. metadata.update(element.metadata)
  374. else:
  375. element_dict = (
  376. element if isinstance(element, dict) else element.to_dict()
  377. )
  378. text = element_dict.get("text", "")
  379. if text == "":
  380. continue
  381. metadata = copy(document.metadata)
  382. for key, value in element_dict.items():
  383. if key == "metadata":
  384. for k, v in value.items():
  385. if k not in metadata and k != "orig_elements":
  386. metadata[f"unstructured_{k}"] = v
  387. # indicate that the document was chunked using unstructured
  388. # nullifies the need for chunking in the pipeline
  389. metadata["partitioned_by_unstructured"] = True
  390. metadata["chunk_order"] = iteration
  391. # creating the text extraction
  392. yield DocumentChunk(
  393. id=generate_extraction_id(document.id, iteration),
  394. document_id=document.id,
  395. owner_id=document.owner_id,
  396. collection_ids=document.collection_ids,
  397. data=text,
  398. metadata=metadata,
  399. )
  400. logger.debug(
  401. f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, "
  402. f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} "
  403. f"into {iteration + 1} extractions in t={time.time() - t0:.2f} seconds."
  404. )
  405. def get_parser_for_document_type(self, doc_type: DocumentType) -> str:
  406. return "unstructured_local"