document.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. """Abstractions for documents and their extractions."""
  2. import json
  3. import logging
  4. from datetime import datetime
  5. from enum import Enum
  6. from typing import Any, Optional
  7. from uuid import UUID, uuid4
  8. from pydantic import Field
  9. from .base import R2RSerializable
  10. logger = logging.getLogger()
  11. class DocumentType(str, Enum):
  12. """Types of documents that can be stored."""
  13. # Audio
  14. MP3 = "mp3"
  15. # CSV
  16. CSV = "csv"
  17. # Email
  18. EML = "eml"
  19. MSG = "msg"
  20. P7S = "p7s"
  21. # EPUB
  22. EPUB = "epub"
  23. # Excel
  24. XLS = "xls"
  25. XLSX = "xlsx"
  26. # HTML
  27. HTML = "html"
  28. HTM = "htm"
  29. # Image
  30. BMP = "bmp"
  31. HEIC = "heic"
  32. JPEG = "jpeg"
  33. PNG = "png"
  34. TIFF = "tiff"
  35. JPG = "jpg"
  36. SVG = "svg"
  37. # Markdown
  38. MD = "md"
  39. # Org Mode
  40. ORG = "org"
  41. # Open Office
  42. ODT = "odt"
  43. # PDF
  44. PDF = "pdf"
  45. # Plain text
  46. TXT = "txt"
  47. JSON = "json"
  48. # PowerPoint
  49. PPT = "ppt"
  50. PPTX = "pptx"
  51. # reStructured Text
  52. RST = "rst"
  53. # Rich Text
  54. RTF = "rtf"
  55. # TSV
  56. TSV = "tsv"
  57. # Video/GIF
  58. GIF = "gif"
  59. # Word
  60. DOC = "doc"
  61. DOCX = "docx"
  62. # XML
  63. XML = "xml"
  64. class Document(R2RSerializable):
  65. id: UUID = Field(default_factory=uuid4)
  66. collection_ids: list[UUID]
  67. owner_id: UUID
  68. document_type: DocumentType
  69. metadata: dict
  70. class Config:
  71. arbitrary_types_allowed = True
  72. ignore_extra = False
  73. json_encoders = {
  74. UUID: str,
  75. }
  76. populate_by_name = True
  77. class IngestionStatus(str, Enum):
  78. """Status of document processing."""
  79. PENDING = "pending"
  80. PARSING = "parsing"
  81. EXTRACTING = "extracting"
  82. CHUNKING = "chunking"
  83. EMBEDDING = "embedding"
  84. AUGMENTING = "augmenting"
  85. STORING = "storing"
  86. ENRICHING = "enriching"
  87. ENRICHED = "enriched"
  88. FAILED = "failed"
  89. SUCCESS = "success"
  90. def __str__(self):
  91. return self.value
  92. @classmethod
  93. def table_name(cls) -> str:
  94. return "documents"
  95. @classmethod
  96. def id_column(cls) -> str:
  97. return "document_id"
  98. class KGExtractionStatus(str, Enum):
  99. """Status of KG Creation per document."""
  100. PENDING = "pending"
  101. PROCESSING = "processing"
  102. SUCCESS = "success"
  103. ENRICHED = "enriched"
  104. FAILED = "failed"
  105. def __str__(self):
  106. return self.value
  107. @classmethod
  108. def table_name(cls) -> str:
  109. return "documents"
  110. @classmethod
  111. def id_column(cls) -> str:
  112. return "id"
  113. class KGEnrichmentStatus(str, Enum):
  114. """Status of KG Enrichment per collection."""
  115. PENDING = "pending"
  116. PROCESSING = "processing"
  117. OUTDATED = "outdated"
  118. SUCCESS = "success"
  119. FAILED = "failed"
  120. def __str__(self):
  121. return self.value
  122. @classmethod
  123. def table_name(cls) -> str:
  124. return "collections"
  125. @classmethod
  126. def id_column(cls) -> str:
  127. return "id"
  128. class DocumentResponse(R2RSerializable):
  129. """Base class for document information handling."""
  130. id: UUID
  131. collection_ids: list[UUID]
  132. owner_id: UUID
  133. document_type: DocumentType
  134. metadata: dict
  135. title: Optional[str] = None
  136. version: str
  137. size_in_bytes: Optional[int]
  138. ingestion_status: IngestionStatus = IngestionStatus.PENDING
  139. extraction_status: KGExtractionStatus = KGExtractionStatus.PENDING
  140. created_at: Optional[datetime] = None
  141. updated_at: Optional[datetime] = None
  142. ingestion_attempt_number: Optional[int] = None
  143. summary: Optional[str] = None
  144. summary_embedding: Optional[list[float]] = None # Add optional embedding
  145. def convert_to_db_entry(self):
  146. """Prepare the document info for database entry, extracting certain fields from metadata."""
  147. now = datetime.now()
  148. # Format the embedding properly for Postgres vector type
  149. embedding = None
  150. if self.summary_embedding is not None:
  151. embedding = f"[{','.join(str(x) for x in self.summary_embedding)}]"
  152. return {
  153. "id": self.id,
  154. "collection_ids": self.collection_ids,
  155. "owner_id": self.owner_id,
  156. "document_type": self.document_type,
  157. "metadata": json.dumps(self.metadata),
  158. "title": self.title or "N/A",
  159. "version": self.version,
  160. "size_in_bytes": self.size_in_bytes,
  161. "ingestion_status": self.ingestion_status.value,
  162. "extraction_status": self.extraction_status.value,
  163. "created_at": self.created_at or now,
  164. "updated_at": self.updated_at or now,
  165. "ingestion_attempt_number": self.ingestion_attempt_number or 0,
  166. "summary": self.summary,
  167. "summary_embedding": embedding,
  168. }
  169. class UnprocessedChunk(R2RSerializable):
  170. """An extraction from a document."""
  171. id: Optional[UUID] = None
  172. document_id: Optional[UUID] = None
  173. collection_ids: list[UUID] = []
  174. metadata: dict = {}
  175. text: str
  176. class UpdateChunk(R2RSerializable):
  177. """An extraction from a document."""
  178. id: UUID
  179. metadata: Optional[dict] = None
  180. text: str
  181. class DocumentChunk(R2RSerializable):
  182. """An extraction from a document."""
  183. id: UUID
  184. document_id: UUID
  185. collection_ids: list[UUID]
  186. owner_id: UUID
  187. data: str | bytes
  188. metadata: dict
  189. class RawChunk(R2RSerializable):
  190. text: str
  191. class IngestionMode(str, Enum):
  192. hi_res = "hi-res"
  193. fast = "fast"
  194. custom = "custom"
  195. class ChunkEnrichmentStrategy(str, Enum):
  196. SEMANTIC = "semantic"
  197. NEIGHBORHOOD = "neighborhood"
  198. def __str__(self) -> str:
  199. return self.value
  200. from .llm import GenerationConfig
  201. class ChunkEnrichmentSettings(R2RSerializable):
  202. """
  203. Settings for chunk enrichment.
  204. """
  205. enable_chunk_enrichment: bool = Field(
  206. default=False,
  207. description="Whether to enable chunk enrichment or not",
  208. )
  209. strategies: list[ChunkEnrichmentStrategy] = Field(
  210. default=[],
  211. description="The strategies to use for chunk enrichment. Union of chunks obtained from each strategy is used as context.",
  212. )
  213. forward_chunks: int = Field(
  214. default=3,
  215. description="The number after the current chunk to include in the LLM context while enriching",
  216. )
  217. backward_chunks: int = Field(
  218. default=3,
  219. description="The number of chunks before the current chunk in the LLM context while enriching",
  220. )
  221. semantic_neighbors: int = Field(
  222. default=10, description="The number of semantic neighbors to include"
  223. )
  224. semantic_similarity_threshold: float = Field(
  225. default=0.7,
  226. description="The similarity threshold for semantic neighbors",
  227. )
  228. generation_config: GenerationConfig = Field(
  229. default=GenerationConfig(),
  230. description="The generation config to use for chunk enrichment",
  231. )
  232. ## TODO - Move ingestion config
  233. class IngestionConfig(R2RSerializable):
  234. provider: str = "r2r"
  235. excluded_parsers: list[str] = ["mp4"]
  236. chunking_strategy: str = "recursive"
  237. chunk_enrichment_settings: ChunkEnrichmentSettings = (
  238. ChunkEnrichmentSettings()
  239. )
  240. extra_parsers: dict[str, Any] = {}
  241. audio_transcription_model: str = "openai/whisper-1"
  242. vision_img_prompt_name: str = "vision_img"
  243. vision_img_model: str = "openai/gpt-4o"
  244. vision_pdf_prompt_name: str = "vision_pdf"
  245. vision_pdf_model: str = "openai/gpt-4o"
  246. skip_document_summary: bool = False
  247. document_summary_system_prompt: str = "default_system"
  248. document_summary_task_prompt: str = "default_summary"
  249. chunks_for_document_summary: int = 128
  250. document_summary_model: str = "openai/gpt-4o-mini"
  251. @property
  252. def supported_providers(self) -> list[str]:
  253. return ["r2r", "unstructured_local", "unstructured_api"]
  254. def validate_config(self) -> None:
  255. if self.provider not in self.supported_providers:
  256. raise ValueError(f"Provider {self.provider} is not supported.")
  257. @classmethod
  258. def get_default(cls, mode: str) -> "IngestionConfig":
  259. """Return default ingestion configuration for a given mode."""
  260. if mode == "hi-res":
  261. # More thorough parsing, no skipping summaries, possibly larger `chunks_for_document_summary`.
  262. return cls(
  263. provider="r2r",
  264. excluded_parsers=["mp4"],
  265. chunk_enrichment_settings=ChunkEnrichmentSettings(), # default
  266. extra_parsers={},
  267. audio_transcription_model="openai/whisper-1",
  268. vision_img_prompt_name="vision_img",
  269. vision_img_model="openai/gpt-4o",
  270. vision_pdf_prompt_name="vision_pdf",
  271. vision_pdf_model="openai/gpt-4o",
  272. skip_document_summary=False,
  273. document_summary_system_prompt="default_system",
  274. document_summary_task_prompt="default_summary",
  275. chunks_for_document_summary=256, # larger for hi-res
  276. document_summary_model="openai/gpt-4o-mini",
  277. )
  278. elif mode == "fast":
  279. # Skip summaries and other enrichment steps for speed.
  280. return cls(
  281. provider="r2r",
  282. excluded_parsers=["mp4"],
  283. chunk_enrichment_settings=ChunkEnrichmentSettings(), # default
  284. extra_parsers={},
  285. audio_transcription_model="openai/whisper-1",
  286. vision_img_prompt_name="vision_img",
  287. vision_img_model="openai/gpt-4o",
  288. vision_pdf_prompt_name="vision_pdf",
  289. vision_pdf_model="openai/gpt-4o",
  290. skip_document_summary=True, # skip summaries
  291. document_summary_system_prompt="default_system",
  292. document_summary_task_prompt="default_summary",
  293. chunks_for_document_summary=64,
  294. document_summary_model="openai/gpt-4o-mini",
  295. )
  296. else:
  297. # For `custom` or any unrecognized mode, return a base config
  298. return cls()