base.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. import asyncio
  2. import base64
  3. import logging
  4. import os
  5. import time
  6. from copy import copy
  7. from io import BytesIO
  8. from typing import Any, AsyncGenerator
  9. import httpx
  10. from unstructured_client import UnstructuredClient
  11. from unstructured_client.models import operations, shared
  12. from core import parsers
  13. from core.base import (
  14. AsyncParser,
  15. ChunkingStrategy,
  16. Document,
  17. DocumentChunk,
  18. DocumentType,
  19. RecursiveCharacterTextSplitter,
  20. )
  21. from core.base.abstractions import R2RSerializable
  22. from core.base.providers.ingestion import IngestionConfig, IngestionProvider
  23. from core.providers.ocr import MistralOCRProvider
  24. from core.utils import generate_extraction_id
  25. from ...database import PostgresDatabaseProvider
  26. from ...llm import (
  27. LiteLLMCompletionProvider,
  28. OpenAICompletionProvider,
  29. R2RCompletionProvider,
  30. )
  31. logger = logging.getLogger()
  32. class FallbackElement(R2RSerializable):
  33. text: str
  34. metadata: dict[str, Any]
  35. class UnstructuredIngestionConfig(IngestionConfig):
  36. combine_under_n_chars: int = 128
  37. max_characters: int = 500
  38. new_after_n_chars: int = 1500
  39. overlap: int = 64
  40. coordinates: bool | None = None
  41. encoding: str | None = None # utf-8
  42. extract_image_block_types: list[str] | None = None
  43. gz_uncompressed_content_type: str | None = None
  44. hi_res_model_name: str | None = None
  45. include_orig_elements: bool | None = None
  46. include_page_breaks: bool | None = None
  47. languages: list[str] | None = None
  48. multipage_sections: bool | None = None
  49. ocr_languages: list[str] | None = None
  50. # output_format: Optional[str] = "application/json"
  51. overlap_all: bool | None = None
  52. pdf_infer_table_structure: bool | None = None
  53. similarity_threshold: float | None = None
  54. skip_infer_table_types: list[str] | None = None
  55. split_pdf_concurrency_level: int | None = None
  56. split_pdf_page: bool | None = None
  57. starting_page_number: int | None = None
  58. strategy: str | None = None
  59. chunking_strategy: str | ChunkingStrategy | None = None # type: ignore
  60. unique_element_ids: bool | None = None
  61. xml_keep_tags: bool | None = None
  62. def to_ingestion_request(self):
  63. import json
  64. x = json.loads(self.json())
  65. x.pop("extra_fields", None)
  66. x.pop("provider", None)
  67. x.pop("excluded_parsers", None)
  68. x = {k: v for k, v in x.items() if v is not None}
  69. return x
  70. class UnstructuredIngestionProvider(IngestionProvider):
  71. R2R_FALLBACK_PARSERS = {
  72. DocumentType.GIF: [parsers.ImageParser], # type: ignore
  73. DocumentType.JPEG: [parsers.ImageParser], # type: ignore
  74. DocumentType.JPG: [parsers.ImageParser], # type: ignore
  75. DocumentType.PNG: [parsers.ImageParser], # type: ignore
  76. DocumentType.SVG: [parsers.ImageParser], # type: ignore
  77. DocumentType.HEIC: [parsers.ImageParser], # type: ignore
  78. DocumentType.MP3: [parsers.AudioParser], # type: ignore
  79. DocumentType.JSON: [parsers.JSONParser], # type: ignore
  80. DocumentType.HTML: [parsers.HTMLParser], # type: ignore
  81. DocumentType.XLS: [parsers.XLSParser], # type: ignore
  82. DocumentType.XLSX: [parsers.XLSXParser], # type: ignore
  83. #DocumentType.DOC: [parsers.DOCParser], # type: ignore
  84. DocumentType.PPT: [parsers.PPTParser], # type: ignore
  85. DocumentType.CSV: [parsers.CSVParserAdvanced], # type: ignore
  86. }
  87. EXTRA_PARSERS = {
  88. #DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced}, # type: ignore
  89. #DocumentType.PDF: {
  90. # "ocr": parsers.OCRPDFParser, # type: ignore
  91. # "unstructured": parsers.PDFParserUnstructured, # type: ignore
  92. # "zerox": parsers.VLMPDFParser, # type: ignore
  93. #},
  94. #DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced}, # type: ignore
  95. }
  96. IMAGE_TYPES = {
  97. DocumentType.GIF,
  98. DocumentType.HEIC,
  99. DocumentType.JPG,
  100. DocumentType.JPEG,
  101. DocumentType.PNG,
  102. DocumentType.SVG,
  103. }
  104. def __init__(
  105. self,
  106. config: UnstructuredIngestionConfig,
  107. database_provider: PostgresDatabaseProvider,
  108. llm_provider: (
  109. LiteLLMCompletionProvider
  110. | OpenAICompletionProvider
  111. | R2RCompletionProvider
  112. ),
  113. ocr_provider: MistralOCRProvider,
  114. ):
  115. super().__init__(config, database_provider, llm_provider)
  116. self.config: UnstructuredIngestionConfig = config
  117. self.database_provider: PostgresDatabaseProvider = database_provider
  118. self.llm_provider: (
  119. LiteLLMCompletionProvider
  120. | OpenAICompletionProvider
  121. | R2RCompletionProvider
  122. ) = llm_provider
  123. self.ocr_provider: MistralOCRProvider = ocr_provider
  124. self.client: UnstructuredClient | httpx.AsyncClient
  125. if config.provider == "unstructured_api":
  126. try:
  127. self.unstructured_api_auth = os.environ["UNSTRUCTURED_API_KEY"]
  128. except KeyError as e:
  129. raise ValueError(
  130. "UNSTRUCTURED_API_KEY environment variable is not set"
  131. ) from e
  132. self.unstructured_api_url = os.environ.get(
  133. "UNSTRUCTURED_API_URL",
  134. "https://api.unstructuredapp.io/general/v0/general",
  135. )
  136. self.client = UnstructuredClient(
  137. api_key_auth=self.unstructured_api_auth,
  138. server_url=self.unstructured_api_url,
  139. )
  140. self.shared = shared
  141. self.operations = operations
  142. else:
  143. try:
  144. self.local_unstructured_url = os.environ[
  145. "UNSTRUCTURED_SERVICE_URL"
  146. ]
  147. except KeyError as e:
  148. raise ValueError(
  149. "UNSTRUCTURED_SERVICE_URL environment variable is not set"
  150. ) from e
  151. self.client = httpx.AsyncClient()
  152. self.parsers: dict[DocumentType, AsyncParser] = {}
  153. self._initialize_parsers()
  154. def _initialize_parsers(self):
  155. for doc_type, parsers in self.R2R_FALLBACK_PARSERS.items():
  156. for parser in parsers:
  157. if (
  158. doc_type not in self.config.excluded_parsers
  159. and doc_type not in self.parsers
  160. ):
  161. # will choose the first parser in the list
  162. self.parsers[doc_type] = parser(
  163. config=self.config,
  164. database_provider=self.database_provider,
  165. llm_provider=self.llm_provider,
  166. )
  167. # TODO - Reduce code duplication between Unstructured & R2R
  168. for doc_type, parser_names in self.config.extra_parsers.items():
  169. if not isinstance(parser_names, list):
  170. parser_names = [parser_names]
  171. for parser_name in parser_names:
  172. parser_key = f"{parser_name}_{str(doc_type)}"
  173. try:
  174. self.parsers[parser_key] = self.EXTRA_PARSERS[doc_type][
  175. parser_name
  176. ](
  177. config=self.config,
  178. database_provider=self.database_provider,
  179. llm_provider=self.llm_provider,
  180. ocr_provider=self.ocr_provider,
  181. )
  182. logger.info(
  183. f"Initialized extra parser {parser_name} for {doc_type}"
  184. )
  185. except KeyError as e:
  186. logger.error(
  187. f"Parser {parser_name} for document type {doc_type} not found: {e}"
  188. )
  189. async def parse_fallback(
  190. self,
  191. file_content: bytes,
  192. ingestion_config: dict,
  193. parser_name: str,
  194. ) -> AsyncGenerator[FallbackElement, None]:
  195. contents = []
  196. async for chunk in self.parsers[parser_name].ingest( # type: ignore
  197. file_content, **ingestion_config
  198. ): # type: ignore
  199. if isinstance(chunk, dict) and chunk.get("content"):
  200. contents.append(chunk)
  201. elif chunk: # Handle string output for backward compatibility
  202. contents.append({"content": chunk})
  203. if not contents:
  204. logging.warning(
  205. "No valid text content was extracted during parsing"
  206. )
  207. return
  208. logging.info(f"Fallback ingestion with config = {ingestion_config}")
  209. vlm_ocr_one_page_per_chunk = ingestion_config.get(
  210. "vlm_ocr_one_page_per_chunk", True
  211. )
  212. iteration = 0
  213. for content_item in contents:
  214. text = content_item["content"]
  215. if vlm_ocr_one_page_per_chunk and parser_name.startswith(
  216. ("zerox_", "ocr_")
  217. ):
  218. # Use one page per chunk for OCR/VLM
  219. metadata = {"chunk_id": iteration}
  220. if "page_number" in content_item:
  221. metadata["page_number"] = content_item["page_number"]
  222. yield FallbackElement(
  223. text=text or "No content extracted.",
  224. metadata=metadata,
  225. )
  226. iteration += 1
  227. await asyncio.sleep(0)
  228. else:
  229. # Use regular text splitting
  230. loop = asyncio.get_event_loop()
  231. splitter = RecursiveCharacterTextSplitter(
  232. chunk_size=ingestion_config["new_after_n_chars"],
  233. chunk_overlap=ingestion_config["overlap"],
  234. )
  235. chunks = await loop.run_in_executor(
  236. None, splitter.create_documents, [text]
  237. )
  238. for text_chunk in chunks:
  239. metadata = {"chunk_id": iteration}
  240. if "page_number" in content_item:
  241. metadata["page_number"] = content_item["page_number"]
  242. yield FallbackElement(
  243. text=text_chunk.page_content,
  244. metadata=metadata,
  245. )
  246. iteration += 1
  247. await asyncio.sleep(0)
  248. async def parse(
  249. self,
  250. file_content: bytes,
  251. document: Document,
  252. ingestion_config_override: dict,
  253. ) -> AsyncGenerator[DocumentChunk, None]:
  254. ingestion_config = copy(
  255. {
  256. **self.config.to_ingestion_request(),
  257. **(ingestion_config_override or {}),
  258. }
  259. )
  260. # cleanup extra fields
  261. ingestion_config.pop("provider", None)
  262. ingestion_config.pop("excluded_parsers", None)
  263. t0 = time.time()
  264. parser_overrides = ingestion_config_override.get(
  265. "parser_overrides", {}
  266. )
  267. elements = []
  268. # TODO - Cleanup this approach to be less hardcoded
  269. # TODO - Remove code duplication between Unstructured & R2R
  270. logger.info(f"Parser overrides: {parser_overrides}")
  271. logger.info(f"R2R fallback parsers is: {document.document_type.value}")
  272. logger.info(f"R2R fallback parsers is: {self.EXTRA_PARSERS.keys()}")
  273. logger.info(f"R2R fallback parsers is: {document.document_type.value in self.EXTRA_PARSERS.keys()}")
  274. if document.document_type.value in parser_overrides:
  275. logger.info(
  276. f"Using parser_override for {document.document_type} with input value {parser_overrides[document.document_type.value]}"
  277. )
  278. if parser_overrides[document.document_type.value] == "zerox":
  279. async for element in self.parse_fallback(
  280. file_content,
  281. ingestion_config=ingestion_config,
  282. parser_name=f"zerox_{DocumentType.PDF.value}",
  283. ):
  284. logger.warning(
  285. f"Using parser_override for {document.document_type}"
  286. )
  287. elements.append(element)
  288. elif parser_overrides[document.document_type.value] == "ocr":
  289. async for element in self.parse_fallback(
  290. file_content,
  291. ingestion_config=ingestion_config,
  292. parser_name=f"ocr_{DocumentType.PDF.value}",
  293. ):
  294. logger.warning(
  295. f"Using OCR parser_override for {document.document_type}"
  296. )
  297. elements.append(element)
  298. elif document.document_type in self.R2R_FALLBACK_PARSERS.keys():
  299. logger.info(
  300. f"Parsing {document.document_type}: {document.id} with fallback parser"
  301. )
  302. try:
  303. async for element in self.parse_fallback(
  304. file_content,
  305. ingestion_config=ingestion_config,
  306. parser_name=document.document_type,
  307. ):
  308. elements.append(element)
  309. except Exception as e:
  310. logger.error(f"Error parsing {document.document_type}: {e}")
  311. raise e
  312. else:
  313. logger.info(
  314. f"Parsing {document.document_type}: {document.id} with unstructured"
  315. )
  316. file_io = BytesIO(file_content)
  317. # TODO - Include check on excluded parsers here.
  318. if self.config.provider == "unstructured_api":
  319. logger.info(f"Using API to parse document {document.id}")
  320. files = self.shared.Files(
  321. content=file_io.read(),
  322. file_name=document.metadata.get("title", "unknown_file"),
  323. )
  324. ingestion_config.pop("app", None)
  325. ingestion_config.pop("extra_parsers", None)
  326. req = self.operations.PartitionRequest(
  327. partition_parameters=self.shared.PartitionParameters(
  328. files=files,
  329. **ingestion_config,
  330. )
  331. )
  332. elements = await self.client.general.partition_async( # type: ignore
  333. request=req
  334. )
  335. elements = list(elements.elements) # type: ignore
  336. else:
  337. logger.info(
  338. f"Using local unstructured fastapi server to parse document {document.id}"
  339. )
  340. # Base64 encode the file content
  341. encoded_content = base64.b64encode(file_io.read()).decode(
  342. "utf-8"
  343. )
  344. logger.info(
  345. f"Sending a request to {self.local_unstructured_url}/partition"
  346. )
  347. #ingestion_config["strategy"] = "hi_res"
  348. print(ingestion_config)
  349. response = await self.client.post(
  350. f"{self.local_unstructured_url}/partition",
  351. json={
  352. "file_content": encoded_content, # Use encoded string
  353. "ingestion_config": ingestion_config,
  354. "filename": document.metadata.get("title", None),
  355. },
  356. timeout=3600, # Adjust timeout as needed
  357. )
  358. if response.status_code != 200:
  359. logger.error(f"Error partitioning file: {response.text}")
  360. raise ValueError(
  361. f"Error partitioning file: {response.text}"
  362. )
  363. elements = response.json().get("elements", [])
  364. iteration = 0 # if there are no chunks
  365. for iteration, element in enumerate(elements):
  366. if isinstance(element, FallbackElement):
  367. text = element.text
  368. metadata = copy(document.metadata)
  369. metadata.update(element.metadata)
  370. else:
  371. element_dict = (
  372. element if isinstance(element, dict) else element.to_dict()
  373. )
  374. text = element_dict.get("text", "")
  375. if text == "":
  376. continue
  377. metadata = copy(document.metadata)
  378. for key, value in element_dict.items():
  379. if key == "metadata":
  380. for k, v in value.items():
  381. if k not in metadata and k != "orig_elements":
  382. metadata[f"unstructured_{k}"] = v
  383. # indicate that the document was chunked using unstructured
  384. # nullifies the need for chunking in the pipeline
  385. metadata["partitioned_by_unstructured"] = True
  386. metadata["chunk_order"] = iteration
  387. # creating the text extraction
  388. yield DocumentChunk(
  389. id=generate_extraction_id(document.id, iteration),
  390. document_id=document.id,
  391. owner_id=document.owner_id,
  392. collection_ids=document.collection_ids,
  393. data=text,
  394. metadata=metadata,
  395. )
  396. logger.debug(
  397. f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, "
  398. f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} "
  399. f"into {iteration + 1} extractions in t={time.time() - t0:.2f} seconds."
  400. )
  401. def get_parser_for_document_type(self, doc_type: DocumentType) -> str:
  402. return "unstructured_local"