base.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # TODO - cleanup type issues in this file that relate to `bytes`
  2. import asyncio
  3. import base64
  4. import logging
  5. import os
  6. import time
  7. from copy import copy
  8. from io import BytesIO
  9. from typing import Any, AsyncGenerator, Optional, Union
  10. import httpx
  11. from unstructured_client import UnstructuredClient
  12. from unstructured_client.models import operations, shared
  13. from core import parsers
  14. from core.base import (
  15. AsyncParser,
  16. ChunkingStrategy,
  17. Document,
  18. DocumentChunk,
  19. DocumentType,
  20. RecursiveCharacterTextSplitter,
  21. )
  22. from core.base.abstractions import R2RSerializable
  23. from core.base.providers.ingestion import IngestionConfig, IngestionProvider
  24. from core.utils import generate_extraction_id
  25. from ....database.postgres import PostgresDatabaseProvider
  26. from ...llm import LiteLLMCompletionProvider, OpenAICompletionProvider
  27. logger = logging.getLogger()
  28. class FallbackElement(R2RSerializable):
  29. text: str
  30. metadata: dict[str, Any]
  31. class UnstructuredIngestionConfig(IngestionConfig):
  32. combine_under_n_chars: int = 128
  33. max_characters: int = 500
  34. new_after_n_chars: int = 1500
  35. overlap: int = 64
  36. coordinates: Optional[bool] = None
  37. encoding: Optional[str] = None # utf-8
  38. extract_image_block_types: Optional[list[str]] = None
  39. gz_uncompressed_content_type: Optional[str] = None
  40. hi_res_model_name: Optional[str] = None
  41. include_orig_elements: Optional[bool] = None
  42. include_page_breaks: Optional[bool] = None
  43. languages: Optional[list[str]] = None
  44. multipage_sections: Optional[bool] = None
  45. ocr_languages: Optional[list[str]] = None
  46. # output_format: Optional[str] = "application/json"
  47. overlap_all: Optional[bool] = None
  48. pdf_infer_table_structure: Optional[bool] = None
  49. similarity_threshold: Optional[float] = None
  50. skip_infer_table_types: Optional[list[str]] = None
  51. split_pdf_concurrency_level: Optional[int] = None
  52. split_pdf_page: Optional[bool] = None
  53. starting_page_number: Optional[int] = None
  54. strategy: Optional[str] = None
  55. chunking_strategy: Optional[ChunkingStrategy] = None
  56. unique_element_ids: Optional[bool] = None
  57. xml_keep_tags: Optional[bool] = None
  58. def to_ingestion_request(self):
  59. import json
  60. x = json.loads(self.json())
  61. x.pop("extra_fields", None)
  62. x.pop("provider", None)
  63. x.pop("excluded_parsers", None)
  64. x = {k: v for k, v in x.items() if v is not None}
  65. return x
  66. class UnstructuredIngestionProvider(IngestionProvider):
  67. R2R_FALLBACK_PARSERS = {
  68. DocumentType.GIF: [parsers.ImageParser],
  69. DocumentType.JPEG: [parsers.ImageParser],
  70. DocumentType.JPG: [parsers.ImageParser],
  71. DocumentType.PNG: [parsers.ImageParser],
  72. DocumentType.SVG: [parsers.ImageParser],
  73. DocumentType.MP3: [parsers.AudioParser],
  74. DocumentType.JSON: [parsers.JSONParser], # type: ignore
  75. DocumentType.HTML: [parsers.HTMLParser], # type: ignore
  76. DocumentType.XLSX: [parsers.XLSXParser], # type: ignore
  77. }
  78. EXTRA_PARSERS = {
  79. DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced}, # type: ignore
  80. DocumentType.PDF: {
  81. "unstructured": parsers.PDFParserUnstructured,
  82. "zerox": parsers.VLMPDFParser,
  83. },
  84. DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced}, # type: ignore
  85. }
  86. IMAGE_TYPES = {
  87. DocumentType.GIF,
  88. DocumentType.JPG,
  89. DocumentType.JPEG,
  90. DocumentType.PNG,
  91. DocumentType.SVG,
  92. }
  93. def __init__(
  94. self,
  95. config: UnstructuredIngestionConfig,
  96. database_provider: PostgresDatabaseProvider,
  97. llm_provider: Union[
  98. LiteLLMCompletionProvider, OpenAICompletionProvider
  99. ],
  100. ):
  101. super().__init__(config, database_provider, llm_provider)
  102. self.config: UnstructuredIngestionConfig = config
  103. self.database_provider: PostgresDatabaseProvider = database_provider
  104. self.llm_provider: Union[
  105. LiteLLMCompletionProvider, OpenAICompletionProvider
  106. ] = llm_provider
  107. if config.provider == "unstructured_api":
  108. try:
  109. self.unstructured_api_auth = os.environ["UNSTRUCTURED_API_KEY"]
  110. except KeyError as e:
  111. raise ValueError(
  112. "UNSTRUCTURED_API_KEY environment variable is not set"
  113. ) from e
  114. self.unstructured_api_url = os.environ.get(
  115. "UNSTRUCTURED_API_URL",
  116. "https://api.unstructuredapp.io/general/v0/general",
  117. )
  118. self.client = UnstructuredClient(
  119. api_key_auth=self.unstructured_api_auth,
  120. server_url=self.unstructured_api_url,
  121. )
  122. self.shared = shared
  123. self.operations = operations
  124. else:
  125. try:
  126. self.local_unstructured_url = os.environ[
  127. "UNSTRUCTURED_SERVICE_URL"
  128. ]
  129. except KeyError as e:
  130. raise ValueError(
  131. "UNSTRUCTURED_SERVICE_URL environment variable is not set"
  132. ) from e
  133. self.client = httpx.AsyncClient()
  134. self.parsers: dict[DocumentType, AsyncParser] = {}
  135. self._initialize_parsers()
  136. def _initialize_parsers(self):
  137. for doc_type, parsers in self.R2R_FALLBACK_PARSERS.items():
  138. for parser in parsers:
  139. if (
  140. doc_type not in self.config.excluded_parsers
  141. and doc_type not in self.parsers
  142. ):
  143. # will choose the first parser in the list
  144. self.parsers[doc_type] = parser(
  145. config=self.config,
  146. database_provider=self.database_provider,
  147. llm_provider=self.llm_provider,
  148. )
  149. # TODO - Reduce code duplication between Unstructured & R2R
  150. for doc_type, doc_parser_name in self.config.extra_parsers.items():
  151. self.parsers[
  152. f"{doc_parser_name}_{str(doc_type)}"
  153. ] = UnstructuredIngestionProvider.EXTRA_PARSERS[doc_type][
  154. doc_parser_name
  155. ](
  156. config=self.config,
  157. database_provider=self.database_provider,
  158. llm_provider=self.llm_provider,
  159. )
  160. async def parse_fallback(
  161. self,
  162. file_content: bytes,
  163. ingestion_config: dict,
  164. parser_name: str,
  165. ) -> AsyncGenerator[FallbackElement, None]:
  166. context = ""
  167. async for text in self.parsers[parser_name].ingest(file_content, **ingestion_config): # type: ignore
  168. context += text + "\n\n"
  169. logging.info(f"Fallback ingestion with config = {ingestion_config}")
  170. loop = asyncio.get_event_loop()
  171. splitter = RecursiveCharacterTextSplitter(
  172. chunk_size=ingestion_config["new_after_n_chars"],
  173. chunk_overlap=ingestion_config["overlap"],
  174. )
  175. chunks = await loop.run_in_executor(
  176. None, splitter.create_documents, [context]
  177. )
  178. for chunk_id, text_chunk in enumerate(chunks):
  179. yield FallbackElement(
  180. text=text_chunk.page_content,
  181. metadata={"chunk_id": chunk_id},
  182. )
  183. await asyncio.sleep(0)
  184. async def parse(
  185. self,
  186. file_content: bytes,
  187. document: Document,
  188. ingestion_config_override: dict,
  189. ) -> AsyncGenerator[DocumentChunk, None]:
  190. ingestion_config = copy(
  191. {
  192. **self.config.to_ingestion_request(),
  193. **(ingestion_config_override or {}),
  194. }
  195. )
  196. # cleanup extra fields
  197. ingestion_config.pop("provider", None)
  198. ingestion_config.pop("excluded_parsers", None)
  199. t0 = time.time()
  200. parser_overrides = ingestion_config_override.get(
  201. "parser_overrides", {}
  202. )
  203. elements = []
  204. # TODO - Cleanup this approach to be less hardcoded
  205. # TODO - Remove code duplication between Unstructured & R2R
  206. if document.document_type.value in parser_overrides:
  207. logger.info(
  208. f"Using parser_override for {document.document_type} with input value {parser_overrides[document.document_type.value]}"
  209. )
  210. async for element in self.parse_fallback(
  211. file_content,
  212. ingestion_config=ingestion_config,
  213. parser_name=f"zerox_{DocumentType.PDF.value}",
  214. ):
  215. elements.append(element)
  216. elif document.document_type in self.R2R_FALLBACK_PARSERS.keys():
  217. logger.info(
  218. f"Parsing {document.document_type}: {document.id} with fallback parser"
  219. )
  220. async for element in self.parse_fallback(
  221. file_content,
  222. ingestion_config=ingestion_config,
  223. parser_name=document.document_type,
  224. ):
  225. elements.append(element)
  226. else:
  227. logger.info(
  228. f"Parsing {document.document_type}: {document.id} with unstructured"
  229. )
  230. if isinstance(file_content, bytes):
  231. file_content = BytesIO(file_content) # type: ignore
  232. # TODO - Include check on excluded parsers here.
  233. if self.config.provider == "unstructured_api":
  234. logger.info(f"Using API to parse document {document.id}")
  235. files = self.shared.Files(
  236. content=file_content.read(), # type: ignore
  237. file_name=document.metadata.get("title", "unknown_file"),
  238. )
  239. ingestion_config.pop("app", None)
  240. ingestion_config.pop("extra_parsers", None)
  241. req = self.operations.PartitionRequest(
  242. self.shared.PartitionParameters(
  243. files=files,
  244. **ingestion_config,
  245. )
  246. )
  247. elements = self.client.general.partition(req) # type: ignore
  248. elements = list(elements.elements) # type: ignore
  249. else:
  250. logger.info(
  251. f"Using local unstructured fastapi server to parse document {document.id}"
  252. )
  253. # Base64 encode the file content
  254. encoded_content = base64.b64encode(file_content.read()).decode( # type: ignore
  255. "utf-8"
  256. )
  257. logger.info(
  258. f"Sending a request to {self.local_unstructured_url}/partition"
  259. )
  260. response = await self.client.post(
  261. f"{self.local_unstructured_url}/partition",
  262. json={
  263. "file_content": encoded_content, # Use encoded string
  264. "ingestion_config": ingestion_config,
  265. "filename": document.metadata.get("title", None),
  266. },
  267. timeout=3600, # Adjust timeout as needed
  268. )
  269. if response.status_code != 200:
  270. logger.error(f"Error partitioning file: {response.text}")
  271. raise ValueError(
  272. f"Error partitioning file: {response.text}"
  273. )
  274. elements = response.json().get("elements", [])
  275. iteration = 0 # if there are no chunks
  276. for iteration, element in enumerate(elements):
  277. if isinstance(element, FallbackElement):
  278. text = element.text
  279. metadata = copy(document.metadata)
  280. metadata.update(element.metadata)
  281. else:
  282. element_dict = (
  283. element.to_dict()
  284. if not isinstance(element, dict)
  285. else element
  286. )
  287. text = element_dict.get("text", "")
  288. if text == "":
  289. continue
  290. metadata = copy(document.metadata)
  291. for key, value in element_dict.items():
  292. if key == "metadata":
  293. for k, v in value.items():
  294. if k not in metadata:
  295. if k != "orig_elements":
  296. metadata[f"unstructured_{k}"] = v
  297. # indicate that the document was chunked using unstructured
  298. # nullifies the need for chunking in the pipeline
  299. metadata["partitioned_by_unstructured"] = True
  300. metadata["chunk_order"] = iteration
  301. # creating the text extraction
  302. yield DocumentChunk(
  303. id=generate_extraction_id(document.id, iteration),
  304. document_id=document.id,
  305. owner_id=document.owner_id,
  306. collection_ids=document.collection_ids,
  307. data=text,
  308. metadata=metadata,
  309. )
  310. # TODO: explore why this is throwing inadvertedly
  311. # if iteration == 0:
  312. # raise ValueError(f"No chunks found for document {document.id}")
  313. logger.debug(
  314. f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, "
  315. f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} "
  316. f"into {iteration + 1} extractions in t={time.time() - t0:.2f} seconds."
  317. )
  318. def get_parser_for_document_type(self, doc_type: DocumentType) -> str:
  319. return "unstructured_local"