ingestion.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import logging
  2. from abc import ABC
  3. from enum import Enum
  4. from typing import TYPE_CHECKING, Any, ClassVar, Optional
  5. from pydantic import Field
  6. from core.base.abstractions import ChunkEnrichmentSettings
  7. from .base import AppConfig, Provider, ProviderConfig
  8. from .llm import CompletionProvider
  9. logger = logging.getLogger()
  10. if TYPE_CHECKING:
  11. from core.providers.database import PostgresDatabaseProvider
  12. class ChunkingStrategy(str, Enum):
  13. RECURSIVE = "recursive"
  14. CHARACTER = "character"
  15. BASIC = "basic"
  16. BY_TITLE = "by_title"
  17. class IngestionConfig(ProviderConfig):
  18. _defaults: ClassVar[dict] = {
  19. "app": AppConfig(),
  20. "provider": "r2r",
  21. "excluded_parsers": [],
  22. "chunking_strategy": "recursive",
  23. "chunk_size": 1024,
  24. "chunk_overlap": 512,
  25. "chunk_enrichment_settings": ChunkEnrichmentSettings(),
  26. "extra_parsers": {},
  27. "audio_transcription_model": None,
  28. "vlm": None,
  29. "vlm_batch_size": 5,
  30. "vlm_max_tokens_to_sample": 1_024,
  31. "max_concurrent_vlm_tasks": 5,
  32. "vlm_ocr_one_page_per_chunk": True,
  33. "skip_document_summary": False,
  34. "document_summary_system_prompt": "system",
  35. "document_summary_task_prompt": "summary",
  36. "document_summary_max_length": 100_000,
  37. "chunks_for_document_summary": 128,
  38. "document_summary_model": None,
  39. "parser_overrides": {},
  40. "extra_fields": {},
  41. "automatic_extraction": False,
  42. }
  43. provider: str = Field(
  44. default_factory=lambda: IngestionConfig._defaults["provider"]
  45. )
  46. excluded_parsers: list[str] = Field(
  47. default_factory=lambda: IngestionConfig._defaults["excluded_parsers"]
  48. )
  49. chunking_strategy: str | ChunkingStrategy = Field(
  50. default_factory=lambda: IngestionConfig._defaults["chunking_strategy"]
  51. )
  52. chunk_size: int = Field(
  53. default_factory=lambda: IngestionConfig._defaults["chunk_size"]
  54. )
  55. chunk_overlap: int = Field(
  56. default_factory=lambda: IngestionConfig._defaults["chunk_overlap"]
  57. )
  58. chunk_enrichment_settings: ChunkEnrichmentSettings = Field(
  59. default_factory=lambda: IngestionConfig._defaults[
  60. "chunk_enrichment_settings"
  61. ]
  62. )
  63. extra_parsers: dict[str, Any] = Field(
  64. default_factory=lambda: IngestionConfig._defaults["extra_parsers"]
  65. )
  66. audio_transcription_model: Optional[str] = Field(
  67. default_factory=lambda: IngestionConfig._defaults[
  68. "audio_transcription_model"
  69. ]
  70. )
  71. vlm: Optional[str] = Field(
  72. default_factory=lambda: IngestionConfig._defaults["vlm"]
  73. )
  74. vlm_batch_size: int = Field(
  75. default_factory=lambda: IngestionConfig._defaults["vlm_batch_size"]
  76. )
  77. vlm_max_tokens_to_sample: int = Field(
  78. default_factory=lambda: IngestionConfig._defaults[
  79. "vlm_max_tokens_to_sample"
  80. ]
  81. )
  82. max_concurrent_vlm_tasks: int = Field(
  83. default_factory=lambda: IngestionConfig._defaults[
  84. "max_concurrent_vlm_tasks"
  85. ]
  86. )
  87. vlm_ocr_one_page_per_chunk: bool = Field(
  88. default_factory=lambda: IngestionConfig._defaults[
  89. "vlm_ocr_one_page_per_chunk"
  90. ]
  91. )
  92. skip_document_summary: bool = Field(
  93. default_factory=lambda: IngestionConfig._defaults[
  94. "skip_document_summary"
  95. ]
  96. )
  97. document_summary_system_prompt: str = Field(
  98. default_factory=lambda: IngestionConfig._defaults[
  99. "document_summary_system_prompt"
  100. ]
  101. )
  102. document_summary_task_prompt: str = Field(
  103. default_factory=lambda: IngestionConfig._defaults[
  104. "document_summary_task_prompt"
  105. ]
  106. )
  107. chunks_for_document_summary: int = Field(
  108. default_factory=lambda: IngestionConfig._defaults[
  109. "chunks_for_document_summary"
  110. ]
  111. )
  112. document_summary_model: Optional[str] = Field(
  113. default_factory=lambda: IngestionConfig._defaults[
  114. "document_summary_model"
  115. ]
  116. )
  117. parser_overrides: dict[str, str] = Field(
  118. default_factory=lambda: IngestionConfig._defaults["parser_overrides"]
  119. )
  120. automatic_extraction: bool = Field(
  121. default_factory=lambda: IngestionConfig._defaults[
  122. "automatic_extraction"
  123. ]
  124. )
  125. document_summary_max_length: int = Field(
  126. default_factory=lambda: IngestionConfig._defaults[
  127. "document_summary_max_length"
  128. ]
  129. )
  130. @classmethod
  131. def set_default(cls, **kwargs):
  132. for key, value in kwargs.items():
  133. if key in cls._defaults:
  134. cls._defaults[key] = value
  135. else:
  136. raise AttributeError(
  137. f"No default attribute '{key}' in IngestionConfig"
  138. )
  139. @property
  140. def supported_providers(self) -> list[str]:
  141. return ["r2r", "unstructured_local", "unstructured_api"]
  142. def validate_config(self) -> None:
  143. if self.provider not in self.supported_providers:
  144. raise ValueError(
  145. f"Provider {self.provider} is not supported, must be one of {self.supported_providers}"
  146. )
  147. @classmethod
  148. def get_default(cls, mode: str, app) -> "IngestionConfig":
  149. """Return default ingestion configuration for a given mode."""
  150. if mode == "hi-res":
  151. return cls(app=app, parser_overrides={"pdf": "zerox"})
  152. if mode == "ocr":
  153. return cls(app=app, parser_overrides={"pdf": "ocr"})
  154. if mode == "fast":
  155. return cls(app=app, skip_document_summary=True)
  156. else:
  157. return cls(app=app)
  158. class IngestionProvider(Provider, ABC):
  159. config: IngestionConfig
  160. database_provider: "PostgresDatabaseProvider"
  161. llm_provider: CompletionProvider
  162. def __init__(
  163. self,
  164. config: IngestionConfig,
  165. database_provider: "PostgresDatabaseProvider",
  166. llm_provider: CompletionProvider,
  167. ):
  168. super().__init__(config)
  169. self.config: IngestionConfig = config
  170. self.llm_provider = llm_provider
  171. self.database_provider: "PostgresDatabaseProvider" = database_provider