mistral.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import logging
  2. import os
  3. from typing import Any
  4. from mistralai import Mistral
  5. from mistralai.models import OCRResponse
  6. from core.base.providers.ocr import OCRConfig, OCRProvider
  7. logger = logging.getLogger()
  8. class MistralOCRProvider(OCRProvider):
  9. def __init__(self, config: OCRConfig) -> None:
  10. if not isinstance(config, OCRConfig):
  11. raise ValueError(
  12. f"MistralOCRProvider must be initialized with a OCRConfig. Got: {config} with type {type(config)}"
  13. )
  14. super().__init__(config)
  15. self.config: OCRConfig = config
  16. api_key = os.environ.get("MISTRAL_API_KEY")
  17. if not api_key:
  18. logger.warning(
  19. "MISTRAL_API_KEY not set in environment, if you plan to use Mistral OCR, please set it."
  20. )
  21. self.mistral = Mistral(api_key=api_key)
  22. self.model = config.model or "mistral-ocr-latest"
  23. async def _execute_task(self, task: dict[str, Any]) -> OCRResponse:
  24. """Execute OCR task asynchronously."""
  25. document = task.get("document")
  26. include_image_base64 = task.get("include_image_base64", False)
  27. # Process through Mistral OCR API
  28. return await self.mistral.ocr.process_async(
  29. model=self.model,
  30. document=document, # type: ignore
  31. include_image_base64=include_image_base64,
  32. )
  33. def _execute_task_sync(self, task: dict[str, Any]) -> OCRResponse:
  34. """Execute OCR task synchronously."""
  35. document = task.get("document")
  36. include_image_base64 = task.get("include_image_base64", False)
  37. # Process through Mistral OCR API
  38. return self.mistral.ocr.process( # type: ignore
  39. model=self.model,
  40. document=document, # type: ignore
  41. include_image_base64=include_image_base64,
  42. )
  43. async def upload_file(
  44. self,
  45. file_path: str | None = None,
  46. file_content: bytes | None = None,
  47. file_name: str | None = None,
  48. ) -> Any:
  49. """
  50. Upload a file for OCR processing.
  51. Args:
  52. file_path: Path to the file to upload
  53. file_content: Binary content of the file
  54. file_name: Name of the file (required if file_content is provided)
  55. Returns:
  56. The uploaded file object
  57. """
  58. if file_path:
  59. file_name = os.path.basename(file_path)
  60. with open(file_path, "rb") as f:
  61. file_content = f.read()
  62. elif not file_content or not file_name:
  63. raise ValueError(
  64. "Either file_path or (file_content and file_name) must be provided"
  65. )
  66. return await self.mistral.files.upload_async(
  67. file={
  68. "file_name": file_name,
  69. "content": file_content,
  70. },
  71. purpose="ocr",
  72. )
  73. async def process_file(
  74. self, file_id: str, include_image_base64: bool = False
  75. ) -> OCRResponse:
  76. """
  77. Process a previously uploaded file using its file ID.
  78. Args:
  79. file_id: ID of the file to process
  80. include_image_base64: Whether to include image base64 in the response
  81. Returns:
  82. OCR response object
  83. """
  84. # Get the signed URL for the file
  85. signed_url = await self.mistral.files.get_signed_url_async(
  86. file_id=file_id
  87. )
  88. # Create the document data
  89. document = {
  90. "type": "document_url",
  91. "document_url": signed_url.url,
  92. }
  93. # Process the document
  94. task = {
  95. "document": document,
  96. "include_image_base64": include_image_base64,
  97. }
  98. return await self._execute_with_backoff_async(task)
  99. async def process_url(
  100. self,
  101. url: str,
  102. is_image: bool = False,
  103. include_image_base64: bool = False,
  104. ) -> OCRResponse:
  105. """
  106. Process a document or image from a URL.
  107. Args:
  108. url: URL of the document or image
  109. is_image: Whether the URL points to an image
  110. include_image_base64: Whether to include image base64 in the response
  111. Returns:
  112. OCR response object
  113. """
  114. # Create the document data
  115. document_type = "image_url" if is_image else "document_url"
  116. document = {
  117. "type": document_type,
  118. document_type: url,
  119. }
  120. # Process the document
  121. task = {
  122. "document": document,
  123. "include_image_base64": include_image_base64,
  124. }
  125. return await self._execute_with_backoff_async(task)
  126. async def process_pdf(
  127. self, file_path: str | None = None, file_content: bytes | None = None
  128. ) -> OCRResponse:
  129. """
  130. Upload and process a PDF file in one step.
  131. Args:
  132. file_path: Path to the PDF file
  133. file_content: Binary content of the PDF file
  134. Returns:
  135. OCR response object
  136. """
  137. # Upload the file
  138. if file_path:
  139. file_name = os.path.basename(file_path)
  140. with open(file_path, "rb") as f:
  141. file_content = f.read()
  142. elif not file_content:
  143. raise ValueError(
  144. "Either file_path or file_content must be provided"
  145. )
  146. file_name = file_name if file_path else "document.pdf"
  147. uploaded_file = await self.upload_file(
  148. file_name=file_name, file_content=file_content
  149. )
  150. # Process the uploaded file
  151. return await self.process_file(uploaded_file.id)