document.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. """Abstractions for documents and their extractions."""
  2. import json
  3. import logging
  4. from datetime import datetime
  5. from enum import Enum
  6. from typing import Any, Optional
  7. from uuid import UUID, uuid4
  8. from pydantic import Field
  9. from .base import R2RSerializable
  10. from .llm import GenerationConfig
  11. logger = logging.getLogger()
  12. class DocumentType(str, Enum):
  13. """Types of documents that can be stored."""
  14. # Audio
  15. MP3 = "mp3"
  16. # CSV
  17. CSV = "csv"
  18. # Email
  19. EML = "eml"
  20. MSG = "msg"
  21. P7S = "p7s"
  22. # EPUB
  23. EPUB = "epub"
  24. # Excel
  25. XLS = "xls"
  26. XLSX = "xlsx"
  27. # HTML
  28. HTML = "html"
  29. HTM = "htm"
  30. # Image
  31. BMP = "bmp"
  32. HEIC = "heic"
  33. JPEG = "jpeg"
  34. PNG = "png"
  35. TIFF = "tiff"
  36. JPG = "jpg"
  37. SVG = "svg"
  38. # Markdown
  39. MD = "md"
  40. # Org Mode
  41. ORG = "org"
  42. # Open Office
  43. ODT = "odt"
  44. # PDF
  45. PDF = "pdf"
  46. # Plain text
  47. TXT = "txt"
  48. JSON = "json"
  49. # PowerPoint
  50. PPT = "ppt"
  51. PPTX = "pptx"
  52. # reStructured Text
  53. RST = "rst"
  54. # Rich Text
  55. RTF = "rtf"
  56. # TSV
  57. TSV = "tsv"
  58. # Video/GIF
  59. GIF = "gif"
  60. # Word
  61. DOC = "doc"
  62. DOCX = "docx"
  63. # Code
  64. PY = "py"
  65. JS = "js"
  66. TS = "ts"
  67. CSS = "css"
  68. class Document(R2RSerializable):
  69. id: UUID = Field(default_factory=uuid4)
  70. collection_ids: list[UUID]
  71. owner_id: UUID
  72. document_type: DocumentType
  73. metadata: dict
  74. class Config:
  75. arbitrary_types_allowed = True
  76. ignore_extra = False
  77. json_encoders = {
  78. UUID: str,
  79. }
  80. populate_by_name = True
  81. class IngestionStatus(str, Enum):
  82. """Status of document processing."""
  83. PENDING = "pending"
  84. PARSING = "parsing"
  85. EXTRACTING = "extracting"
  86. CHUNKING = "chunking"
  87. EMBEDDING = "embedding"
  88. AUGMENTING = "augmenting"
  89. STORING = "storing"
  90. ENRICHING = "enriching"
  91. FAILED = "failed"
  92. SUCCESS = "success"
  93. def __str__(self):
  94. return self.value
  95. @classmethod
  96. def table_name(cls) -> str:
  97. return "documents"
  98. @classmethod
  99. def id_column(cls) -> str:
  100. return "document_id"
  101. class GraphExtractionStatus(str, Enum):
  102. """Status of graph creation per document."""
  103. PENDING = "pending"
  104. PROCESSING = "processing"
  105. SUCCESS = "success"
  106. ENRICHED = "enriched"
  107. FAILED = "failed"
  108. def __str__(self):
  109. return self.value
  110. @classmethod
  111. def table_name(cls) -> str:
  112. return "documents"
  113. @classmethod
  114. def id_column(cls) -> str:
  115. return "id"
  116. class GraphConstructionStatus(str, Enum):
  117. """Status of graph enrichment per collection."""
  118. PENDING = "pending"
  119. PROCESSING = "processing"
  120. OUTDATED = "outdated"
  121. SUCCESS = "success"
  122. FAILED = "failed"
  123. def __str__(self):
  124. return self.value
  125. @classmethod
  126. def table_name(cls) -> str:
  127. return "collections"
  128. @classmethod
  129. def id_column(cls) -> str:
  130. return "id"
  131. class DocumentResponse(R2RSerializable):
  132. """Base class for document information handling."""
  133. id: UUID
  134. collection_ids: list[UUID]
  135. owner_id: UUID
  136. document_type: DocumentType
  137. metadata: dict
  138. title: Optional[str] = None
  139. version: str
  140. size_in_bytes: Optional[int]
  141. ingestion_status: IngestionStatus = IngestionStatus.PENDING
  142. extraction_status: GraphExtractionStatus = GraphExtractionStatus.PENDING
  143. created_at: Optional[datetime] = None
  144. updated_at: Optional[datetime] = None
  145. ingestion_attempt_number: Optional[int] = None
  146. summary: Optional[str] = None
  147. summary_embedding: Optional[list[float]] = None
  148. total_tokens: Optional[int] = None
  149. chunks: Optional[list] = None
  150. def convert_to_db_entry(self):
  151. """Prepare the document info for database entry, extracting certain
  152. fields from metadata."""
  153. now = datetime.now()
  154. # Format the embedding properly for Postgres vector type
  155. embedding = None
  156. if self.summary_embedding is not None:
  157. embedding = f"[{','.join(str(x) for x in self.summary_embedding)}]"
  158. return {
  159. "id": self.id,
  160. "collection_ids": self.collection_ids,
  161. "owner_id": self.owner_id,
  162. "document_type": self.document_type,
  163. "metadata": json.dumps(self.metadata),
  164. "title": self.title or "N/A",
  165. "version": self.version,
  166. "size_in_bytes": self.size_in_bytes,
  167. "ingestion_status": self.ingestion_status.value,
  168. "extraction_status": self.extraction_status.value,
  169. "created_at": self.created_at or now,
  170. "updated_at": self.updated_at or now,
  171. "ingestion_attempt_number": self.ingestion_attempt_number or 0,
  172. "summary": self.summary,
  173. "summary_embedding": embedding,
  174. "total_tokens": self.total_tokens or 0, # ensure we pass 0 if None
  175. }
  176. class Config:
  177. json_schema_extra = {
  178. "example": {
  179. "id": "123e4567-e89b-12d3-a456-426614174000",
  180. "collection_ids": ["123e4567-e89b-12d3-a456-426614174000"],
  181. "owner_id": "123e4567-e89b-12d3-a456-426614174000",
  182. "document_type": "pdf",
  183. "metadata": {"title": "Sample Document"},
  184. "title": "Sample Document",
  185. "version": "1.0",
  186. "size_in_bytes": 123456,
  187. "ingestion_status": "pending",
  188. "extraction_status": "pending",
  189. "created_at": "2021-01-01T00:00:00",
  190. "updated_at": "2021-01-01T00:00:00",
  191. "ingestion_attempt_number": 0,
  192. "summary": "A summary of the document",
  193. "summary_embedding": [0.1, 0.2, 0.3],
  194. "total_tokens": 1000,
  195. }
  196. }
  197. class UnprocessedChunk(R2RSerializable):
  198. """An extraction from a document."""
  199. id: Optional[UUID] = None
  200. document_id: Optional[UUID] = None
  201. collection_ids: list[UUID] = []
  202. metadata: dict = {}
  203. text: str
  204. class UpdateChunk(R2RSerializable):
  205. """An extraction from a document."""
  206. id: UUID
  207. metadata: Optional[dict] = None
  208. text: str
  209. class DocumentChunk(R2RSerializable):
  210. """An extraction from a document."""
  211. id: UUID
  212. document_id: UUID
  213. collection_ids: list[UUID]
  214. owner_id: UUID
  215. data: str | bytes
  216. metadata: dict
  217. class RawChunk(R2RSerializable):
  218. text: str
  219. class IngestionMode(str, Enum):
  220. hi_res = "hi-res"
  221. ocr = "ocr"
  222. fast = "fast"
  223. custom = "custom"
  224. class ChunkEnrichmentSettings(R2RSerializable):
  225. """Settings for chunk enrichment."""
  226. enable_chunk_enrichment: bool = Field(
  227. default=False,
  228. description="Whether to enable chunk enrichment or not",
  229. )
  230. n_chunks: int = Field(
  231. default=2,
  232. description="The number of preceding and succeeding chunks to include. Defaults to 2.",
  233. )
  234. generation_config: Optional[GenerationConfig] = Field(
  235. default=None,
  236. description="The generation config to use for chunk enrichment",
  237. )
  238. chunk_enrichment_prompt: Optional[str] = Field(
  239. default="chunk_enrichment",
  240. description="The prompt to use for chunk enrichment",
  241. )
  242. class IngestionConfig(R2RSerializable):
  243. provider: str = "r2r"
  244. excluded_parsers: list[str] = []
  245. chunking_strategy: str = "recursive"
  246. chunk_enrichment_settings: ChunkEnrichmentSettings = (
  247. ChunkEnrichmentSettings()
  248. )
  249. extra_parsers: dict[str, Any] = {}
  250. audio_transcription_model: str = ""
  251. vlm: Optional[str] = None
  252. vlm_batch_size: int = 5
  253. vlm_max_tokens_to_sample: int = 1024
  254. max_concurrent_vlm_tasks: int = 5
  255. vlm_ocr_one_page_per_chunk: bool = True
  256. skip_document_summary: bool = False
  257. document_summary_system_prompt: str = "system"
  258. document_summary_task_prompt: str = "summary"
  259. chunks_for_document_summary: int = 128
  260. document_summary_model: str = ""
  261. @property
  262. def supported_providers(self) -> list[str]:
  263. return ["r2r", "unstructured_local", "unstructured_api"]
  264. def validate_config(self) -> None:
  265. if self.provider not in self.supported_providers:
  266. raise ValueError(f"Provider {self.provider} is not supported.")
  267. @classmethod
  268. def get_default(cls, mode: str) -> "IngestionConfig":
  269. """Return default ingestion configuration for a given mode."""
  270. if mode == "hi-res":
  271. # More thorough parsing, no skipping summaries, possibly larger `chunks_for_document_summary`.
  272. return cls(
  273. provider="r2r",
  274. excluded_parsers=[],
  275. chunk_enrichment_settings=ChunkEnrichmentSettings(), # default
  276. extra_parsers={},
  277. audio_transcription_model="",
  278. skip_document_summary=False,
  279. document_summary_system_prompt="system",
  280. document_summary_task_prompt="summary",
  281. chunks_for_document_summary=256, # larger for hi-res
  282. document_summary_model="",
  283. )
  284. elif mode == "ocr":
  285. # Use Mistral OCR for PDFs and images.
  286. return cls(
  287. provider="r2r",
  288. excluded_parsers=[],
  289. chunk_enrichment_settings=ChunkEnrichmentSettings(), # default
  290. extra_parsers={},
  291. audio_transcription_model="",
  292. skip_document_summary=False,
  293. document_summary_system_prompt="system",
  294. document_summary_task_prompt="summary",
  295. chunks_for_document_summary=128,
  296. document_summary_model="",
  297. )
  298. elif mode == "fast":
  299. # Skip summaries and other enrichment steps for speed.
  300. return cls(
  301. provider="r2r",
  302. excluded_parsers=[],
  303. chunk_enrichment_settings=ChunkEnrichmentSettings(), # default
  304. extra_parsers={},
  305. audio_transcription_model="",
  306. skip_document_summary=True, # skip summaries
  307. document_summary_system_prompt="system",
  308. document_summary_task_prompt="summary",
  309. chunks_for_document_summary=64,
  310. document_summary_model="",
  311. )
  312. else:
  313. # For `custom` or any unrecognized mode, return a base config
  314. return cls()