ingestion.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import logging
  2. from abc import ABC
  3. from enum import Enum
  4. from typing import TYPE_CHECKING, Any, ClassVar
  5. from pydantic import Field
  6. from core.base.abstractions import ChunkEnrichmentSettings
  7. from .base import AppConfig, Provider, ProviderConfig
  8. from .llm import CompletionProvider
  9. logger = logging.getLogger()
  10. if TYPE_CHECKING:
  11. from core.database import PostgresDatabaseProvider
  12. class ChunkingStrategy(str, Enum):
  13. RECURSIVE = "recursive"
  14. CHARACTER = "character"
  15. BASIC = "basic"
  16. BY_TITLE = "by_title"
  17. class IngestionMode(str, Enum):
  18. hi_res = "hi-res"
  19. fast = "fast"
  20. custom = "custom"
  21. class IngestionConfig(ProviderConfig):
  22. _defaults: ClassVar[dict] = {
  23. "app": AppConfig(),
  24. "provider": "r2r",
  25. "excluded_parsers": ["mp4"],
  26. "chunking_strategy": "recursive",
  27. "chunk_enrichment_settings": ChunkEnrichmentSettings(),
  28. "extra_parsers": {},
  29. "audio_transcription_model": "openai/whisper-1",
  30. "vision_img_prompt_name": "vision_img",
  31. "vision_img_model": "openai/gpt-4o",
  32. "vision_pdf_prompt_name": "vision_pdf",
  33. "vision_pdf_model": "openai/gpt-4o",
  34. "skip_document_summary": False,
  35. "document_summary_system_prompt": "default_system",
  36. "document_summary_task_prompt": "default_summary",
  37. "chunks_for_document_summary": 128,
  38. "document_summary_model": "openai/gpt-4o-mini",
  39. "parser_overrides": {},
  40. "extra_fields": {},
  41. }
  42. provider: str = Field(
  43. default_factory=lambda: IngestionConfig._defaults["provider"]
  44. )
  45. excluded_parsers: list[str] = Field(
  46. default_factory=lambda: IngestionConfig._defaults["excluded_parsers"]
  47. )
  48. chunking_strategy: str | ChunkingStrategy = Field(
  49. default_factory=lambda: IngestionConfig._defaults["chunking_strategy"]
  50. )
  51. chunk_enrichment_settings: ChunkEnrichmentSettings = Field(
  52. default_factory=lambda: IngestionConfig._defaults[
  53. "chunk_enrichment_settings"
  54. ]
  55. )
  56. extra_parsers: dict[str, Any] = Field(
  57. default_factory=lambda: IngestionConfig._defaults["extra_parsers"]
  58. )
  59. audio_transcription_model: str = Field(
  60. default_factory=lambda: IngestionConfig._defaults[
  61. "audio_transcription_model"
  62. ]
  63. )
  64. vision_img_prompt_name: str = Field(
  65. default_factory=lambda: IngestionConfig._defaults[
  66. "vision_img_prompt_name"
  67. ]
  68. )
  69. vision_img_model: str = Field(
  70. default_factory=lambda: IngestionConfig._defaults["vision_img_model"]
  71. )
  72. vision_pdf_prompt_name: str = Field(
  73. default_factory=lambda: IngestionConfig._defaults[
  74. "vision_pdf_prompt_name"
  75. ]
  76. )
  77. vision_pdf_model: str = Field(
  78. default_factory=lambda: IngestionConfig._defaults["vision_pdf_model"]
  79. )
  80. skip_document_summary: bool = Field(
  81. default_factory=lambda: IngestionConfig._defaults[
  82. "skip_document_summary"
  83. ]
  84. )
  85. document_summary_system_prompt: str = Field(
  86. default_factory=lambda: IngestionConfig._defaults[
  87. "document_summary_system_prompt"
  88. ]
  89. )
  90. document_summary_task_prompt: str = Field(
  91. default_factory=lambda: IngestionConfig._defaults[
  92. "document_summary_task_prompt"
  93. ]
  94. )
  95. chunks_for_document_summary: int = Field(
  96. default_factory=lambda: IngestionConfig._defaults[
  97. "chunks_for_document_summary"
  98. ]
  99. )
  100. document_summary_model: str = Field(
  101. default_factory=lambda: IngestionConfig._defaults[
  102. "document_summary_model"
  103. ]
  104. )
  105. parser_overrides: dict[str, str] = Field(
  106. default_factory=lambda: IngestionConfig._defaults["parser_overrides"]
  107. )
  108. @classmethod
  109. def set_default(cls, **kwargs):
  110. for key, value in kwargs.items():
  111. if key in cls._defaults:
  112. cls._defaults[key] = value
  113. else:
  114. raise AttributeError(
  115. f"No default attribute '{key}' in IngestionConfig"
  116. )
  117. @property
  118. def supported_providers(self) -> list[str]:
  119. return ["r2r", "unstructured_local", "unstructured_api"]
  120. def validate_config(self) -> None:
  121. if self.provider not in self.supported_providers:
  122. raise ValueError(f"Provider {self.provider} is not supported.")
  123. @classmethod
  124. def get_default(cls, mode: str, app) -> "IngestionConfig":
  125. """Return default ingestion configuration for a given mode."""
  126. if mode == "hi-res":
  127. return cls(app=app, parser_overrides={"pdf": "zerox"})
  128. else:
  129. return cls(app=app)
  130. class Config:
  131. populate_by_name = True
  132. json_schema_extra = {
  133. "provider": "r2r",
  134. "excluded_parsers": ["mp4"],
  135. "chunking_strategy": "recursive",
  136. "chunk_enrichment_settings": ChunkEnrichmentSettings().dict(),
  137. "extra_parsers": {},
  138. "audio_transcription_model": "openai/whisper-1",
  139. "vision_img_prompt_name": "vision_img",
  140. "vision_img_model": "openai/gpt-4o",
  141. "vision_pdf_prompt_name": "vision_pdf",
  142. "vision_pdf_model": "openai/gpt-4o",
  143. "skip_document_summary": False,
  144. "document_summary_system_prompt": "default_system",
  145. "document_summary_task_prompt": "default_summary",
  146. "chunks_for_document_summary": 128,
  147. "document_summary_model": "openai/gpt-4o-mini",
  148. "parser_overrides": {},
  149. }
  150. class IngestionProvider(Provider, ABC):
  151. config: IngestionConfig
  152. database_provider: "PostgresDatabaseProvider"
  153. llm_provider: CompletionProvider
  154. def __init__(
  155. self,
  156. config: IngestionConfig,
  157. database_provider: "PostgresDatabaseProvider",
  158. llm_provider: CompletionProvider,
  159. ):
  160. super().__init__(config)
  161. self.config: IngestionConfig = config
  162. self.llm_provider = llm_provider
  163. self.database_provider: "PostgresDatabaseProvider" = database_provider