img_parser.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # type: ignore
  2. import base64
  3. import logging
  4. from io import BytesIO
  5. from typing import AsyncGenerator
  6. from core.base.abstractions import GenerationConfig
  7. from core.base.parsers.base_parser import AsyncParser
  8. from core.base.providers import (
  9. CompletionProvider,
  10. DatabaseProvider,
  11. IngestionConfig,
  12. )
  13. logger = logging.getLogger()
  14. class ImageParser(AsyncParser[str | bytes]):
  15. def __init__(
  16. self,
  17. config: IngestionConfig,
  18. database_provider: DatabaseProvider,
  19. llm_provider: CompletionProvider,
  20. ):
  21. self.database_provider = database_provider
  22. self.llm_provider = llm_provider
  23. self.config = config
  24. self.vision_prompt_text = None
  25. try:
  26. import pillow_heif # for HEIC support
  27. from litellm import supports_vision
  28. from PIL import Image
  29. self.supports_vision = supports_vision
  30. self.Image = Image
  31. self.pillow_heif = pillow_heif
  32. self.pillow_heif.register_heif_opener()
  33. except ImportError as e:
  34. logger.error(f"Failed to import required packages: {str(e)}")
  35. raise ImportError(
  36. "Please install the required packages: litellm, Pillow, pillow-heif"
  37. )
  38. def _is_heic(self, data: bytes) -> bool:
  39. """More robust HEIC detection using magic numbers and patterns."""
  40. heic_patterns = [
  41. b"ftyp",
  42. b"heic",
  43. b"heix",
  44. b"hevc",
  45. b"HEIC",
  46. b"mif1",
  47. b"msf1",
  48. b"hevc",
  49. b"hevx",
  50. ]
  51. # Check for HEIC file signature
  52. try:
  53. header = data[:32] # Get first 32 bytes
  54. return any(pattern in header for pattern in heic_patterns)
  55. except:
  56. return False
  57. async def _convert_heic_to_jpeg(self, data: bytes) -> bytes:
  58. """Convert HEIC image to JPEG format."""
  59. try:
  60. # Create BytesIO object for input
  61. input_buffer = BytesIO(data)
  62. # Load HEIC image using pillow_heif
  63. heif_file = self.pillow_heif.read_heif(input_buffer)
  64. # Get the primary image - API changed, need to get first image
  65. heif_image = heif_file[0] # Get first image in the container
  66. # Convert to PIL Image directly from the HEIF image
  67. pil_image = heif_image.to_pillow()
  68. # Convert to RGB if needed
  69. if pil_image.mode != "RGB":
  70. pil_image = pil_image.convert("RGB")
  71. # Save as JPEG
  72. output_buffer = BytesIO()
  73. pil_image.save(output_buffer, format="JPEG", quality=95)
  74. return output_buffer.getvalue()
  75. except Exception as e:
  76. logger.error(f"Error converting HEIC to JPEG: {str(e)}")
  77. raise
  78. async def ingest(
  79. self, data: str | bytes, **kwargs
  80. ) -> AsyncGenerator[str, None]:
  81. if not self.vision_prompt_text:
  82. self.vision_prompt_text = (
  83. await self.database_provider.prompts_handler.get_cached_prompt(
  84. prompt_name=self.config.vision_img_prompt_name
  85. )
  86. )
  87. try:
  88. if not self.supports_vision(model=self.config.vision_img_model):
  89. raise ValueError(
  90. f"Model {self.config.vision_img_model} does not support vision"
  91. )
  92. if isinstance(data, bytes):
  93. try:
  94. # Check if it's HEIC and convert if necessary
  95. if self._is_heic(data):
  96. logger.debug(
  97. "Detected HEIC format, converting to JPEG"
  98. )
  99. data = await self._convert_heic_to_jpeg(data)
  100. image_data = base64.b64encode(data).decode("utf-8")
  101. except Exception as e:
  102. logger.error(f"Error processing image data: {str(e)}")
  103. raise
  104. else:
  105. image_data = data
  106. generation_config = GenerationConfig(
  107. model=self.config.vision_img_model,
  108. stream=False,
  109. )
  110. messages = [
  111. {
  112. "role": "user",
  113. "content": [
  114. {"type": "text", "text": self.vision_prompt_text},
  115. {
  116. "type": "image_url",
  117. "image_url": {
  118. "url": f"data:image/jpeg;base64,{image_data}"
  119. },
  120. },
  121. ],
  122. }
  123. ]
  124. response = await self.llm_provider.aget_completion(
  125. messages=messages, generation_config=generation_config
  126. )
  127. if response.choices and response.choices[0].message:
  128. content = response.choices[0].message.content
  129. if not content:
  130. raise ValueError("No content in response")
  131. yield content
  132. else:
  133. raise ValueError("No response content")
  134. except Exception as e:
  135. logger.error(f"Error processing image with vision model: {str(e)}")
  136. raise