pdf_parser.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. # type: ignore
  2. import asyncio
  3. import base64
  4. import json
  5. import logging
  6. import string
  7. import time
  8. import unicodedata
  9. from io import BytesIO
  10. from typing import AsyncGenerator
  11. import pdf2image
  12. from mistralai.models import OCRResponse
  13. from pypdf import PdfReader
  14. from core.base.abstractions import GenerationConfig
  15. from core.base.parsers.base_parser import AsyncParser
  16. from core.base.providers import (
  17. CompletionProvider,
  18. DatabaseProvider,
  19. IngestionConfig,
  20. OCRProvider,
  21. )
  22. logger = logging.getLogger()
  23. class OCRPDFParser(AsyncParser[str | bytes]):
  24. """
  25. A parser for PDF documents using Mistral's OCR for page processing.
  26. Mistral supports directly processing PDF files, so this parser is a simple wrapper around the Mistral OCR API.
  27. """
  28. def __init__(
  29. self,
  30. config: IngestionConfig,
  31. database_provider: DatabaseProvider,
  32. llm_provider: CompletionProvider,
  33. ocr_provider: OCRProvider,
  34. ):
  35. self.config = config
  36. self.database_provider = database_provider
  37. self.ocr_provider = ocr_provider
  38. async def ingest(
  39. self, data: str | bytes, **kwargs
  40. ) -> AsyncGenerator[str, None]:
  41. """Ingest PDF data and yield text from each page."""
  42. try:
  43. logger.info("Starting PDF ingestion using MistralOCRParser")
  44. if isinstance(data, str):
  45. response: OCRResponse = await self.ocr_provider.process_pdf(
  46. file_path=data
  47. )
  48. else:
  49. response: OCRResponse = await self.ocr_provider.process_pdf(
  50. file_content=data
  51. )
  52. for page in response.pages:
  53. yield {
  54. "content": page.markdown,
  55. "page_number": page.index + 1, # Mistral is 0-indexed
  56. }
  57. except Exception as e:
  58. logger.error(f"Error processing PDF with Mistral OCR: {str(e)}")
  59. raise
  60. class VLMPDFParser(AsyncParser[str | bytes]):
  61. """A parser for PDF documents using vision models for page processing."""
  62. def __init__(
  63. self,
  64. config: IngestionConfig,
  65. database_provider: DatabaseProvider,
  66. llm_provider: CompletionProvider,
  67. ocr_provider: OCRProvider,
  68. ):
  69. self.database_provider = database_provider
  70. self.llm_provider = llm_provider
  71. self.config = config
  72. self.vision_prompt_text = None
  73. self.vlm_batch_size = self.config.vlm_batch_size or 5
  74. self.vlm_max_tokens_to_sample = (
  75. self.config.vlm_max_tokens_to_sample or 1024
  76. )
  77. self.max_concurrent_vlm_tasks = (
  78. self.config.max_concurrent_vlm_tasks or 5
  79. )
  80. self.semaphore = None
  81. async def process_page(self, image, page_num: int) -> dict[str, str]:
  82. """Process a single PDF page using the vision model."""
  83. page_start = time.perf_counter()
  84. try:
  85. img_byte_arr = BytesIO()
  86. image.save(img_byte_arr, format="JPEG")
  87. image_data = img_byte_arr.getvalue()
  88. # Convert image bytes to base64
  89. image_base64 = base64.b64encode(image_data).decode("utf-8")
  90. model = self.config.app.vlm
  91. # Configure generation parameters
  92. generation_config = GenerationConfig(
  93. model=self.config.vlm or self.config.app.vlm,
  94. stream=False,
  95. max_tokens_to_sample=self.vlm_max_tokens_to_sample,
  96. )
  97. is_anthropic = model and "anthropic/" in model
  98. # Prepare message with image content
  99. if is_anthropic:
  100. messages = [
  101. {
  102. "role": "user",
  103. "content": [
  104. {"type": "text", "text": self.vision_prompt_text},
  105. {
  106. "type": "image",
  107. "source": {
  108. "type": "base64",
  109. "media_type": "image/jpeg",
  110. "data": image_base64,
  111. },
  112. },
  113. ],
  114. }
  115. ]
  116. else:
  117. # Use OpenAI format
  118. messages = [
  119. {
  120. "role": "user",
  121. "content": [
  122. {"type": "text", "text": self.vision_prompt_text},
  123. {
  124. "type": "image_url",
  125. "image_url": {
  126. "url": f"data:image/jpeg;base64,{image_base64}"
  127. },
  128. },
  129. ],
  130. }
  131. ]
  132. logger.debug(f"Sending page {page_num} to vision model.")
  133. if is_anthropic:
  134. response = await self.llm_provider.aget_completion(
  135. messages=messages,
  136. generation_config=generation_config,
  137. apply_timeout=True,
  138. tools=[
  139. {
  140. "name": "parse_pdf_page",
  141. "description": "Parse text content from a PDF page",
  142. "input_schema": {
  143. "type": "object",
  144. "properties": {
  145. "page_content": {
  146. "type": "string",
  147. "description": "Extracted text from the PDF page, transcribed into markdown",
  148. },
  149. "thoughts": {
  150. "type": "string",
  151. "description": "Any thoughts or comments on the text",
  152. },
  153. },
  154. "required": ["page_content"],
  155. },
  156. }
  157. ],
  158. tool_choice={"type": "tool", "name": "parse_pdf_page"},
  159. )
  160. if (
  161. response.choices
  162. and response.choices[0].message
  163. and response.choices[0].message.tool_calls
  164. ):
  165. tool_call = response.choices[0].message.tool_calls[0]
  166. args = json.loads(tool_call.function.arguments)
  167. content = args.get("page_content", "")
  168. page_elapsed = time.perf_counter() - page_start
  169. logger.debug(
  170. f"Processed page {page_num} in {page_elapsed:.2f} seconds."
  171. )
  172. return {"page": str(page_num), "content": content}
  173. else:
  174. logger.warning(
  175. f"No valid tool call in response for page {page_num}, document might be missing text."
  176. )
  177. return {"page": str(page_num), "content": ""}
  178. else:
  179. response = await self.llm_provider.aget_completion(
  180. messages=messages,
  181. generation_config=generation_config,
  182. apply_timeout=True,
  183. )
  184. if response.choices and response.choices[0].message:
  185. content = response.choices[0].message.content
  186. page_elapsed = time.perf_counter() - page_start
  187. logger.debug(
  188. f"Processed page {page_num} in {page_elapsed:.2f} seconds."
  189. )
  190. return {"page": str(page_num), "content": content}
  191. else:
  192. msg = f"No response content for page {page_num}"
  193. logger.error(msg)
  194. return {"page": str(page_num), "content": ""}
  195. except Exception as e:
  196. logger.error(
  197. f"Error processing page {page_num} with vision model: {str(e)}"
  198. )
  199. # Return empty content rather than raising to avoid failing the entire batch
  200. return {
  201. "page": str(page_num),
  202. "content": f"Error processing page: {str(e)}",
  203. }
  204. async def process_and_yield(self, image, page_num: int):
  205. """Process a page and yield the result."""
  206. async with self.semaphore:
  207. result = await self.process_page(image, page_num)
  208. return {
  209. "content": result.get("content", "") or "",
  210. "page_number": page_num,
  211. }
  212. async def ingest(
  213. self, data: str | bytes, **kwargs
  214. ) -> AsyncGenerator[dict[str, str | int], None]:
  215. """Process PDF as images using pdf2image."""
  216. ingest_start = time.perf_counter()
  217. logger.info("Starting PDF ingestion using VLMPDFParser.")
  218. if not self.vision_prompt_text:
  219. self.vision_prompt_text = (
  220. await self.database_provider.prompts_handler.get_cached_prompt(
  221. prompt_name="vision_pdf"
  222. )
  223. )
  224. logger.info("Retrieved vision prompt text from database.")
  225. self.semaphore = asyncio.Semaphore(self.max_concurrent_vlm_tasks)
  226. try:
  227. if isinstance(data, str):
  228. pdf_info = pdf2image.pdfinfo_from_path(data)
  229. else:
  230. pdf_bytes = BytesIO(data)
  231. pdf_info = pdf2image.pdfinfo_from_bytes(pdf_bytes.getvalue())
  232. max_pages = pdf_info["Pages"]
  233. logger.info(f"PDF has {max_pages} pages to process")
  234. # Create a task queue to process pages in order
  235. pending_tasks = []
  236. completed_tasks = []
  237. next_page_to_yield = 1
  238. # Process pages with a sliding window, in batches
  239. for batch_start in range(1, max_pages + 1, self.vlm_batch_size):
  240. batch_end = min(
  241. batch_start + self.vlm_batch_size - 1, max_pages
  242. )
  243. logger.debug(
  244. f"Preparing batch of pages {batch_start}-{batch_end}/{max_pages}"
  245. )
  246. # Convert the batch of pages to images
  247. if isinstance(data, str):
  248. images = pdf2image.convert_from_path(
  249. data,
  250. dpi=150,
  251. first_page=batch_start,
  252. last_page=batch_end,
  253. )
  254. else:
  255. pdf_bytes = BytesIO(data)
  256. images = pdf2image.convert_from_bytes(
  257. pdf_bytes.getvalue(),
  258. dpi=150,
  259. first_page=batch_start,
  260. last_page=batch_end,
  261. )
  262. # Create tasks for each page in the batch
  263. for i, image in enumerate(images):
  264. page_num = batch_start + i
  265. task = asyncio.create_task(
  266. self.process_and_yield(image, page_num)
  267. )
  268. task.page_num = page_num # Store page number for sorting
  269. pending_tasks.append(task)
  270. # Check if any tasks have completed and yield them in order
  271. while pending_tasks:
  272. # Get the first done task without waiting
  273. done_tasks, pending_tasks_set = await asyncio.wait(
  274. pending_tasks,
  275. timeout=0.01,
  276. return_when=asyncio.FIRST_COMPLETED,
  277. )
  278. if not done_tasks:
  279. break
  280. # Add completed tasks to our completed list
  281. pending_tasks = list(pending_tasks_set)
  282. completed_tasks.extend(iter(done_tasks))
  283. # Sort completed tasks by page number
  284. completed_tasks.sort(key=lambda t: t.page_num)
  285. # Yield results in order
  286. while (
  287. completed_tasks
  288. and completed_tasks[0].page_num == next_page_to_yield
  289. ):
  290. task = completed_tasks.pop(0)
  291. yield await task
  292. next_page_to_yield += 1
  293. # Wait for and yield any remaining tasks in order
  294. while pending_tasks:
  295. done_tasks, _ = await asyncio.wait(pending_tasks)
  296. completed_tasks.extend(done_tasks)
  297. pending_tasks = []
  298. # Sort and yield remaining completed tasks
  299. completed_tasks.sort(key=lambda t: t.page_num)
  300. # Yield results in order
  301. while (
  302. completed_tasks
  303. and completed_tasks[0].page_num == next_page_to_yield
  304. ):
  305. task = completed_tasks.pop(0)
  306. yield await task
  307. next_page_to_yield += 1
  308. total_elapsed = time.perf_counter() - ingest_start
  309. logger.info(
  310. f"Completed PDF conversion in {total_elapsed:.2f} seconds"
  311. )
  312. except Exception as e:
  313. logger.error(f"Error processing PDF: {str(e)}")
  314. raise
  315. class BasicPDFParser(AsyncParser[str | bytes]):
  316. """A parser for PDF data."""
  317. def __init__(
  318. self,
  319. config: IngestionConfig,
  320. database_provider: DatabaseProvider,
  321. llm_provider: CompletionProvider,
  322. ):
  323. self.database_provider = database_provider
  324. self.llm_provider = llm_provider
  325. self.config = config
  326. self.PdfReader = PdfReader
  327. async def ingest(
  328. self, data: str | bytes, **kwargs
  329. ) -> AsyncGenerator[str, None]:
  330. """Ingest PDF data and yield text from each page."""
  331. if isinstance(data, str):
  332. raise ValueError("PDF data must be in bytes format.")
  333. pdf = self.PdfReader(BytesIO(data))
  334. for page in pdf.pages:
  335. page_text = page.extract_text()
  336. if page_text is not None:
  337. page_text = "".join(
  338. filter(
  339. lambda x: (
  340. unicodedata.category(x)
  341. in [
  342. "Ll",
  343. "Lu",
  344. "Lt",
  345. "Lm",
  346. "Lo",
  347. "Nl",
  348. "No",
  349. ] # Keep letters and numbers
  350. or "\u4e00" <= x <= "\u9fff" # Chinese characters
  351. or "\u0600" <= x <= "\u06ff" # Arabic characters
  352. or "\u0400" <= x <= "\u04ff" # Cyrillic letters
  353. or "\u0370" <= x <= "\u03ff" # Greek letters
  354. or "\u0e00" <= x <= "\u0e7f" # Thai
  355. or "\u3040" <= x <= "\u309f" # Japanese Hiragana
  356. or "\u30a0" <= x <= "\u30ff" # Katakana
  357. or "\uff00"
  358. <= x
  359. <= "\uffef" # Halfwidth and Fullwidth Forms
  360. or x in string.printable
  361. ),
  362. page_text,
  363. )
  364. ) # Keep characters in common languages ; # Filter out non-printable characters
  365. yield page_text
  366. class PDFParserUnstructured(AsyncParser[str | bytes]):
  367. def __init__(
  368. self,
  369. config: IngestionConfig,
  370. database_provider: DatabaseProvider,
  371. llm_provider: CompletionProvider,
  372. ocr_provider: OCRProvider,
  373. ):
  374. self.database_provider = database_provider
  375. self.llm_provider = llm_provider
  376. self.config = config
  377. try:
  378. from unstructured.partition.pdf import partition_pdf
  379. self.partition_pdf = partition_pdf
  380. except ImportError as e:
  381. logger.error("PDFParserUnstructured ImportError : ", e)
  382. async def ingest(
  383. self,
  384. data: str | bytes,
  385. partition_strategy: str = "hi_res",
  386. chunking_strategy="by_title",
  387. ) -> AsyncGenerator[str, None]:
  388. # partition the pdf
  389. elements = self.partition_pdf(
  390. file=BytesIO(data),
  391. partition_strategy=partition_strategy,
  392. chunking_strategy=chunking_strategy,
  393. )
  394. for element in elements:
  395. yield element.text