ocr.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import asyncio
  2. import logging
  3. import random
  4. import time
  5. from abc import abstractmethod
  6. from concurrent.futures import ThreadPoolExecutor
  7. from typing import Any, Optional
  8. from litellm import AuthenticationError
  9. from .base import Provider, ProviderConfig
  10. logger = logging.getLogger()
  11. class OCRConfig(ProviderConfig):
  12. provider: Optional[str] = None
  13. model: Optional[str] = None
  14. concurrent_request_limit: int = 256
  15. max_retries: int = 3
  16. initial_backoff: float = 1.0
  17. max_backoff: float = 64.0
  18. def validate_config(self) -> None:
  19. if not self.provider:
  20. raise ValueError("Provider must be set.")
  21. if self.provider not in self.supported_providers:
  22. raise ValueError(f"Provider '{self.provider}' is not supported.")
  23. @property
  24. def supported_providers(self) -> list[str]:
  25. return ["mistral"]
  26. class OCRProvider(Provider):
  27. def __init__(self, config: OCRConfig) -> None:
  28. if not isinstance(config, OCRConfig):
  29. raise ValueError(
  30. "OCRProvider must be initialized with a `OCRConfig`."
  31. )
  32. logger.info(f"Initializing OCRProvider with config: {config}")
  33. super().__init__(config)
  34. self.config: OCRConfig = config
  35. self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
  36. self.thread_pool = ThreadPoolExecutor(
  37. max_workers=config.concurrent_request_limit
  38. )
  39. async def _execute_with_backoff_async(self, task: dict[str, Any]):
  40. retries = 0
  41. backoff = self.config.initial_backoff
  42. while retries < self.config.max_retries:
  43. try:
  44. async with self.semaphore:
  45. return await self._execute_task(task)
  46. except AuthenticationError:
  47. raise
  48. except Exception as e:
  49. logger.warning(
  50. f"Request failed (attempt {retries + 1}): {str(e)}"
  51. )
  52. retries += 1
  53. if retries == self.config.max_retries:
  54. raise
  55. await asyncio.sleep(random.uniform(0, backoff))
  56. backoff = min(backoff * 2, self.config.max_backoff)
  57. def _execute_with_backoff_sync(self, task: dict[str, Any]):
  58. retries = 0
  59. backoff = self.config.initial_backoff
  60. while retries < self.config.max_retries:
  61. try:
  62. return self._execute_task_sync(task)
  63. except Exception as e:
  64. logger.warning(
  65. f"Request failed (attempt {retries + 1}): {str(e)}"
  66. )
  67. retries += 1
  68. if retries == self.config.max_retries:
  69. raise
  70. time.sleep(random.uniform(0, backoff))
  71. backoff = min(backoff * 2, self.config.max_backoff)
  72. @abstractmethod
  73. async def _execute_task(self, task: dict[str, Any]):
  74. pass
  75. @abstractmethod
  76. def _execute_task_sync(self, task: dict[str, Any]):
  77. pass
  78. @abstractmethod
  79. async def upload_file(
  80. self,
  81. file_path: str | None = None,
  82. file_content: bytes | None = None,
  83. file_name: str | None = None,
  84. ) -> Any:
  85. pass
  86. @abstractmethod
  87. async def process_file(
  88. self, file_id: str, include_image_base64: bool = False
  89. ) -> Any:
  90. pass
  91. @abstractmethod
  92. async def process_url(
  93. self,
  94. url: str,
  95. is_image: bool = False,
  96. include_image_base64: bool = False,
  97. ) -> Any:
  98. pass
  99. @abstractmethod
  100. async def process_pdf(
  101. self, file_path: str | None = None, file_content: bytes | None = None
  102. ) -> Any:
  103. pass