tiff_parser.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # type: ignore
  2. import base64
  3. import logging
  4. from io import BytesIO
  5. from typing import AsyncGenerator
  6. from core.base.abstractions import GenerationConfig
  7. from core.base.parsers.base_parser import AsyncParser
  8. from core.base.providers import (
  9. CompletionProvider,
  10. DatabaseProvider,
  11. IngestionConfig,
  12. )
  13. logger = logging.getLogger(__name__)
  14. class TIFFParser(AsyncParser[str | bytes]):
  15. """Parser for TIFF image files."""
  16. def __init__(
  17. self,
  18. config: IngestionConfig,
  19. database_provider: DatabaseProvider,
  20. llm_provider: CompletionProvider,
  21. ):
  22. self.database_provider = database_provider
  23. self.llm_provider = llm_provider
  24. self.config = config
  25. self.vision_prompt_text = None
  26. try:
  27. from litellm import supports_vision
  28. from PIL import Image
  29. self.supports_vision = supports_vision
  30. self.Image = Image
  31. except ImportError:
  32. raise ImportError("Required packages not available.")
  33. async def _convert_tiff_to_jpeg(self, data: bytes) -> bytes:
  34. """Convert TIFF image to JPEG format."""
  35. try:
  36. # Open TIFF image
  37. with BytesIO(data) as input_buffer:
  38. tiff_image = self.Image.open(input_buffer)
  39. # Convert to RGB if needed
  40. if tiff_image.mode not in ("RGB", "L"):
  41. tiff_image = tiff_image.convert("RGB")
  42. # Save as JPEG
  43. output_buffer = BytesIO()
  44. tiff_image.save(output_buffer, format="JPEG", quality=95)
  45. return output_buffer.getvalue()
  46. except Exception as e:
  47. raise ValueError(f"Error converting TIFF to JPEG: {str(e)}")
  48. async def ingest(
  49. self, data: str | bytes, **kwargs
  50. ) -> AsyncGenerator[str, None]:
  51. if not self.vision_prompt_text:
  52. self.vision_prompt_text = (
  53. await self.database_provider.prompts_handler.get_cached_prompt(
  54. prompt_name=self.config.vision_img_prompt_name
  55. )
  56. )
  57. try:
  58. if not self.supports_vision(model=self.config.vision_img_model):
  59. raise ValueError(
  60. f"Model {self.config.vision_img_model} does not support vision"
  61. )
  62. # Convert TIFF to JPEG
  63. if isinstance(data, bytes):
  64. jpeg_data = await self._convert_tiff_to_jpeg(data)
  65. image_data = base64.b64encode(jpeg_data).decode("utf-8")
  66. else:
  67. image_data = data
  68. # Use vision model to analyze image
  69. generation_config = GenerationConfig(
  70. model=self.config.vision_img_model,
  71. stream=False,
  72. )
  73. messages = [
  74. {
  75. "role": "user",
  76. "content": [
  77. {"type": "text", "text": self.vision_prompt_text},
  78. {
  79. "type": "image_url",
  80. "image_url": {
  81. "url": f"data:image/jpeg;base64,{image_data}"
  82. },
  83. },
  84. ],
  85. }
  86. ]
  87. response = await self.llm_provider.aget_completion(
  88. messages=messages, generation_config=generation_config
  89. )
  90. if response.choices and response.choices[0].message:
  91. content = response.choices[0].message.content
  92. if not content:
  93. raise ValueError("No content in response")
  94. yield content
  95. else:
  96. raise ValueError("No response content")
  97. except Exception as e:
  98. raise ValueError(f"Error processing TIFF file: {str(e)}")