ingestion_service.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932
  1. import asyncio
  2. import json
  3. import logging
  4. from datetime import datetime
  5. from typing import Any, AsyncGenerator, Optional, Sequence
  6. from uuid import UUID
  7. from fastapi import HTTPException
  8. from core.base import (
  9. Document,
  10. DocumentChunk,
  11. DocumentResponse,
  12. DocumentType,
  13. GenerationConfig,
  14. IngestionStatus,
  15. R2RException,
  16. RawChunk,
  17. UnprocessedChunk,
  18. Vector,
  19. VectorEntry,
  20. VectorType,
  21. generate_id,
  22. )
  23. from core.base.abstractions import (
  24. ChunkEnrichmentSettings,
  25. IndexMeasure,
  26. IndexMethod,
  27. R2RDocumentProcessingError,
  28. VectorTableName,
  29. )
  30. from core.base.api.models import User
  31. from shared.abstractions import PDFParsingError, PopplerNotFoundError
  32. from ..abstractions import R2RProviders
  33. from ..config import R2RConfig
  34. logger = logging.getLogger()
  35. STARTING_VERSION = "v0"
  36. class IngestionService:
  37. """A refactored IngestionService that inlines all pipe logic for parsing,
  38. embedding, and vector storage directly in its methods."""
  39. def __init__(
  40. self,
  41. config: R2RConfig,
  42. providers: R2RProviders,
  43. ) -> None:
  44. self.config = config
  45. self.providers = providers
  46. async def ingest_file_ingress(
  47. self,
  48. file_data: dict,
  49. user: User,
  50. document_id: UUID,
  51. size_in_bytes,
  52. metadata: Optional[dict] = None,
  53. version: Optional[str] = None,
  54. *args: Any,
  55. **kwargs: Any,
  56. ) -> dict:
  57. """Pre-ingests a file by creating or validating the DocumentResponse
  58. entry.
  59. Does not actually parse/ingest the content. (See parse_file() for that
  60. step.)
  61. """
  62. try:
  63. if not file_data:
  64. raise R2RException(
  65. status_code=400, message="No files provided for ingestion."
  66. )
  67. if not file_data.get("filename"):
  68. raise R2RException(
  69. status_code=400, message="File name not provided."
  70. )
  71. metadata = metadata or {}
  72. version = version or STARTING_VERSION
  73. document_info = self.create_document_info_from_file(
  74. document_id,
  75. user,
  76. file_data["filename"],
  77. metadata,
  78. version,
  79. size_in_bytes,
  80. )
  81. existing_document_info = (
  82. await self.providers.database.documents_handler.get_documents_overview(
  83. offset=0,
  84. limit=100,
  85. filter_user_ids=[user.id],
  86. filter_document_ids=[document_id],
  87. )
  88. )["results"]
  89. # Validate ingestion status for re-ingestion
  90. if len(existing_document_info) > 0:
  91. existing_doc = existing_document_info[0]
  92. if existing_doc.ingestion_status == IngestionStatus.SUCCESS:
  93. raise R2RException(
  94. status_code=409,
  95. message=(
  96. f"Document {document_id} already exists. "
  97. "Submit a DELETE request to `/documents/{document_id}` "
  98. "to delete this document and allow for re-ingestion."
  99. ),
  100. )
  101. elif existing_doc.ingestion_status != IngestionStatus.FAILED:
  102. raise R2RException(
  103. status_code=409,
  104. message=(
  105. f"Document {document_id} is currently ingesting "
  106. f"with status {existing_doc.ingestion_status}."
  107. ),
  108. )
  109. # Set to PARSING until we actually parse
  110. document_info.ingestion_status = IngestionStatus.PARSING
  111. await self.providers.database.documents_handler.upsert_documents_overview(
  112. document_info
  113. )
  114. return {
  115. "info": document_info,
  116. }
  117. except R2RException as e:
  118. logger.error(f"R2RException in ingest_file_ingress: {str(e)}")
  119. raise
  120. except Exception as e:
  121. raise HTTPException(
  122. status_code=500, detail=f"Error during ingestion: {str(e)}"
  123. ) from e
  124. def create_document_info_from_file(
  125. self,
  126. document_id: UUID,
  127. user: User,
  128. file_name: str,
  129. metadata: dict,
  130. version: str,
  131. size_in_bytes: int,
  132. ) -> DocumentResponse:
  133. file_extension = (
  134. file_name.split(".")[-1].lower() if file_name != "N/A" else "txt"
  135. )
  136. if file_extension.upper() not in DocumentType.__members__:
  137. raise R2RException(
  138. status_code=415,
  139. message=f"'{file_extension}' is not a valid DocumentType.",
  140. )
  141. metadata = metadata or {}
  142. metadata["version"] = version
  143. return DocumentResponse(
  144. id=document_id,
  145. owner_id=user.id,
  146. collection_ids=metadata.get("collection_ids", []),
  147. document_type=DocumentType[file_extension.upper()],
  148. title=(
  149. metadata.get("title", file_name.split("/")[-1])
  150. if file_name != "N/A"
  151. else "N/A"
  152. ),
  153. metadata=metadata,
  154. version=version,
  155. size_in_bytes=size_in_bytes,
  156. ingestion_status=IngestionStatus.PENDING,
  157. created_at=datetime.now(),
  158. updated_at=datetime.now(),
  159. )
  160. def _create_document_info_from_chunks(
  161. self,
  162. document_id: UUID,
  163. user: User,
  164. chunks: list[RawChunk],
  165. metadata: dict,
  166. version: str,
  167. ) -> DocumentResponse:
  168. metadata = metadata or {}
  169. metadata["version"] = version
  170. return DocumentResponse(
  171. id=document_id,
  172. owner_id=user.id,
  173. collection_ids=metadata.get("collection_ids", []),
  174. document_type=DocumentType.TXT,
  175. title=metadata.get("title", f"Ingested Chunks - {document_id}"),
  176. metadata=metadata,
  177. version=version,
  178. size_in_bytes=sum(
  179. len(chunk.text.encode("utf-8")) for chunk in chunks
  180. ),
  181. ingestion_status=IngestionStatus.PENDING,
  182. created_at=datetime.now(),
  183. updated_at=datetime.now(),
  184. )
  185. async def parse_file(
  186. self,
  187. document_info: DocumentResponse,
  188. ingestion_config: dict | None,
  189. ) -> AsyncGenerator[DocumentChunk, None]:
  190. """Reads the file content from the DB, calls the ingestion
  191. provider to parse, and yields DocumentChunk objects."""
  192. version = document_info.version or "v0"
  193. ingestion_config_override = ingestion_config or {}
  194. # The ingestion config might specify a different provider, etc.
  195. override_provider = ingestion_config_override.pop("provider", None)
  196. if (
  197. override_provider
  198. and override_provider != self.providers.ingestion.config.provider
  199. ):
  200. raise ValueError(
  201. f"Provider '{override_provider}' does not match ingestion provider "
  202. f"'{self.providers.ingestion.config.provider}'."
  203. )
  204. try:
  205. # Pull file from DB
  206. retrieved = await self.providers.file.retrieve_file(
  207. document_info.id
  208. )
  209. if not retrieved:
  210. # No file found in the DB, can't parse
  211. raise R2RDocumentProcessingError(
  212. document_id=document_info.id,
  213. error_message="No file content found in DB for this document.",
  214. )
  215. file_name, file_wrapper, file_size = retrieved
  216. # Read the content
  217. with file_wrapper as file_content_stream:
  218. file_content = file_content_stream.read()
  219. # Build a barebones Document object
  220. doc = Document(
  221. id=document_info.id,
  222. collection_ids=document_info.collection_ids,
  223. owner_id=document_info.owner_id,
  224. metadata={
  225. "document_type": document_info.document_type.value,
  226. **document_info.metadata,
  227. },
  228. document_type=document_info.document_type,
  229. )
  230. # Delegate to the ingestion provider to parse
  231. async for extraction in self.providers.ingestion.parse(
  232. file_content, # raw bytes
  233. doc,
  234. ingestion_config_override,
  235. ):
  236. # Adjust chunk ID to incorporate version
  237. # or any other needed transformations
  238. extraction.id = generate_id(f"{extraction.id}_{version}")
  239. extraction.metadata["version"] = version
  240. yield extraction
  241. except (PopplerNotFoundError, PDFParsingError) as e:
  242. raise R2RDocumentProcessingError(
  243. error_message=e.message,
  244. document_id=document_info.id,
  245. status_code=e.status_code,
  246. ) from None
  247. except Exception as e:
  248. if isinstance(e, R2RException):
  249. raise
  250. raise R2RDocumentProcessingError(
  251. document_id=document_info.id,
  252. error_message=f"Error parsing document: {str(e)}",
  253. ) from e
  254. async def augment_document_info(
  255. self,
  256. document_info: DocumentResponse,
  257. chunked_documents: list[dict],
  258. ) -> None:
  259. if not self.config.ingestion.skip_document_summary:
  260. document = f"Document Title: {document_info.title}\n"
  261. if document_info.metadata != {}:
  262. document += f"Document Metadata: {json.dumps(document_info.metadata)}\n"
  263. document += "Document Text:\n"
  264. for chunk in chunked_documents[
  265. : self.config.ingestion.chunks_for_document_summary
  266. ]:
  267. document += chunk["data"]
  268. messages = await self.providers.database.prompts_handler.get_message_payload(
  269. system_prompt_name=self.config.ingestion.document_summary_system_prompt,
  270. task_prompt_name=self.config.ingestion.document_summary_task_prompt,
  271. task_inputs={
  272. "document": document[
  273. : self.config.ingestion.document_summary_max_length
  274. ]
  275. },
  276. )
  277. response = await self.providers.llm.aget_completion(
  278. messages=messages,
  279. generation_config=GenerationConfig(
  280. model=self.config.ingestion.document_summary_model
  281. or self.config.app.fast_llm
  282. ),
  283. )
  284. document_info.summary = response.choices[0].message.content # type: ignore
  285. if not document_info.summary:
  286. raise ValueError("Expected a generated response.")
  287. embedding = await self.providers.embedding.async_get_embedding(
  288. text=document_info.summary,
  289. )
  290. document_info.summary_embedding = embedding
  291. return
  292. async def embed_document(
  293. self,
  294. chunked_documents: list[dict],
  295. embedding_batch_size: int = 8,
  296. ) -> AsyncGenerator[VectorEntry, None]:
  297. """Inline replacement for the old embedding_pipe.run(...).
  298. Batches the embedding calls and yields VectorEntry objects.
  299. """
  300. if not chunked_documents:
  301. return
  302. concurrency_limit = (
  303. self.providers.embedding.config.concurrent_request_limit or 5
  304. )
  305. extraction_batch: list[DocumentChunk] = []
  306. tasks: set[asyncio.Task] = set()
  307. async def process_batch(
  308. batch: list[DocumentChunk],
  309. ) -> list[VectorEntry]:
  310. # All text from the batch
  311. texts = [
  312. (
  313. ex.data.decode("utf-8")
  314. if isinstance(ex.data, bytes)
  315. else ex.data
  316. )
  317. for ex in batch
  318. ]
  319. # Retrieve embeddings in bulk
  320. vectors = await self.providers.embedding.async_get_embeddings(
  321. texts, # list of strings
  322. )
  323. # Zip them back together
  324. results = []
  325. for raw_vector, extraction in zip(vectors, batch, strict=False):
  326. results.append(
  327. VectorEntry(
  328. id=extraction.id,
  329. document_id=extraction.document_id,
  330. owner_id=extraction.owner_id,
  331. collection_ids=extraction.collection_ids,
  332. vector=Vector(data=raw_vector, type=VectorType.FIXED),
  333. text=(
  334. extraction.data.decode("utf-8")
  335. if isinstance(extraction.data, bytes)
  336. else str(extraction.data)
  337. ),
  338. metadata={**extraction.metadata},
  339. )
  340. )
  341. return results
  342. async def run_process_batch(batch: list[DocumentChunk]):
  343. return await process_batch(batch)
  344. # Convert each chunk dict to a DocumentChunk
  345. for chunk_dict in chunked_documents:
  346. extraction = DocumentChunk.from_dict(chunk_dict)
  347. extraction_batch.append(extraction)
  348. # If we hit a batch threshold, spawn a task
  349. if len(extraction_batch) >= embedding_batch_size:
  350. tasks.add(
  351. asyncio.create_task(run_process_batch(extraction_batch))
  352. )
  353. extraction_batch = []
  354. # If tasks are at concurrency limit, wait for the first to finish
  355. while len(tasks) >= concurrency_limit:
  356. done, tasks = await asyncio.wait(
  357. tasks, return_when=asyncio.FIRST_COMPLETED
  358. )
  359. for t in done:
  360. for vector_entry in await t:
  361. yield vector_entry
  362. # Handle any leftover items
  363. if extraction_batch:
  364. tasks.add(asyncio.create_task(run_process_batch(extraction_batch)))
  365. # Gather remaining tasks
  366. for future_task in asyncio.as_completed(tasks):
  367. for vector_entry in await future_task:
  368. yield vector_entry
  369. async def store_embeddings(
  370. self,
  371. embeddings: Sequence[dict | VectorEntry],
  372. storage_batch_size: int = 128,
  373. ) -> AsyncGenerator[str, None]:
  374. """Inline replacement for the old vector_storage_pipe.run(...).
  375. Batches up the vector entries, enforces usage limits, stores them, and
  376. yields a success/error string (or you could yield a StorageResult).
  377. """
  378. if not embeddings:
  379. return
  380. vector_entries: list[VectorEntry] = []
  381. for item in embeddings:
  382. if isinstance(item, VectorEntry):
  383. vector_entries.append(item)
  384. else:
  385. vector_entries.append(VectorEntry.from_dict(item))
  386. vector_batch: list[VectorEntry] = []
  387. document_counts: dict[UUID, int] = {}
  388. # We'll track usage from the first user we see; if your scenario allows
  389. # multiple user owners in a single ingestion, you'd need to refine usage checks.
  390. current_usage = None
  391. user_id_for_usage_check: UUID | None = None
  392. count = 0
  393. for msg in vector_entries:
  394. # If we haven't set usage yet, do so on the first chunk
  395. if current_usage is None:
  396. user_id_for_usage_check = msg.owner_id
  397. usage_data = (
  398. await self.providers.database.chunks_handler.list_chunks(
  399. limit=1,
  400. offset=0,
  401. filters={"owner_id": msg.owner_id},
  402. )
  403. )
  404. current_usage = usage_data["total_entries"]
  405. # Figure out the user's limit
  406. user = await self.providers.database.users_handler.get_user_by_id(
  407. msg.owner_id
  408. )
  409. max_chunks = (
  410. self.providers.database.config.app.default_max_chunks_per_user
  411. if self.providers.database.config.app
  412. else 1e10
  413. )
  414. if user.limits_overrides and "max_chunks" in user.limits_overrides:
  415. max_chunks = user.limits_overrides["max_chunks"]
  416. # Add to our local batch
  417. vector_batch.append(msg)
  418. document_counts[msg.document_id] = (
  419. document_counts.get(msg.document_id, 0) + 1
  420. )
  421. count += 1
  422. # Check usage
  423. if (
  424. current_usage is not None
  425. and (current_usage + len(vector_batch) + count) > max_chunks
  426. ):
  427. error_message = f"User {msg.owner_id} has exceeded the maximum number of allowed chunks: {max_chunks}"
  428. logger.error(error_message)
  429. yield error_message
  430. continue
  431. # Once we hit our batch size, store them
  432. if len(vector_batch) >= storage_batch_size:
  433. try:
  434. await (
  435. self.providers.database.chunks_handler.upsert_entries(
  436. vector_batch
  437. )
  438. )
  439. except Exception as e:
  440. logger.error(f"Failed to store vector batch: {e}")
  441. yield f"Error: {e}"
  442. vector_batch.clear()
  443. # Store any leftover items
  444. if vector_batch:
  445. try:
  446. await self.providers.database.chunks_handler.upsert_entries(
  447. vector_batch
  448. )
  449. except Exception as e:
  450. logger.error(f"Failed to store final vector batch: {e}")
  451. yield f"Error: {e}"
  452. # Summaries
  453. for doc_id, cnt in document_counts.items():
  454. info_msg = f"Successful ingestion for document_id: {doc_id}, with vector count: {cnt}"
  455. logger.info(info_msg)
  456. yield info_msg
  457. async def finalize_ingestion(
  458. self, document_info: DocumentResponse
  459. ) -> None:
  460. """Called at the end of a successful ingestion pipeline to set the
  461. document status to SUCCESS or similar final steps."""
  462. async def empty_generator():
  463. yield document_info
  464. await self.update_document_status(
  465. document_info, IngestionStatus.SUCCESS
  466. )
  467. return empty_generator()
  468. async def update_document_status(
  469. self,
  470. document_info: DocumentResponse,
  471. status: IngestionStatus,
  472. metadata: Optional[dict] = None,
  473. ) -> None:
  474. document_info.ingestion_status = status
  475. if metadata:
  476. document_info.metadata = {**document_info.metadata, **metadata}
  477. await self._update_document_status_in_db(document_info)
  478. async def _update_document_status_in_db(
  479. self, document_info: DocumentResponse
  480. ):
  481. try:
  482. await self.providers.database.documents_handler.upsert_documents_overview(
  483. document_info
  484. )
  485. except Exception as e:
  486. logger.error(
  487. f"Failed to update document status: {document_info.id}. Error: {str(e)}"
  488. )
  489. async def ingest_chunks_ingress(
  490. self,
  491. document_id: UUID,
  492. metadata: Optional[dict],
  493. chunks: list[RawChunk],
  494. user: User,
  495. *args: Any,
  496. **kwargs: Any,
  497. ) -> DocumentResponse:
  498. """Directly ingest user-provided text chunks (rather than from a
  499. file)."""
  500. if not chunks:
  501. raise R2RException(
  502. status_code=400, message="No chunks provided for ingestion."
  503. )
  504. metadata = metadata or {}
  505. version = STARTING_VERSION
  506. document_info = self._create_document_info_from_chunks(
  507. document_id,
  508. user,
  509. chunks,
  510. metadata,
  511. version,
  512. )
  513. existing_document_info = (
  514. await self.providers.database.documents_handler.get_documents_overview(
  515. offset=0,
  516. limit=100,
  517. filter_user_ids=[user.id],
  518. filter_document_ids=[document_id],
  519. )
  520. )["results"]
  521. if len(existing_document_info) > 0:
  522. existing_doc = existing_document_info[0]
  523. if existing_doc.ingestion_status != IngestionStatus.FAILED:
  524. raise R2RException(
  525. status_code=409,
  526. message=(
  527. f"Document {document_id} was already ingested "
  528. "and is not in a failed state."
  529. ),
  530. )
  531. await self.providers.database.documents_handler.upsert_documents_overview(
  532. document_info
  533. )
  534. return document_info
  535. async def update_chunk_ingress(
  536. self,
  537. document_id: UUID,
  538. chunk_id: UUID,
  539. text: str,
  540. user: User,
  541. metadata: Optional[dict] = None,
  542. *args: Any,
  543. **kwargs: Any,
  544. ) -> dict:
  545. """Update an individual chunk's text and metadata, re-embed, and re-
  546. store it."""
  547. # Verify chunk exists and user has access
  548. existing_chunks = (
  549. await self.providers.database.chunks_handler.list_document_chunks(
  550. document_id=document_id,
  551. offset=0,
  552. limit=1,
  553. )
  554. )
  555. if not existing_chunks["results"]:
  556. raise R2RException(
  557. status_code=404,
  558. message=f"Chunk with chunk_id {chunk_id} not found.",
  559. )
  560. existing_chunk = (
  561. await self.providers.database.chunks_handler.get_chunk(chunk_id)
  562. )
  563. if not existing_chunk:
  564. raise R2RException(
  565. status_code=404,
  566. message=f"Chunk with id {chunk_id} not found",
  567. )
  568. if (
  569. str(existing_chunk["owner_id"]) != str(user.id)
  570. and not user.is_superuser
  571. ):
  572. raise R2RException(
  573. status_code=403,
  574. message="You don't have permission to modify this chunk.",
  575. )
  576. # Merge metadata
  577. merged_metadata = {**existing_chunk["metadata"]}
  578. if metadata is not None:
  579. merged_metadata |= metadata
  580. # Create updated chunk
  581. extraction_data = {
  582. "id": chunk_id,
  583. "document_id": document_id,
  584. "collection_ids": kwargs.get(
  585. "collection_ids", existing_chunk["collection_ids"]
  586. ),
  587. "owner_id": existing_chunk["owner_id"],
  588. "data": text or existing_chunk["text"],
  589. "metadata": merged_metadata,
  590. }
  591. extraction = DocumentChunk(**extraction_data).model_dump()
  592. # Re-embed
  593. embeddings_generator = self.embed_document(
  594. [extraction], embedding_batch_size=1
  595. )
  596. embeddings = []
  597. async for embedding in embeddings_generator:
  598. embeddings.append(embedding)
  599. # Re-store
  600. store_gen = self.store_embeddings(embeddings, storage_batch_size=1)
  601. async for _ in store_gen:
  602. pass
  603. return extraction
  604. async def _get_enriched_chunk_text(
  605. self,
  606. chunk_idx: int,
  607. chunk: dict,
  608. document_id: UUID,
  609. document_summary: str | None,
  610. chunk_enrichment_settings: ChunkEnrichmentSettings,
  611. list_document_chunks: list[dict],
  612. ) -> VectorEntry:
  613. """Helper for chunk_enrichment.
  614. Leverages an LLM to rewrite or expand chunk text, then re-embeds it.
  615. """
  616. preceding_chunks = [
  617. list_document_chunks[idx]["text"]
  618. for idx in range(
  619. max(0, chunk_idx - chunk_enrichment_settings.n_chunks),
  620. chunk_idx,
  621. )
  622. ]
  623. succeeding_chunks = [
  624. list_document_chunks[idx]["text"]
  625. for idx in range(
  626. chunk_idx + 1,
  627. min(
  628. len(list_document_chunks),
  629. chunk_idx + chunk_enrichment_settings.n_chunks + 1,
  630. ),
  631. )
  632. ]
  633. try:
  634. # Obtain the updated text from the LLM
  635. updated_chunk_text = (
  636. (
  637. await self.providers.llm.aget_completion(
  638. messages=await self.providers.database.prompts_handler.get_message_payload(
  639. task_prompt_name=chunk_enrichment_settings.chunk_enrichment_prompt,
  640. task_inputs={
  641. "document_summary": document_summary or "None",
  642. "chunk": chunk["text"],
  643. "preceding_chunks": (
  644. "\n".join(preceding_chunks)
  645. if preceding_chunks
  646. else "None"
  647. ),
  648. "succeeding_chunks": (
  649. "\n".join(succeeding_chunks)
  650. if succeeding_chunks
  651. else "None"
  652. ),
  653. "chunk_size": self.config.ingestion.chunk_size
  654. or 1024,
  655. },
  656. ),
  657. generation_config=chunk_enrichment_settings.generation_config
  658. or GenerationConfig(model=self.config.app.fast_llm),
  659. )
  660. )
  661. .choices[0]
  662. .message.content
  663. )
  664. except Exception:
  665. updated_chunk_text = chunk["text"]
  666. chunk["metadata"]["chunk_enrichment_status"] = "failed"
  667. else:
  668. chunk["metadata"]["chunk_enrichment_status"] = (
  669. "success" if updated_chunk_text else "failed"
  670. )
  671. if not updated_chunk_text or not isinstance(updated_chunk_text, str):
  672. updated_chunk_text = str(chunk["text"])
  673. chunk["metadata"]["chunk_enrichment_status"] = "failed"
  674. # Re-embed
  675. data = await self.providers.embedding.async_get_embedding(
  676. updated_chunk_text
  677. )
  678. chunk["metadata"]["original_text"] = chunk["text"]
  679. return VectorEntry(
  680. id=generate_id(str(chunk["id"])),
  681. vector=Vector(data=data, type=VectorType.FIXED, length=len(data)),
  682. document_id=document_id,
  683. owner_id=chunk["owner_id"],
  684. collection_ids=chunk["collection_ids"],
  685. text=updated_chunk_text,
  686. metadata=chunk["metadata"],
  687. )
  688. async def chunk_enrichment(
  689. self,
  690. document_id: UUID,
  691. document_summary: str | None,
  692. chunk_enrichment_settings: ChunkEnrichmentSettings,
  693. ) -> int:
  694. """Example function that modifies chunk text via an LLM then re-embeds
  695. and re-stores all chunks for the given document."""
  696. list_document_chunks = (
  697. await self.providers.database.chunks_handler.list_document_chunks(
  698. document_id=document_id,
  699. offset=0,
  700. limit=-1,
  701. )
  702. )["results"]
  703. new_vector_entries: list[VectorEntry] = []
  704. tasks = []
  705. total_completed = 0
  706. for chunk_idx, chunk in enumerate(list_document_chunks):
  707. tasks.append(
  708. self._get_enriched_chunk_text(
  709. chunk_idx=chunk_idx,
  710. chunk=chunk,
  711. document_id=document_id,
  712. document_summary=document_summary,
  713. chunk_enrichment_settings=chunk_enrichment_settings,
  714. list_document_chunks=list_document_chunks,
  715. )
  716. )
  717. # Process in batches of e.g. 128 concurrency
  718. if len(tasks) == 128:
  719. new_vector_entries.extend(await asyncio.gather(*tasks))
  720. total_completed += 128
  721. logger.info(
  722. f"Completed {total_completed} out of {len(list_document_chunks)} chunks for document {document_id}"
  723. )
  724. tasks = []
  725. # Finish any remaining tasks
  726. new_vector_entries.extend(await asyncio.gather(*tasks))
  727. logger.info(
  728. f"Completed enrichment of {len(list_document_chunks)} chunks for document {document_id}"
  729. )
  730. # Delete old chunks from vector db
  731. await self.providers.database.chunks_handler.delete(
  732. filters={"document_id": document_id}
  733. )
  734. # Insert the newly enriched entries
  735. await self.providers.database.chunks_handler.upsert_entries(
  736. new_vector_entries
  737. )
  738. return len(new_vector_entries)
  739. async def list_chunks(
  740. self,
  741. offset: int,
  742. limit: int,
  743. filters: Optional[dict[str, Any]] = None,
  744. include_vectors: bool = False,
  745. *args: Any,
  746. **kwargs: Any,
  747. ) -> dict:
  748. return await self.providers.database.chunks_handler.list_chunks(
  749. offset=offset,
  750. limit=limit,
  751. filters=filters,
  752. include_vectors=include_vectors,
  753. )
  754. async def get_chunk(
  755. self,
  756. chunk_id: UUID,
  757. *args: Any,
  758. **kwargs: Any,
  759. ) -> dict:
  760. return await self.providers.database.chunks_handler.get_chunk(chunk_id)
  761. class IngestionServiceAdapter:
  762. @staticmethod
  763. def _parse_user_data(user_data) -> User:
  764. if isinstance(user_data, str):
  765. try:
  766. user_data = json.loads(user_data)
  767. except json.JSONDecodeError as e:
  768. raise ValueError(
  769. f"Invalid user data format: {user_data}"
  770. ) from e
  771. return User.from_dict(user_data)
  772. @staticmethod
  773. def parse_ingest_file_input(data: dict) -> dict:
  774. return {
  775. "user": IngestionServiceAdapter._parse_user_data(data["user"]),
  776. "metadata": data["metadata"],
  777. "document_id": (
  778. UUID(data["document_id"]) if data["document_id"] else None
  779. ),
  780. "version": data.get("version"),
  781. "ingestion_config": data["ingestion_config"] or {},
  782. "file_data": data["file_data"],
  783. "size_in_bytes": data["size_in_bytes"],
  784. "collection_ids": data.get("collection_ids", []),
  785. }
  786. @staticmethod
  787. def parse_ingest_chunks_input(data: dict) -> dict:
  788. return {
  789. "user": IngestionServiceAdapter._parse_user_data(data["user"]),
  790. "metadata": data["metadata"],
  791. "document_id": data["document_id"],
  792. "chunks": [
  793. UnprocessedChunk.from_dict(chunk) for chunk in data["chunks"]
  794. ],
  795. "id": data.get("id"),
  796. "collection_ids": data.get("collection_ids", []),
  797. }
  798. @staticmethod
  799. def parse_update_chunk_input(data: dict) -> dict:
  800. return {
  801. "user": IngestionServiceAdapter._parse_user_data(data["user"]),
  802. "document_id": UUID(data["document_id"]),
  803. "id": UUID(data["id"]),
  804. "text": data["text"],
  805. "metadata": data.get("metadata"),
  806. "collection_ids": data.get("collection_ids", []),
  807. }
  808. @staticmethod
  809. def parse_create_vector_index_input(data: dict) -> dict:
  810. return {
  811. "table_name": VectorTableName(data["table_name"]),
  812. "index_method": IndexMethod(data["index_method"]),
  813. "index_measure": IndexMeasure(data["index_measure"]),
  814. "index_name": data["index_name"],
  815. "index_column": data["index_column"],
  816. "index_arguments": data["index_arguments"],
  817. "concurrently": data["concurrently"],
  818. }
  819. @staticmethod
  820. def parse_list_vector_indices_input(input_data: dict) -> dict:
  821. return {"table_name": input_data["table_name"]}
  822. @staticmethod
  823. def parse_delete_vector_index_input(input_data: dict) -> dict:
  824. return {
  825. "index_name": input_data["index_name"],
  826. "table_name": input_data.get("table_name"),
  827. "concurrently": input_data.get("concurrently", True),
  828. }
  829. @staticmethod
  830. def parse_select_vector_index_input(input_data: dict) -> dict:
  831. return {
  832. "index_name": input_data["index_name"],
  833. "table_name": input_data.get("table_name"),
  834. }