base.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. # TODO - cleanup type issues in this file that relate to `bytes`
  2. import asyncio
  3. import base64
  4. import logging
  5. import os
  6. import time
  7. from copy import copy
  8. from io import BytesIO
  9. from typing import Any, AsyncGenerator, Optional
  10. import httpx
  11. from unstructured_client import UnstructuredClient
  12. from unstructured_client.models import operations, shared
  13. from core import parsers
  14. from core.base import (
  15. AsyncParser,
  16. ChunkingStrategy,
  17. Document,
  18. DocumentChunk,
  19. DocumentType,
  20. RecursiveCharacterTextSplitter,
  21. )
  22. from core.base.abstractions import R2RSerializable
  23. from core.base.providers.ingestion import IngestionConfig, IngestionProvider
  24. from core.utils import generate_extraction_id
  25. from ....database.postgres import PostgresDatabaseProvider
  26. from ...llm import LiteLLMCompletionProvider, OpenAICompletionProvider
  27. logger = logging.getLogger()
  28. class FallbackElement(R2RSerializable):
  29. text: str
  30. metadata: dict[str, Any]
  31. class UnstructuredIngestionConfig(IngestionConfig):
  32. combine_under_n_chars: int = 128
  33. max_characters: int = 500
  34. new_after_n_chars: int = 1500
  35. overlap: int = 64
  36. coordinates: Optional[bool] = None
  37. encoding: Optional[str] = None # utf-8
  38. extract_image_block_types: Optional[list[str]] = None
  39. gz_uncompressed_content_type: Optional[str] = None
  40. hi_res_model_name: Optional[str] = None
  41. include_orig_elements: Optional[bool] = None
  42. include_page_breaks: Optional[bool] = None
  43. languages: Optional[list[str]] = None
  44. multipage_sections: Optional[bool] = None
  45. ocr_languages: Optional[list[str]] = None
  46. # output_format: Optional[str] = "application/json"
  47. overlap_all: Optional[bool] = None
  48. pdf_infer_table_structure: Optional[bool] = None
  49. similarity_threshold: Optional[float] = None
  50. skip_infer_table_types: Optional[list[str]] = None
  51. split_pdf_concurrency_level: Optional[int] = None
  52. split_pdf_page: Optional[bool] = None
  53. starting_page_number: Optional[int] = None
  54. strategy: Optional[str] = None
  55. chunking_strategy: Optional[str | ChunkingStrategy] = None
  56. unique_element_ids: Optional[bool] = None
  57. xml_keep_tags: Optional[bool] = None
  58. def to_ingestion_request(self):
  59. import json
  60. x = json.loads(self.json())
  61. x.pop("extra_fields", None)
  62. x.pop("provider", None)
  63. x.pop("excluded_parsers", None)
  64. x = {k: v for k, v in x.items() if v is not None}
  65. return x
  66. class UnstructuredIngestionProvider(IngestionProvider):
  67. R2R_FALLBACK_PARSERS = {
  68. DocumentType.GIF: [parsers.ImageParser], # type: ignore
  69. DocumentType.JPEG: [parsers.ImageParser], # type: ignore
  70. DocumentType.JPG: [parsers.ImageParser], # type: ignore
  71. DocumentType.PNG: [parsers.ImageParser], # type: ignore
  72. DocumentType.SVG: [parsers.ImageParser], # type: ignore
  73. DocumentType.HEIC: [parsers.ImageParser], # type: ignore
  74. DocumentType.MP3: [parsers.AudioParser], # type: ignore
  75. DocumentType.JSON: [parsers.JSONParser], # type: ignore
  76. DocumentType.HTML: [parsers.HTMLParser], # type: ignore
  77. DocumentType.XLS: [parsers.XLSParser], # type: ignore
  78. DocumentType.XLSX: [parsers.XLSXParser], # type: ignore
  79. DocumentType.DOC: [parsers.DOCParser], # type: ignore
  80. DocumentType.PPT: [parsers.PPTParser], # type: ignore
  81. }
  82. EXTRA_PARSERS = {
  83. DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced}, # type: ignore
  84. DocumentType.PDF: {
  85. "unstructured": parsers.PDFParserUnstructured, # type: ignore
  86. "zerox": parsers.VLMPDFParser, # type: ignore
  87. },
  88. DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced}, # type: ignore
  89. }
  90. IMAGE_TYPES = {
  91. DocumentType.GIF,
  92. DocumentType.HEIC,
  93. DocumentType.JPG,
  94. DocumentType.JPEG,
  95. DocumentType.PNG,
  96. DocumentType.SVG,
  97. }
  98. def __init__(
  99. self,
  100. config: UnstructuredIngestionConfig,
  101. database_provider: PostgresDatabaseProvider,
  102. llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider,
  103. ):
  104. super().__init__(config, database_provider, llm_provider)
  105. self.config: UnstructuredIngestionConfig = config
  106. self.database_provider: PostgresDatabaseProvider = database_provider
  107. self.llm_provider: (
  108. LiteLLMCompletionProvider | OpenAICompletionProvider
  109. ) = llm_provider
  110. if config.provider == "unstructured_api":
  111. try:
  112. self.unstructured_api_auth = os.environ["UNSTRUCTURED_API_KEY"]
  113. except KeyError as e:
  114. raise ValueError(
  115. "UNSTRUCTURED_API_KEY environment variable is not set"
  116. ) from e
  117. self.unstructured_api_url = os.environ.get(
  118. "UNSTRUCTURED_API_URL",
  119. "https://api.unstructuredapp.io/general/v0/general",
  120. )
  121. self.client = UnstructuredClient(
  122. api_key_auth=self.unstructured_api_auth,
  123. server_url=self.unstructured_api_url,
  124. )
  125. self.shared = shared
  126. self.operations = operations
  127. else:
  128. try:
  129. self.local_unstructured_url = os.environ[
  130. "UNSTRUCTURED_SERVICE_URL"
  131. ]
  132. except KeyError as e:
  133. raise ValueError(
  134. "UNSTRUCTURED_SERVICE_URL environment variable is not set"
  135. ) from e
  136. self.client = httpx.AsyncClient()
  137. self.parsers: dict[DocumentType, AsyncParser] = {}
  138. self._initialize_parsers()
  139. def _initialize_parsers(self):
  140. for doc_type, parsers in self.R2R_FALLBACK_PARSERS.items():
  141. for parser in parsers:
  142. if (
  143. doc_type not in self.config.excluded_parsers
  144. and doc_type not in self.parsers
  145. ):
  146. # will choose the first parser in the list
  147. self.parsers[doc_type] = parser(
  148. config=self.config,
  149. database_provider=self.database_provider,
  150. llm_provider=self.llm_provider,
  151. )
  152. # TODO - Reduce code duplication between Unstructured & R2R
  153. for doc_type, doc_parser_name in self.config.extra_parsers.items():
  154. self.parsers[
  155. f"{doc_parser_name}_{str(doc_type)}"
  156. ] = UnstructuredIngestionProvider.EXTRA_PARSERS[doc_type][
  157. doc_parser_name
  158. ](
  159. config=self.config,
  160. database_provider=self.database_provider,
  161. llm_provider=self.llm_provider,
  162. )
  163. async def parse_fallback(
  164. self,
  165. file_content: bytes,
  166. ingestion_config: dict,
  167. parser_name: str,
  168. ) -> AsyncGenerator[FallbackElement, None]:
  169. context = ""
  170. async for text in self.parsers[parser_name].ingest(file_content, **ingestion_config): # type: ignore
  171. context += text + "\n\n"
  172. logging.info(f"Fallback ingestion with config = {ingestion_config}")
  173. loop = asyncio.get_event_loop()
  174. splitter = RecursiveCharacterTextSplitter(
  175. chunk_size=ingestion_config["new_after_n_chars"],
  176. chunk_overlap=ingestion_config["overlap"],
  177. )
  178. chunks = await loop.run_in_executor(
  179. None, splitter.create_documents, [context]
  180. )
  181. for chunk_id, text_chunk in enumerate(chunks):
  182. yield FallbackElement(
  183. text=text_chunk.page_content,
  184. metadata={"chunk_id": chunk_id},
  185. )
  186. await asyncio.sleep(0)
  187. async def parse(
  188. self,
  189. file_content: bytes,
  190. document: Document,
  191. ingestion_config_override: dict,
  192. ) -> AsyncGenerator[DocumentChunk, None]:
  193. ingestion_config = copy(
  194. {
  195. **self.config.to_ingestion_request(),
  196. **(ingestion_config_override or {}),
  197. }
  198. )
  199. # cleanup extra fields
  200. ingestion_config.pop("provider", None)
  201. ingestion_config.pop("excluded_parsers", None)
  202. t0 = time.time()
  203. parser_overrides = ingestion_config_override.get(
  204. "parser_overrides", {}
  205. )
  206. elements = []
  207. # TODO - Cleanup this approach to be less hardcoded
  208. # TODO - Remove code duplication between Unstructured & R2R
  209. if document.document_type.value in parser_overrides:
  210. logger.info(
  211. f"Using parser_override for {document.document_type} with input value {parser_overrides[document.document_type.value]}"
  212. )
  213. async for element in self.parse_fallback(
  214. file_content,
  215. ingestion_config=ingestion_config,
  216. parser_name=f"zerox_{DocumentType.PDF.value}",
  217. ):
  218. elements.append(element)
  219. elif document.document_type in self.R2R_FALLBACK_PARSERS.keys():
  220. logger.info(
  221. f"Parsing {document.document_type}: {document.id} with fallback parser"
  222. )
  223. async for element in self.parse_fallback(
  224. file_content,
  225. ingestion_config=ingestion_config,
  226. parser_name=document.document_type,
  227. ):
  228. elements.append(element)
  229. else:
  230. logger.info(
  231. f"Parsing {document.document_type}: {document.id} with unstructured"
  232. )
  233. if isinstance(file_content, bytes):
  234. file_content = BytesIO(file_content) # type: ignore
  235. # TODO - Include check on excluded parsers here.
  236. if self.config.provider == "unstructured_api":
  237. logger.info(f"Using API to parse document {document.id}")
  238. files = self.shared.Files(
  239. content=file_content.read(), # type: ignore
  240. file_name=document.metadata.get("title", "unknown_file"),
  241. )
  242. ingestion_config.pop("app", None)
  243. ingestion_config.pop("extra_parsers", None)
  244. req = self.operations.PartitionRequest(
  245. self.shared.PartitionParameters(
  246. files=files,
  247. **ingestion_config,
  248. )
  249. )
  250. elements = self.client.general.partition(req) # type: ignore
  251. elements = list(elements.elements) # type: ignore
  252. else:
  253. logger.info(
  254. f"Using local unstructured fastapi server to parse document {document.id}"
  255. )
  256. # Base64 encode the file content
  257. encoded_content = base64.b64encode(file_content.read()).decode( # type: ignore
  258. "utf-8"
  259. )
  260. logger.info(
  261. f"Sending a request to {self.local_unstructured_url}/partition"
  262. )
  263. response = await self.client.post(
  264. f"{self.local_unstructured_url}/partition",
  265. json={
  266. "file_content": encoded_content, # Use encoded string
  267. "ingestion_config": ingestion_config,
  268. "filename": document.metadata.get("title", None),
  269. },
  270. timeout=3600, # Adjust timeout as needed
  271. )
  272. if response.status_code != 200:
  273. logger.error(f"Error partitioning file: {response.text}")
  274. raise ValueError(
  275. f"Error partitioning file: {response.text}"
  276. )
  277. elements = response.json().get("elements", [])
  278. iteration = 0 # if there are no chunks
  279. for iteration, element in enumerate(elements):
  280. if isinstance(element, FallbackElement):
  281. text = element.text
  282. metadata = copy(document.metadata)
  283. metadata.update(element.metadata)
  284. else:
  285. element_dict = (
  286. element if isinstance(element, dict) else element.to_dict()
  287. )
  288. text = element_dict.get("text", "")
  289. if text == "":
  290. continue
  291. metadata = copy(document.metadata)
  292. for key, value in element_dict.items():
  293. if key == "metadata":
  294. for k, v in value.items():
  295. if k not in metadata and k != "orig_elements":
  296. metadata[f"unstructured_{k}"] = v
  297. # indicate that the document was chunked using unstructured
  298. # nullifies the need for chunking in the pipeline
  299. metadata["partitioned_by_unstructured"] = True
  300. metadata["chunk_order"] = iteration
  301. # creating the text extraction
  302. yield DocumentChunk(
  303. id=generate_extraction_id(document.id, iteration),
  304. document_id=document.id,
  305. owner_id=document.owner_id,
  306. collection_ids=document.collection_ids,
  307. data=text,
  308. metadata=metadata,
  309. )
  310. # TODO: explore why this is throwing inadvertedly
  311. # if iteration == 0:
  312. # raise ValueError(f"No chunks found for document {document.id}")
  313. logger.debug(
  314. f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, "
  315. f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} "
  316. f"into {iteration + 1} extractions in t={time.time() - t0:.2f} seconds."
  317. )
  318. def get_parser_for_document_type(self, doc_type: DocumentType) -> str:
  319. return "unstructured_local"