img_parser.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import base64
  2. import logging
  3. from typing import AsyncGenerator
  4. from core.base.abstractions import GenerationConfig
  5. from core.base.parsers.base_parser import AsyncParser
  6. from core.base.providers import (
  7. CompletionProvider,
  8. DatabaseProvider,
  9. IngestionConfig,
  10. )
  11. logger = logging.getLogger()
  12. class ImageParser(AsyncParser[str | bytes]):
  13. """A parser for image data using vision models."""
  14. def __init__(
  15. self,
  16. config: IngestionConfig,
  17. database_provider: DatabaseProvider,
  18. llm_provider: CompletionProvider,
  19. ):
  20. self.database_provider = database_provider
  21. self.llm_provider = llm_provider
  22. self.config = config
  23. self.vision_prompt_text = None
  24. try:
  25. from litellm import supports_vision
  26. self.supports_vision = supports_vision
  27. except ImportError:
  28. logger.error("Failed to import LiteLLM vision support")
  29. raise ImportError(
  30. "Please install the `litellm` package to use the ImageParser."
  31. )
  32. async def ingest( # type: ignore
  33. self, data: str | bytes, **kwargs
  34. ) -> AsyncGenerator[str, None]:
  35. """
  36. Ingest image data and yield a description using vision model.
  37. Args:
  38. data: Image data (bytes or base64 string)
  39. *args, **kwargs: Additional arguments passed to the completion call
  40. Yields:
  41. Chunks of image description text
  42. """
  43. if not self.vision_prompt_text:
  44. self.vision_prompt_text = await self.database_provider.prompts_handler.get_cached_prompt( # type: ignore
  45. prompt_name=self.config.vision_img_prompt_name
  46. )
  47. try:
  48. # Verify model supports vision
  49. if not self.supports_vision(model=self.config.vision_img_model):
  50. raise ValueError(
  51. f"Model {self.config.vision_img_model} does not support vision"
  52. )
  53. # Encode image data if needed
  54. if isinstance(data, bytes):
  55. image_data = base64.b64encode(data).decode("utf-8")
  56. else:
  57. image_data = data
  58. # Configure the generation parameters
  59. generation_config = GenerationConfig(
  60. model=self.config.vision_img_model,
  61. stream=False,
  62. )
  63. # Prepare message with image
  64. messages = [
  65. {
  66. "role": "user",
  67. "content": [
  68. {"type": "text", "text": self.vision_prompt_text},
  69. {
  70. "type": "image_url",
  71. "image_url": {
  72. "url": f"data:image/jpeg;base64,{image_data}"
  73. },
  74. },
  75. ],
  76. }
  77. ]
  78. # Get completion from LiteLLM provider
  79. response = await self.llm_provider.aget_completion(
  80. messages=messages, generation_config=generation_config
  81. )
  82. # Extract description from response
  83. if response.choices and response.choices[0].message:
  84. content = response.choices[0].message.content
  85. if not content:
  86. raise ValueError("No content in response")
  87. yield content
  88. else:
  89. raise ValueError("No response content")
  90. except Exception as e:
  91. logger.error(f"Error processing image with vision model: {str(e)}")
  92. raise