img_parser.py 11 KB


  1. # type: ignore
  2. import base64
  3. import logging
  4. from io import BytesIO
  5. from typing import AsyncGenerator, Optional
  6. import filetype
  7. import pillow_heif
  8. from PIL import Image
  9. from core.base.abstractions import GenerationConfig
  10. from core.base.parsers.base_parser import AsyncParser
  11. from core.base.providers import (
  12. CompletionProvider,
  13. DatabaseProvider,
  14. IngestionConfig,
  15. )
  16. logger = logging.getLogger()
  17. class ImageParser(AsyncParser[str | bytes]):
  18. # Mapping of file extensions to MIME types
  19. MIME_TYPE_MAPPING = {
  20. "bmp": "image/bmp",
  21. "gif": "image/gif",
  22. "heic": "image/heic",
  23. "jpeg": "image/jpeg",
  24. "jpg": "image/jpeg",
  25. "png": "image/png",
  26. "tiff": "image/tiff",
  27. "tif": "image/tiff",
  28. "webp": "image/webp",
  29. }
  30. def __init__(
  31. self,
  32. config: IngestionConfig,
  33. database_provider: DatabaseProvider,
  34. llm_provider: CompletionProvider,
  35. ):
  36. self.database_provider = database_provider
  37. self.llm_provider = llm_provider
  38. self.config = config
  39. self.vision_prompt_text = None
  40. self.Image = Image
  41. self.pillow_heif = pillow_heif
  42. self.pillow_heif.register_heif_opener()
  43. def _is_heic(self, data: bytes) -> bool:
  44. """Detect HEIC format using magic numbers and patterns."""
  45. heic_patterns = [
  46. b"ftyp",
  47. b"heic",
  48. b"heix",
  49. b"hevc",
  50. b"HEIC",
  51. b"mif1",
  52. b"msf1",
  53. b"hevc",
  54. b"hevx",
  55. ]
  56. try:
  57. header = data[:32] # Get first 32 bytes
  58. return any(pattern in header for pattern in heic_patterns)
  59. except Exception as e:
  60. logger.error(f"Error checking for HEIC format: {str(e)}")
  61. return False
  62. async def _convert_heic_to_jpeg(self, data: bytes) -> bytes:
  63. """Convert HEIC image to JPEG format."""
  64. try:
  65. # Create BytesIO object for input
  66. input_buffer = BytesIO(data)
  67. # Load HEIC image using pillow_heif
  68. heif_file = self.pillow_heif.read_heif(input_buffer)
  69. # Get the primary image - API changed, need to get first image
  70. heif_image = heif_file[0] # Get first image in the container
  71. # Convert to PIL Image directly from the HEIF image
  72. pil_image = heif_image.to_pillow()
  73. # Convert to RGB if needed
  74. if pil_image.mode != "RGB":
  75. pil_image = pil_image.convert("RGB")
  76. # Save as JPEG
  77. output_buffer = BytesIO()
  78. pil_image.save(output_buffer, format="JPEG", quality=95)
  79. return output_buffer.getvalue()
  80. except Exception as e:
  81. logger.error(f"Error converting HEIC to JPEG: {str(e)}")
  82. raise
  83. async def _convert_tiff_to_jpeg(self, data: bytes) -> bytes:
  84. """Convert TIFF image to JPEG format."""
  85. try:
  86. # Open TIFF image
  87. with BytesIO(data) as input_buffer:
  88. tiff_image = self.Image.open(input_buffer)
  89. # Convert to RGB if needed
  90. if tiff_image.mode not in ("RGB", "L"):
  91. tiff_image = tiff_image.convert("RGB")
  92. # Save as JPEG
  93. output_buffer = BytesIO()
  94. tiff_image.save(output_buffer, format="JPEG", quality=95)
  95. return output_buffer.getvalue()
  96. except Exception as e:
  97. raise ValueError(f"Error converting TIFF to JPEG: {str(e)}") from e
  98. def _is_jpeg(self, data: bytes) -> bool:
  99. """Detect JPEG format using magic numbers."""
  100. return len(data) >= 2 and data[0] == 0xFF and data[1] == 0xD8
  101. def _is_png(self, data: bytes) -> bool:
  102. """Detect PNG format using magic numbers."""
  103. png_signature = b"\x89PNG\r\n\x1a\n"
  104. return data.startswith(png_signature)
  105. def _is_bmp(self, data: bytes) -> bool:
  106. """Detect BMP format using magic numbers."""
  107. return data.startswith(b"BM")
  108. def _is_tiff(self, data: bytes) -> bool:
  109. """Detect TIFF format using magic numbers."""
  110. return (
  111. data.startswith(b"II*\x00") # Little-endian
  112. or data.startswith(b"MM\x00*")
  113. ) # Big-endian
  114. def _get_image_media_type(
  115. self, data: bytes, filename: Optional[str] = None
  116. ) -> str:
  117. """
  118. Determine the correct media type based on image data and/or filename.
  119. Args:
  120. data: The binary image data
  121. filename: Optional filename which may contain extension information
  122. Returns:
  123. str: The MIME type for the image
  124. """
  125. try:
  126. # First, try format-specific detection functions
  127. if self._is_heic(data):
  128. return "image/heic"
  129. if self._is_jpeg(data):
  130. return "image/jpeg"
  131. if self._is_png(data):
  132. return "image/png"
  133. if self._is_bmp(data):
  134. return "image/bmp"
  135. if self._is_tiff(data):
  136. return "image/tiff"
  137. # Try using filetype as a fallback
  138. if img_type := filetype.guess(data):
  139. # Map the detected type to a MIME type
  140. return self.MIME_TYPE_MAPPING.get(
  141. img_type, f"image/{img_type}"
  142. )
  143. # If we have a filename, try to get the type from the extension
  144. if filename:
  145. extension = filename.split(".")[-1].lower()
  146. if extension in self.MIME_TYPE_MAPPING:
  147. return self.MIME_TYPE_MAPPING[extension]
  148. # If all else fails, default to octet-stream (generic binary)
  149. logger.warning(
  150. "Could not determine image type, using application/octet-stream"
  151. )
  152. return "application/octet-stream"
  153. except Exception as e:
  154. logger.error(f"Error determining image media type: {str(e)}")
  155. return "application/octet-stream" # Default to generic binary as fallback
  156. async def ingest(
  157. self,
  158. data: str | bytes,
  159. prompt_text: str = None,
  160. prompt_name: str = None,
  161. prompt_args: dict = None,
  162. **kwargs,
  163. ) -> AsyncGenerator[str, None]:
  164. # prompt_text > prompt_name > self.vision_prompt_text
  165. if not prompt_text and not prompt_name:
  166. if not self.vision_prompt_text:
  167. prompt = await self.database_provider.prompts_handler.get_cached_prompt(
  168. prompt_name="vision_img"
  169. )
  170. self.vision_prompt_text = prompt
  171. prompt_text = self.vision_prompt_text
  172. elif not prompt_text and prompt_name:
  173. prompt = (
  174. await self.database_provider.prompts_handler.get_cached_prompt(
  175. prompt_name=prompt_name,
  176. inputs=prompt_args,
  177. )
  178. )
  179. prompt_text = prompt
  180. try:
  181. filename = kwargs.get("filename", None)
  182. # Whether to convert HEIC to JPEG (default: True for backward compatibility)
  183. convert_heic = kwargs.get("convert_heic", True)
  184. if isinstance(data, bytes):
  185. try:
  186. # First detect the original media type
  187. original_media_type = self._get_image_media_type(
  188. data, filename
  189. )
  190. logger.debug(
  191. f"Detected original image type: {original_media_type}"
  192. )
  193. # Determine if we need to convert HEIC
  194. is_heic_format = self._is_heic(data)
  195. is_tiff_format = self._is_tiff(data)
  196. # Handle HEIC images
  197. if is_heic_format and convert_heic:
  198. logger.debug(
  199. "Detected HEIC format, converting to JPEG"
  200. )
  201. data = await self._convert_heic_to_jpeg(data)
  202. media_type = "image/jpeg"
  203. elif is_tiff_format:
  204. logger.debug(
  205. "Detected TIFF format, converting to JPEG"
  206. )
  207. data = await self._convert_tiff_to_jpeg(data)
  208. media_type = "image/jpeg"
  209. else:
  210. # Keep original format and media type
  211. media_type = original_media_type
  212. # Encode the data to base64
  213. image_data = base64.b64encode(data).decode("utf-8")
  214. except Exception as e:
  215. logger.error(f"Error processing image data: {str(e)}")
  216. raise
  217. else:
  218. # If data is already a string (base64), we assume it has a reliable content type
  219. # from the source that encoded it
  220. image_data = data
  221. # Try to determine the media type from the context if available
  222. media_type = kwargs.get(
  223. "media_type", "application/octet-stream"
  224. )
  225. # Get the model from kwargs or config
  226. model = kwargs.get("vlm", None) or self.config.app.vlm
  227. generation_config = GenerationConfig(
  228. model=model,
  229. stream=False,
  230. )
  231. logger.debug(f"Using model: {model}, media_type: {media_type}")
  232. if "anthropic" in model:
  233. messages = [
  234. {
  235. "role": "user",
  236. "content": [
  237. {"type": "text", "text": prompt_text},
  238. {
  239. "type": "image",
  240. "source": {
  241. "type": "base64",
  242. "media_type": media_type,
  243. "data": image_data,
  244. },
  245. },
  246. ],
  247. }
  248. ]
  249. else:
  250. # For OpenAI-style APIs, use their format
  251. messages = [
  252. {
  253. "role": "user",
  254. "content": [
  255. {"type": "text", "text": prompt_text},
  256. {
  257. "type": "image_url",
  258. "image_url": {
  259. "url": f"data:{media_type};base64,{image_data}"
  260. },
  261. },
  262. ],
  263. }
  264. ]
  265. response = await self.llm_provider.aget_completion(
  266. messages=messages, generation_config=generation_config
  267. )
  268. if not response.choices or not response.choices[0].message:
  269. raise ValueError("No response content")
  270. if content := response.choices[0].message.content:
  271. yield content
  272. else:
  273. raise ValueError("No content in response")
  274. except Exception as e:
  275. logger.error(f"Error processing image with vision model: {str(e)}")
  276. raise