document.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. """Abstractions for documents and their extractions."""
  2. import json
  3. import logging
  4. from datetime import datetime
  5. from enum import Enum
  6. from typing import Optional
  7. from uuid import UUID, uuid4
  8. from pydantic import Field
  9. from .base import R2RSerializable
  10. logger = logging.getLogger()
  11. class DocumentType(str, Enum):
  12. """Types of documents that can be stored."""
  13. # Audio
  14. MP3 = "mp3"
  15. # CSV
  16. CSV = "csv"
  17. # Email
  18. EML = "eml"
  19. MSG = "msg"
  20. P7S = "p7s"
  21. # EPUB
  22. EPUB = "epub"
  23. # Excel
  24. XLS = "xls"
  25. XLSX = "xlsx"
  26. # HTML
  27. HTML = "html"
  28. HTM = "htm"
  29. # Image
  30. BMP = "bmp"
  31. HEIC = "heic"
  32. JPEG = "jpeg"
  33. PNG = "png"
  34. TIFF = "tiff"
  35. JPG = "jpg"
  36. SVG = "svg"
  37. WEBP = "webp"
  38. ICO = "ico"
  39. # Markdown
  40. MD = "md"
  41. # Org Mode
  42. ORG = "org"
  43. # Open Office
  44. ODT = "odt"
  45. # PDF
  46. PDF = "pdf"
  47. # Plain text
  48. TXT = "txt"
  49. JSON = "json"
  50. # PowerPoint
  51. PPT = "ppt"
  52. PPTX = "pptx"
  53. # reStructured Text
  54. RST = "rst"
  55. # Rich Text
  56. RTF = "rtf"
  57. # TSV
  58. TSV = "tsv"
  59. # Video/GIF
  60. MP4 = "mp4"
  61. GIF = "gif"
  62. # Word
  63. DOC = "doc"
  64. DOCX = "docx"
  65. # XML
  66. XML = "xml"
  67. class Document(R2RSerializable):
  68. id: UUID = Field(default_factory=uuid4)
  69. collection_ids: list[UUID]
  70. owner_id: UUID
  71. document_type: DocumentType
  72. metadata: dict
  73. class Config:
  74. arbitrary_types_allowed = True
  75. ignore_extra = False
  76. json_encoders = {
  77. UUID: str,
  78. }
  79. populate_by_name = True
  80. class IngestionStatus(str, Enum):
  81. """Status of document processing."""
  82. PENDING = "pending"
  83. PARSING = "parsing"
  84. EXTRACTING = "extracting"
  85. CHUNKING = "chunking"
  86. EMBEDDING = "embedding"
  87. AUGMENTING = "augmenting"
  88. STORING = "storing"
  89. ENRICHING = "enriching"
  90. ENRICHED = "enriched"
  91. FAILED = "failed"
  92. SUCCESS = "success"
  93. def __str__(self):
  94. return self.value
  95. @classmethod
  96. def table_name(cls) -> str:
  97. return "documents"
  98. @classmethod
  99. def id_column(cls) -> str:
  100. return "document_id"
  101. class KGExtractionStatus(str, Enum):
  102. """Status of KG Creation per document."""
  103. PENDING = "pending"
  104. PROCESSING = "processing"
  105. SUCCESS = "success"
  106. ENRICHED = "enriched"
  107. FAILED = "failed"
  108. def __str__(self):
  109. return self.value
  110. @classmethod
  111. def table_name(cls) -> str:
  112. return "documents"
  113. @classmethod
  114. def id_column(cls) -> str:
  115. return "id"
  116. class KGEnrichmentStatus(str, Enum):
  117. """Status of KG Enrichment per collection."""
  118. PENDING = "pending"
  119. PROCESSING = "processing"
  120. OUTDATED = "outdated"
  121. SUCCESS = "success"
  122. FAILED = "failed"
  123. def __str__(self):
  124. return self.value
  125. @classmethod
  126. def table_name(cls) -> str:
  127. return "collections"
  128. @classmethod
  129. def id_column(cls) -> str:
  130. return "id"
  131. class DocumentResponse(R2RSerializable):
  132. """Base class for document information handling."""
  133. id: UUID
  134. collection_ids: list[UUID]
  135. owner_id: UUID
  136. document_type: DocumentType
  137. metadata: dict
  138. title: Optional[str] = None
  139. version: str
  140. size_in_bytes: Optional[int]
  141. ingestion_status: IngestionStatus = IngestionStatus.PENDING
  142. extraction_status: KGExtractionStatus = KGExtractionStatus.PENDING
  143. created_at: Optional[datetime] = None
  144. updated_at: Optional[datetime] = None
  145. ingestion_attempt_number: Optional[int] = None
  146. summary: Optional[str] = None
  147. summary_embedding: Optional[list[float]] = None # Add optional embedding
  148. def convert_to_db_entry(self):
  149. """Prepare the document info for database entry, extracting certain fields from metadata."""
  150. now = datetime.now()
  151. # Format the embedding properly for Postgres vector type
  152. embedding = None
  153. if self.summary_embedding is not None:
  154. embedding = f"[{','.join(str(x) for x in self.summary_embedding)}]"
  155. return {
  156. "id": self.id,
  157. "collection_ids": self.collection_ids,
  158. "owner_id": self.owner_id,
  159. "document_type": self.document_type,
  160. "metadata": json.dumps(self.metadata),
  161. "title": self.title or "N/A",
  162. "version": self.version,
  163. "size_in_bytes": self.size_in_bytes,
  164. "ingestion_status": self.ingestion_status.value,
  165. "extraction_status": self.extraction_status.value,
  166. "created_at": self.created_at or now,
  167. "updated_at": self.updated_at or now,
  168. "ingestion_attempt_number": self.ingestion_attempt_number or 0,
  169. "summary": self.summary,
  170. "summary_embedding": embedding,
  171. }
  172. class UnprocessedChunk(R2RSerializable):
  173. """An extraction from a document."""
  174. id: Optional[UUID] = None
  175. document_id: Optional[UUID] = None
  176. collection_ids: list[UUID] = []
  177. metadata: dict = {}
  178. text: str
  179. class UpdateChunk(R2RSerializable):
  180. """An extraction from a document."""
  181. id: UUID
  182. metadata: Optional[dict] = None
  183. text: str
  184. class DocumentChunk(R2RSerializable):
  185. """An extraction from a document."""
  186. id: UUID
  187. document_id: UUID
  188. collection_ids: list[UUID]
  189. owner_id: UUID
  190. data: str | bytes
  191. metadata: dict
  192. class RawChunk(R2RSerializable):
  193. text: str
  194. class IngestionMode(str, Enum):
  195. hi_res = "hi-res"
  196. fast = "fast"
  197. custom = "custom"
  198. class ChunkEnrichmentStrategy(str, Enum):
  199. SEMANTIC = "semantic"
  200. NEIGHBORHOOD = "neighborhood"
  201. def __str__(self) -> str:
  202. return self.value
  203. from .llm import GenerationConfig
  204. class ChunkEnrichmentSettings(R2RSerializable):
  205. """
  206. Settings for chunk enrichment.
  207. """
  208. enable_chunk_enrichment: bool = Field(
  209. default=False,
  210. description="Whether to enable chunk enrichment or not",
  211. )
  212. strategies: list[ChunkEnrichmentStrategy] = Field(
  213. default=[],
  214. description="The strategies to use for chunk enrichment. Union of chunks obtained from each strategy is used as context.",
  215. )
  216. forward_chunks: int = Field(
  217. default=3,
  218. description="The number after the current chunk to include in the LLM context while enriching",
  219. )
  220. backward_chunks: int = Field(
  221. default=3,
  222. description="The number of chunks before the current chunk in the LLM context while enriching",
  223. )
  224. semantic_neighbors: int = Field(
  225. default=10, description="The number of semantic neighbors to include"
  226. )
  227. semantic_similarity_threshold: float = Field(
  228. default=0.7,
  229. description="The similarity threshold for semantic neighbors",
  230. )
  231. generation_config: GenerationConfig = Field(
  232. default=GenerationConfig(),
  233. description="The generation config to use for chunk enrichment",
  234. )
  235. ## TODO - Move ingestion config
  236. class IngestionConfig(R2RSerializable):
  237. provider: str = "r2r"
  238. excluded_parsers: list[str] = ["mp4"]
  239. chunk_enrichment_settings: ChunkEnrichmentSettings = (
  240. ChunkEnrichmentSettings()
  241. )
  242. extra_parsers: dict[str, str] = {}
  243. audio_transcription_model: str = "openai/whisper-1"
  244. vision_img_prompt_name: str = "vision_img"
  245. vision_img_model: str = "openai/gpt-4o"
  246. vision_pdf_prompt_name: str = "vision_pdf"
  247. vision_pdf_model: str = "openai/gpt-4o"
  248. skip_document_summary: bool = False
  249. document_summary_system_prompt: str = "default_system"
  250. document_summary_task_prompt: str = "default_summary"
  251. chunks_for_document_summary: int = 128
  252. document_summary_model: str = "openai/gpt-4o-mini"
  253. @property
  254. def supported_providers(self) -> list[str]:
  255. return ["r2r", "unstructured_local", "unstructured_api"]
  256. def validate_config(self) -> None:
  257. if self.provider not in self.supported_providers:
  258. raise ValueError(f"Provider {self.provider} is not supported.")
  259. @classmethod
  260. def get_default(cls, mode: str) -> "IngestionConfig":
  261. """Return default ingestion configuration for a given mode."""
  262. if mode == "hi-res":
  263. # More thorough parsing, no skipping summaries, possibly larger `chunks_for_document_summary`.
  264. return cls(
  265. provider="r2r",
  266. excluded_parsers=["mp4"],
  267. chunk_enrichment_settings=ChunkEnrichmentSettings(), # default
  268. extra_parsers={},
  269. audio_transcription_model="openai/whisper-1",
  270. vision_img_prompt_name="vision_img",
  271. vision_img_model="openai/gpt-4o",
  272. vision_pdf_prompt_name="vision_pdf",
  273. vision_pdf_model="openai/gpt-4o",
  274. skip_document_summary=False,
  275. document_summary_system_prompt="default_system",
  276. document_summary_task_prompt="default_summary",
  277. chunks_for_document_summary=256, # larger for hi-res
  278. document_summary_model="openai/gpt-4o-mini",
  279. )
  280. elif mode == "fast":
  281. # Skip summaries and other enrichment steps for speed.
  282. return cls(
  283. provider="r2r",
  284. excluded_parsers=["mp4"],
  285. chunk_enrichment_settings=ChunkEnrichmentSettings(), # default
  286. extra_parsers={},
  287. audio_transcription_model="openai/whisper-1",
  288. vision_img_prompt_name="vision_img",
  289. vision_img_model="openai/gpt-4o",
  290. vision_pdf_prompt_name="vision_pdf",
  291. vision_pdf_model="openai/gpt-4o",
  292. skip_document_summary=True, # skip summaries
  293. document_summary_system_prompt="default_system",
  294. document_summary_task_prompt="default_summary",
  295. chunks_for_document_summary=64,
  296. document_summary_model="openai/gpt-4o-mini",
  297. )
  298. else:
  299. # For `custom` or any unrecognized mode, return a base config
  300. return cls()