chunks.py 56 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488
  1. import copy
  2. import json
  3. import logging
  4. import time
  5. import uuid
  6. from typing import Any, Optional, TypedDict
  7. from uuid import UUID
  8. import numpy as np
  9. from core.base import (
  10. ChunkSearchResult,
  11. Handler,
  12. IndexArgsHNSW,
  13. IndexArgsIVFFlat,
  14. IndexMeasure,
  15. IndexMethod,
  16. R2RException,
  17. SearchSettings,
  18. VectorEntry,
  19. VectorQuantizationType,
  20. VectorTableName,
  21. )
  22. from .base import PostgresConnectionManager
  23. from .vecs.exc import ArgError, FilterError
  24. logger = logging.getLogger()
  25. from core.base.utils import _decorate_vector_type
  26. def psql_quote_literal(value: str) -> str:
  27. """
  28. Safely quote a string literal for PostgreSQL to prevent SQL injection.
  29. This is a simple implementation - in production, you should use proper parameterization
  30. or your database driver's quoting functions.
  31. """
  32. return "'" + value.replace("'", "''") + "'"
  33. def index_measure_to_ops(
  34. measure: IndexMeasure,
  35. quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
  36. ):
  37. return _decorate_vector_type(measure.ops, quantization_type)
  38. def quantize_vector_to_binary(
  39. vector: list[float] | np.ndarray,
  40. threshold: float = 0.0,
  41. ) -> bytes:
  42. """
  43. Quantizes a float vector to a binary vector string for PostgreSQL bit type.
  44. Used when quantization_type is INT1.
  45. Args:
  46. vector (List[float] | np.ndarray): Input vector of floats
  47. threshold (float, optional): Threshold for binarization. Defaults to 0.0.
  48. Returns:
  49. str: Binary string representation for PostgreSQL bit type
  50. """
  51. # Convert input to numpy array if it isn't already
  52. if not isinstance(vector, np.ndarray):
  53. vector = np.array(vector)
  54. # Convert to binary (1 where value > threshold, 0 otherwise)
  55. binary_vector = (vector > threshold).astype(int)
  56. # Convert to string of 1s and 0s
  57. # Convert to string of 1s and 0s, then to bytes
  58. binary_string = "".join(map(str, binary_vector))
  59. return binary_string.encode("ascii")
  60. class HybridSearchIntermediateResult(TypedDict):
  61. semantic_rank: int
  62. full_text_rank: int
  63. data: ChunkSearchResult
  64. rrf_score: float
  65. class PostgresChunksHandler(Handler):
  66. TABLE_NAME = VectorTableName.CHUNKS
  67. COLUMN_VARS = [
  68. "id",
  69. "document_id",
  70. "owner_id",
  71. "collection_ids",
  72. ]
  73. def __init__(
  74. self,
  75. project_name: str,
  76. connection_manager: PostgresConnectionManager,
  77. dimension: int,
  78. quantization_type: VectorQuantizationType,
  79. ):
  80. super().__init__(project_name, connection_manager)
  81. self.dimension = dimension
  82. self.quantization_type = quantization_type
  83. async def create_tables(self):
  84. # Check for old table name first
  85. check_query = """
  86. SELECT EXISTS (
  87. SELECT FROM pg_tables
  88. WHERE schemaname = $1
  89. AND tablename = $2
  90. );
  91. """
  92. old_table_exists = await self.connection_manager.fetch_query(
  93. check_query, (self.project_name, self.project_name)
  94. )
  95. if len(old_table_exists) > 0 and old_table_exists[0]["exists"]:
  96. raise ValueError(
  97. f"Found old vector table '{self.project_name}.{self.project_name}'. "
  98. "Please run `r2r db upgrade` with the CLI, or to run manually, "
  99. "run in R2R/py/migrations with 'alembic upgrade head' to update "
  100. "your database schema to the new version."
  101. )
  102. binary_col = (
  103. ""
  104. if self.quantization_type != VectorQuantizationType.INT1
  105. else f"vec_binary bit({self.dimension}),"
  106. )
  107. query = f"""
  108. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (
  109. id UUID PRIMARY KEY,
  110. document_id UUID,
  111. owner_id UUID,
  112. collection_ids UUID[],
  113. vec vector({self.dimension}),
  114. {binary_col}
  115. text TEXT,
  116. metadata JSONB,
  117. fts tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
  118. );
  119. CREATE INDEX IF NOT EXISTS idx_vectors_document_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (document_id);
  120. CREATE INDEX IF NOT EXISTS idx_vectors_owner_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (owner_id);
  121. CREATE INDEX IF NOT EXISTS idx_vectors_collection_ids ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (collection_ids);
  122. CREATE INDEX IF NOT EXISTS idx_vectors_text ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (to_tsvector('english', text));
  123. """
  124. await self.connection_manager.execute_query(query)
  125. async def upsert(self, entry: VectorEntry) -> None:
  126. """
  127. Upsert function that handles vector quantization only when quantization_type is INT1.
  128. Matches the table schema where vec_binary column only exists for INT1 quantization.
  129. """
  130. # Check the quantization type to determine which columns to use
  131. if self.quantization_type == VectorQuantizationType.INT1:
  132. # For quantized vectors, use vec_binary column
  133. query = f"""
  134. INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  135. (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
  136. VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
  137. ON CONFLICT (id) DO UPDATE SET
  138. document_id = EXCLUDED.document_id,
  139. owner_id = EXCLUDED.owner_id,
  140. collection_ids = EXCLUDED.collection_ids,
  141. vec = EXCLUDED.vec,
  142. vec_binary = EXCLUDED.vec_binary,
  143. text = EXCLUDED.text,
  144. metadata = EXCLUDED.metadata;
  145. """
  146. await self.connection_manager.execute_query(
  147. query,
  148. (
  149. entry.id,
  150. entry.document_id,
  151. entry.owner_id,
  152. entry.collection_ids,
  153. str(entry.vector.data),
  154. quantize_vector_to_binary(
  155. entry.vector.data
  156. ), # Convert to binary
  157. entry.text,
  158. json.dumps(entry.metadata),
  159. ),
  160. )
  161. else:
  162. # For regular vectors, use vec column only
  163. query = f"""
  164. INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  165. (id, document_id, owner_id, collection_ids, vec, text, metadata)
  166. VALUES ($1, $2, $3, $4, $5, $6, $7)
  167. ON CONFLICT (id) DO UPDATE SET
  168. document_id = EXCLUDED.document_id,
  169. owner_id = EXCLUDED.owner_id,
  170. collection_ids = EXCLUDED.collection_ids,
  171. vec = EXCLUDED.vec,
  172. text = EXCLUDED.text,
  173. metadata = EXCLUDED.metadata;
  174. """
  175. await self.connection_manager.execute_query(
  176. query,
  177. (
  178. entry.id,
  179. entry.document_id,
  180. entry.owner_id,
  181. entry.collection_ids,
  182. str(entry.vector.data),
  183. entry.text,
  184. json.dumps(entry.metadata),
  185. ),
  186. )
  187. async def upsert_entries(self, entries: list[VectorEntry]) -> None:
  188. """
  189. Batch upsert function that handles vector quantization only when quantization_type is INT1.
  190. Matches the table schema where vec_binary column only exists for INT1 quantization.
  191. """
  192. if self.quantization_type == VectorQuantizationType.INT1:
  193. # For quantized vectors, use vec_binary column
  194. query = f"""
  195. INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  196. (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
  197. VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
  198. ON CONFLICT (id) DO UPDATE SET
  199. document_id = EXCLUDED.document_id,
  200. owner_id = EXCLUDED.owner_id,
  201. collection_ids = EXCLUDED.collection_ids,
  202. vec = EXCLUDED.vec,
  203. vec_binary = EXCLUDED.vec_binary,
  204. text = EXCLUDED.text,
  205. metadata = EXCLUDED.metadata;
  206. """
  207. bin_params = [
  208. (
  209. entry.id,
  210. entry.document_id,
  211. entry.owner_id,
  212. entry.collection_ids,
  213. str(entry.vector.data),
  214. quantize_vector_to_binary(
  215. entry.vector.data
  216. ), # Convert to binary
  217. entry.text,
  218. json.dumps(entry.metadata),
  219. )
  220. for entry in entries
  221. ]
  222. await self.connection_manager.execute_many(query, bin_params)
  223. else:
  224. # For regular vectors, use vec column only
  225. query = f"""
  226. INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  227. (id, document_id, owner_id, collection_ids, vec, text, metadata)
  228. VALUES ($1, $2, $3, $4, $5, $6, $7)
  229. ON CONFLICT (id) DO UPDATE SET
  230. document_id = EXCLUDED.document_id,
  231. owner_id = EXCLUDED.owner_id,
  232. collection_ids = EXCLUDED.collection_ids,
  233. vec = EXCLUDED.vec,
  234. text = EXCLUDED.text,
  235. metadata = EXCLUDED.metadata;
  236. """
  237. params = [
  238. (
  239. entry.id,
  240. entry.document_id,
  241. entry.owner_id,
  242. entry.collection_ids,
  243. str(entry.vector.data),
  244. entry.text,
  245. json.dumps(entry.metadata),
  246. )
  247. for entry in entries
  248. ]
  249. await self.connection_manager.execute_many(query, params)
  250. async def semantic_search(
  251. self, query_vector: list[float], search_settings: SearchSettings
  252. ) -> list[ChunkSearchResult]:
  253. try:
  254. imeasure_obj = IndexMeasure(
  255. search_settings.chunk_settings.index_measure
  256. )
  257. except ValueError:
  258. raise ValueError("Invalid index measure")
  259. table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
  260. cols = [
  261. f"{table_name}.id",
  262. f"{table_name}.document_id",
  263. f"{table_name}.owner_id",
  264. f"{table_name}.collection_ids",
  265. f"{table_name}.text",
  266. ]
  267. params: list[str | int | bytes] = []
  268. # For binary vectors (INT1), implement two-stage search
  269. if self.quantization_type == VectorQuantizationType.INT1:
  270. # Convert query vector to binary format
  271. binary_query = quantize_vector_to_binary(query_vector)
  272. # TODO - Put depth multiplier in config / settings
  273. extended_limit = (
  274. search_settings.limit * 20
  275. ) # Get 20x candidates for re-ranking
  276. if (
  277. imeasure_obj == IndexMeasure.hamming_distance
  278. or imeasure_obj == IndexMeasure.jaccard_distance
  279. ):
  280. binary_search_measure_repr = imeasure_obj.pgvector_repr
  281. else:
  282. binary_search_measure_repr = (
  283. IndexMeasure.hamming_distance.pgvector_repr
  284. )
  285. # Use binary column and binary-specific distance measures for first stage
  286. stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit({self.dimension})"
  287. stage1_param = binary_query
  288. cols.append(
  289. f"{table_name}.vec"
  290. ) # Need original vector for re-ranking
  291. if search_settings.include_metadatas:
  292. cols.append(f"{table_name}.metadata")
  293. select_clause = ", ".join(cols)
  294. where_clause = ""
  295. params.append(stage1_param)
  296. if search_settings.filters:
  297. where_clause = self._build_filters(
  298. search_settings.filters, params
  299. )
  300. where_clause = f"WHERE {where_clause}"
  301. # First stage: Get candidates using binary search
  302. query = f"""
  303. WITH candidates AS (
  304. SELECT {select_clause},
  305. ({stage1_distance}) as binary_distance
  306. FROM {table_name}
  307. {where_clause}
  308. ORDER BY {stage1_distance}
  309. LIMIT ${len(params) + 1}
  310. OFFSET ${len(params) + 2}
  311. )
  312. -- Second stage: Re-rank using original vectors
  313. SELECT
  314. id,
  315. document_id,
  316. owner_id,
  317. collection_ids,
  318. text,
  319. {"metadata," if search_settings.include_metadatas else ""}
  320. (vec <=> ${len(params) + 4}::vector({self.dimension})) as distance
  321. FROM candidates
  322. ORDER BY distance
  323. LIMIT ${len(params) + 3}
  324. """
  325. params.extend(
  326. [
  327. extended_limit, # First stage limit
  328. search_settings.offset,
  329. search_settings.limit, # Final limit
  330. str(query_vector), # For re-ranking
  331. ]
  332. )
  333. else:
  334. # Standard float vector handling - unchanged from original
  335. distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector({self.dimension})"
  336. query_param = str(query_vector)
  337. if search_settings.include_scores:
  338. cols.append(f"({distance_calc}) AS distance")
  339. if search_settings.include_metadatas:
  340. cols.append(f"{table_name}.metadata")
  341. select_clause = ", ".join(cols)
  342. where_clause = ""
  343. params.append(query_param)
  344. if search_settings.filters:
  345. where_clause = self._build_filters(
  346. search_settings.filters, params
  347. )
  348. where_clause = f"WHERE {where_clause}"
  349. query = f"""
  350. SELECT {select_clause}
  351. FROM {table_name}
  352. {where_clause}
  353. ORDER BY {distance_calc}
  354. LIMIT ${len(params) + 1}
  355. OFFSET ${len(params) + 2}
  356. """
  357. params.extend([search_settings.limit, search_settings.offset])
  358. results = await self.connection_manager.fetch_query(query, params)
  359. return [
  360. ChunkSearchResult(
  361. id=UUID(str(result["id"])),
  362. document_id=UUID(str(result["document_id"])),
  363. owner_id=UUID(str(result["owner_id"])),
  364. collection_ids=result["collection_ids"],
  365. text=result["text"],
  366. score=(
  367. (1 - float(result["distance"]))
  368. if "distance" in result
  369. else -1
  370. ),
  371. metadata=(
  372. json.loads(result["metadata"])
  373. if search_settings.include_metadatas
  374. else {}
  375. ),
  376. )
  377. for result in results
  378. ]
  379. async def full_text_search(
  380. self, query_text: str, search_settings: SearchSettings
  381. ) -> list[ChunkSearchResult]:
  382. where_clauses = []
  383. params: list[str | int | bytes] = [query_text]
  384. if search_settings.filters:
  385. filters_clause = self._build_filters(
  386. search_settings.filters, params
  387. )
  388. where_clauses.append(filters_clause)
  389. if where_clauses:
  390. where_clause = (
  391. "WHERE "
  392. + " AND ".join(where_clauses)
  393. + " AND fts @@ websearch_to_tsquery('english', $1)"
  394. )
  395. else:
  396. where_clause = "WHERE fts @@ websearch_to_tsquery('english', $1)"
  397. query = f"""
  398. SELECT
  399. id, document_id, owner_id, collection_ids, text, metadata,
  400. ts_rank(fts, websearch_to_tsquery('english', $1), 32) as rank
  401. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  402. {where_clause}
  403. """
  404. query += f"""
  405. ORDER BY rank DESC
  406. OFFSET ${len(params)+1} LIMIT ${len(params)+2}
  407. """
  408. params.extend(
  409. [
  410. search_settings.offset,
  411. search_settings.hybrid_settings.full_text_limit,
  412. ]
  413. )
  414. results = await self.connection_manager.fetch_query(query, params)
  415. return [
  416. ChunkSearchResult(
  417. id=UUID(str(r["id"])),
  418. document_id=UUID(str(r["document_id"])),
  419. owner_id=UUID(str(r["owner_id"])),
  420. collection_ids=r["collection_ids"],
  421. text=r["text"],
  422. score=float(r["rank"]),
  423. metadata=json.loads(r["metadata"]),
  424. )
  425. for r in results
  426. ]
  427. async def hybrid_search(
  428. self,
  429. query_text: str,
  430. query_vector: list[float],
  431. search_settings: SearchSettings,
  432. *args,
  433. **kwargs,
  434. ) -> list[ChunkSearchResult]:
  435. if search_settings.hybrid_settings is None:
  436. raise ValueError(
  437. "Please provide a valid `hybrid_settings` in the `search_settings`."
  438. )
  439. if (
  440. search_settings.hybrid_settings.full_text_limit
  441. < search_settings.limit
  442. ):
  443. raise ValueError(
  444. "The `full_text_limit` must be greater than or equal to the `limit`."
  445. )
  446. semantic_settings = copy.deepcopy(search_settings)
  447. semantic_settings.limit += search_settings.offset
  448. full_text_settings = copy.deepcopy(search_settings)
  449. full_text_settings.hybrid_settings.full_text_limit += (
  450. search_settings.offset
  451. )
  452. semantic_results: list[ChunkSearchResult] = await self.semantic_search(
  453. query_vector, semantic_settings
  454. )
  455. full_text_results: list[ChunkSearchResult] = (
  456. await self.full_text_search(query_text, full_text_settings)
  457. )
  458. semantic_limit = search_settings.limit
  459. full_text_limit = search_settings.hybrid_settings.full_text_limit
  460. semantic_weight = search_settings.hybrid_settings.semantic_weight
  461. full_text_weight = search_settings.hybrid_settings.full_text_weight
  462. rrf_k = search_settings.hybrid_settings.rrf_k
  463. combined_results: dict[uuid.UUID, HybridSearchIntermediateResult] = {}
  464. for rank, result in enumerate(semantic_results, 1):
  465. combined_results[result.id] = {
  466. "semantic_rank": rank,
  467. "full_text_rank": full_text_limit,
  468. "data": result,
  469. "rrf_score": 0.0, # Initialize with 0, will be calculated later
  470. }
  471. for rank, result in enumerate(full_text_results, 1):
  472. if result.id in combined_results:
  473. combined_results[result.id]["full_text_rank"] = rank
  474. else:
  475. combined_results[result.id] = {
  476. "semantic_rank": semantic_limit,
  477. "full_text_rank": rank,
  478. "data": result,
  479. "rrf_score": 0.0, # Initialize with 0, will be calculated later
  480. }
  481. combined_results = {
  482. k: v
  483. for k, v in combined_results.items()
  484. if v["semantic_rank"] <= semantic_limit * 2
  485. and v["full_text_rank"] <= full_text_limit * 2
  486. }
  487. for hyb_result in combined_results.values():
  488. semantic_score = 1 / (rrf_k + hyb_result["semantic_rank"])
  489. full_text_score = 1 / (rrf_k + hyb_result["full_text_rank"])
  490. hyb_result["rrf_score"] = (
  491. semantic_score * semantic_weight
  492. + full_text_score * full_text_weight
  493. ) / (semantic_weight + full_text_weight)
  494. sorted_results = sorted(
  495. combined_results.values(),
  496. key=lambda x: x["rrf_score"],
  497. reverse=True,
  498. )
  499. offset_results = sorted_results[
  500. search_settings.offset : search_settings.offset
  501. + search_settings.limit
  502. ]
  503. return [
  504. ChunkSearchResult(
  505. id=result["data"].id,
  506. document_id=result["data"].document_id,
  507. owner_id=result["data"].owner_id,
  508. collection_ids=result["data"].collection_ids,
  509. text=result["data"].text,
  510. score=result["rrf_score"],
  511. metadata={
  512. **result["data"].metadata,
  513. "semantic_rank": result["semantic_rank"],
  514. "full_text_rank": result["full_text_rank"],
  515. },
  516. )
  517. for result in offset_results
  518. ]
  519. async def delete(
  520. self, filters: dict[str, Any]
  521. ) -> dict[str, dict[str, str]]:
  522. params: list[str | int | bytes] = []
  523. where_clause = self._build_filters(filters, params)
  524. query = f"""
  525. DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  526. WHERE {where_clause}
  527. RETURNING id, document_id, text;
  528. """
  529. results = await self.connection_manager.fetch_query(query, params)
  530. return {
  531. str(result["id"]): {
  532. "status": "deleted",
  533. "id": str(result["id"]),
  534. "document_id": str(result["document_id"]),
  535. "text": result["text"],
  536. }
  537. for result in results
  538. }
  539. async def assign_document_chunks_to_collection(
  540. self, document_id: UUID, collection_id: UUID
  541. ) -> None:
  542. query = f"""
  543. UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  544. SET collection_ids = array_append(collection_ids, $1)
  545. WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids));
  546. """
  547. result = await self.connection_manager.execute_query(
  548. query, (str(collection_id), str(document_id))
  549. )
  550. return result
  551. async def remove_document_from_collection_vector(
  552. self, document_id: UUID, collection_id: UUID
  553. ) -> None:
  554. query = f"""
  555. UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  556. SET collection_ids = array_remove(collection_ids, $1)
  557. WHERE document_id = $2;
  558. """
  559. await self.connection_manager.execute_query(
  560. query, (collection_id, document_id)
  561. )
  562. async def delete_user_vector(self, owner_id: UUID) -> None:
  563. query = f"""
  564. DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  565. WHERE owner_id = $1;
  566. """
  567. await self.connection_manager.execute_query(query, (owner_id,))
  568. async def delete_collection_vector(self, collection_id: UUID) -> None:
  569. query = f"""
  570. DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  571. WHERE $1 = ANY(collection_ids)
  572. RETURNING collection_ids
  573. """
  574. results = await self.connection_manager.fetchrow_query(
  575. query, (collection_id,)
  576. )
  577. return None
  578. async def list_document_chunks(
  579. self,
  580. document_id: UUID,
  581. offset: int,
  582. limit: int,
  583. include_vectors: bool = False,
  584. ) -> dict[str, Any]:
  585. vector_select = ", vec" if include_vectors else ""
  586. limit_clause = f"LIMIT {limit}" if limit > -1 else ""
  587. query = f"""
  588. SELECT id, document_id, owner_id, collection_ids, text, metadata{vector_select}, COUNT(*) OVER() AS total
  589. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  590. WHERE document_id = $1
  591. ORDER BY (metadata->>'chunk_order')::integer
  592. OFFSET $2
  593. {limit_clause};
  594. """
  595. params = [document_id, offset]
  596. results = await self.connection_manager.fetch_query(query, params)
  597. chunks = []
  598. total = 0
  599. if results:
  600. total = results[0].get("total", 0)
  601. chunks = [
  602. {
  603. "id": result["id"],
  604. "document_id": result["document_id"],
  605. "owner_id": result["owner_id"],
  606. "collection_ids": result["collection_ids"],
  607. "text": result["text"],
  608. "metadata": json.loads(result["metadata"]),
  609. "vector": (
  610. json.loads(result["vec"]) if include_vectors else None
  611. ),
  612. }
  613. for result in results
  614. ]
  615. return {"results": chunks, "total_entries": total}
  616. async def get_chunk(self, id: UUID) -> dict:
  617. query = f"""
  618. SELECT id, document_id, owner_id, collection_ids, text, metadata
  619. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  620. WHERE id = $1;
  621. """
  622. result = await self.connection_manager.fetchrow_query(query, (id,))
  623. if result:
  624. return {
  625. "id": result["id"],
  626. "document_id": result["document_id"],
  627. "owner_id": result["owner_id"],
  628. "collection_ids": result["collection_ids"],
  629. "text": result["text"],
  630. "metadata": json.loads(result["metadata"]),
  631. }
  632. raise R2RException(
  633. message=f"Chunk with ID {id} not found", status_code=404
  634. )
  635. async def create_index(
  636. self,
  637. table_name: Optional[VectorTableName] = None,
  638. index_measure: IndexMeasure = IndexMeasure.cosine_distance,
  639. index_method: IndexMethod = IndexMethod.auto,
  640. index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = None,
  641. index_name: Optional[str] = None,
  642. index_column: Optional[str] = None,
  643. concurrently: bool = True,
  644. ) -> None:
  645. """
  646. Creates an index for the collection.
  647. Note:
  648. When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step
  649. process that enables performant indexes to be built for large collections with low end
  650. database hardware.
  651. Those steps are:
  652. - Creates a new table with a different name
  653. - Randomly selects records from the existing table
  654. - Inserts the random records from the existing table into the new table
  655. - Creates the requested vector index on the new table
  656. - Upserts all data from the existing table into the new table
  657. - Drops the existing table
  658. - Renames the new table to the existing tables name
  659. If you create dependencies (like views) on the table that underpins
  660. a `vecs.Collection` the `create_index` step may require you to drop those dependencies before
  661. it will succeed.
  662. Args:
  663. index_measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'.
  664. index_method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'.
  665. index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments
  666. index_name (str, optional): The name of the index to create. Defaults to None.
  667. concurrently (bool, optional): Whether to create the index concurrently. Defaults to True.
  668. Raises:
  669. ArgError: If an invalid index method is used, or if *replace* is False and an index already exists.
  670. """
  671. if table_name == VectorTableName.CHUNKS:
  672. table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}" # TODO - Fix bug in vector table naming convention
  673. if index_column:
  674. col_name = index_column
  675. else:
  676. col_name = (
  677. "vec"
  678. if (
  679. index_measure != IndexMeasure.hamming_distance
  680. and index_measure != IndexMeasure.jaccard_distance
  681. )
  682. else "vec_binary"
  683. )
  684. elif table_name == VectorTableName.ENTITIES_DOCUMENT:
  685. table_name_str = (
  686. f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
  687. )
  688. col_name = "description_embedding"
  689. elif table_name == VectorTableName.GRAPHS_ENTITIES:
  690. table_name_str = (
  691. f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
  692. )
  693. col_name = "description_embedding"
  694. elif table_name == VectorTableName.COMMUNITIES:
  695. table_name_str = (
  696. f"{self.project_name}.{VectorTableName.COMMUNITIES}"
  697. )
  698. col_name = "embedding"
  699. else:
  700. raise ArgError("invalid table name")
  701. if index_method not in (
  702. IndexMethod.ivfflat,
  703. IndexMethod.hnsw,
  704. IndexMethod.auto,
  705. ):
  706. raise ArgError("invalid index method")
  707. if index_arguments:
  708. # Disallow case where user submits index arguments but uses the
  709. # IndexMethod.auto index (index build arguments should only be
  710. # used with a specific index)
  711. if index_method == IndexMethod.auto:
  712. raise ArgError(
  713. "Index build parameters are not allowed when using the IndexMethod.auto index."
  714. )
  715. # Disallow case where user specifies one index type but submits
  716. # index build arguments for the other index type
  717. if (
  718. isinstance(index_arguments, IndexArgsHNSW)
  719. and index_method != IndexMethod.hnsw
  720. ) or (
  721. isinstance(index_arguments, IndexArgsIVFFlat)
  722. and index_method != IndexMethod.ivfflat
  723. ):
  724. raise ArgError(
  725. f"{index_arguments.__class__.__name__} build parameters were supplied but {index_method} index was specified."
  726. )
  727. if index_method == IndexMethod.auto:
  728. index_method = IndexMethod.hnsw
  729. ops = index_measure_to_ops(
  730. index_measure # , quantization_type=self.quantization_type
  731. )
  732. if ops is None:
  733. raise ArgError("Unknown index measure")
  734. concurrently_sql = "CONCURRENTLY" if concurrently else ""
  735. index_name = (
  736. index_name
  737. or f"ix_{ops}_{index_method}__{col_name}_{time.strftime('%Y%m%d%H%M%S')}"
  738. )
  739. create_index_sql = f"""
  740. CREATE INDEX {concurrently_sql} {index_name}
  741. ON {table_name_str}
  742. USING {index_method} ({col_name} {ops}) {self._get_index_options(index_method, index_arguments)};
  743. """
  744. try:
  745. if concurrently:
  746. async with (
  747. self.connection_manager.pool.get_connection() as conn # type: ignore
  748. ):
  749. # Disable automatic transaction management
  750. await conn.execute(
  751. "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
  752. )
  753. await conn.execute(create_index_sql)
  754. else:
  755. # Non-concurrent index creation can use normal query execution
  756. await self.connection_manager.execute_query(create_index_sql)
  757. except Exception as e:
  758. raise Exception(f"Failed to create index: {e}")
  759. return None
  760. def _build_filters(
  761. self, filters: dict, parameters: list[str | int | bytes]
  762. ) -> str:
  763. def parse_condition(key: str, value: Any) -> str: # type: ignore
  764. # nonlocal parameters
  765. if key in self.COLUMN_VARS:
  766. # Handle column-based filters
  767. if isinstance(value, dict):
  768. op, clause = next(iter(value.items()))
  769. if op == "$eq":
  770. parameters.append(clause)
  771. return f"{key} = ${len(parameters)}"
  772. elif op == "$ne":
  773. parameters.append(clause)
  774. return f"{key} != ${len(parameters)}"
  775. elif op == "$in":
  776. parameters.append(clause)
  777. return f"{key} = ANY(${len(parameters)})"
  778. elif op == "$nin":
  779. parameters.append(clause)
  780. return f"{key} != ALL(${len(parameters)})"
  781. elif op == "$overlap":
  782. parameters.append(clause)
  783. return f"{key} && ${len(parameters)}"
  784. elif op == "$contains":
  785. parameters.append(clause)
  786. return f"{key} @> ${len(parameters)}"
  787. elif op == "$any":
  788. if key == "collection_ids":
  789. parameters.append(f"%{clause}%")
  790. return f"array_to_string({key}, ',') LIKE ${len(parameters)}"
  791. parameters.append(clause)
  792. return f"${len(parameters)} = ANY({key})"
  793. else:
  794. raise FilterError(
  795. f"Unsupported operator for column {key}: {op}"
  796. )
  797. else:
  798. # Handle direct equality
  799. parameters.append(value)
  800. return f"{key} = ${len(parameters)}"
  801. else:
  802. # Handle JSON-based filters
  803. json_col = "metadata"
  804. if key.startswith("metadata."):
  805. key = key.split("metadata.")[1]
  806. if isinstance(value, dict):
  807. op, clause = next(iter(value.items()))
  808. if op not in (
  809. "$eq",
  810. "$ne",
  811. "$lt",
  812. "$lte",
  813. "$gt",
  814. "$gte",
  815. "$in",
  816. "$contains",
  817. ):
  818. raise FilterError("unknown operator")
  819. if op == "$eq":
  820. parameters.append(json.dumps(clause))
  821. return (
  822. f"{json_col}->'{key}' = ${len(parameters)}::jsonb"
  823. )
  824. elif op == "$ne":
  825. parameters.append(json.dumps(clause))
  826. return (
  827. f"{json_col}->'{key}' != ${len(parameters)}::jsonb"
  828. )
  829. elif op == "$lt":
  830. parameters.append(json.dumps(clause))
  831. return f"({json_col}->'{key}')::float < (${len(parameters)}::jsonb)::float"
  832. elif op == "$lte":
  833. parameters.append(json.dumps(clause))
  834. return f"({json_col}->'{key}')::float <= (${len(parameters)}::jsonb)::float"
  835. elif op == "$gt":
  836. parameters.append(json.dumps(clause))
  837. return f"({json_col}->'{key}')::float > (${len(parameters)}::jsonb)::float"
  838. elif op == "$gte":
  839. parameters.append(json.dumps(clause))
  840. return f"({json_col}->'{key}')::float >= (${len(parameters)}::jsonb)::float"
  841. elif op == "$in":
  842. # Ensure clause is a list
  843. if not isinstance(clause, list):
  844. raise FilterError(
  845. "argument to $in filter must be a list"
  846. )
  847. # Append the Python list as a parameter; many drivers can convert Python lists to arrays
  848. parameters.append(clause)
  849. # Cast the parameter to a text array type
  850. return f"(metadata->>'{key}')::text = ANY(${len(parameters)}::text[])"
  851. # elif op == "$in":
  852. # if not isinstance(clause, list):
  853. # raise FilterError(
  854. # "argument to $in filter must be a list"
  855. # )
  856. # parameters.append(json.dumps(clause))
  857. # return f"{json_col}->'{key}' = ANY(SELECT jsonb_array_elements(${len(parameters)}::jsonb))"
  858. elif op == "$contains":
  859. if isinstance(clause, (int, float, str)):
  860. clause = [clause]
  861. # Now clause is guaranteed to be a list or array-like structure.
  862. parameters.append(json.dumps(clause))
  863. return (
  864. f"{json_col}->'{key}' @> ${len(parameters)}::jsonb"
  865. )
  866. # if not isinstance(clause, (int, str, float, list)):
  867. # raise FilterError(
  868. # "argument to $contains filter must be a scalar or array"
  869. # )
  870. # parameters.append(json.dumps(clause))
  871. # return (
  872. # f"{json_col}->'{key}' @> ${len(parameters)}::jsonb"
  873. # )
  874. def parse_filter(filter_dict: dict) -> str:
  875. filter_conditions = []
  876. for key, value in filter_dict.items():
  877. if key == "$and":
  878. and_conditions = [
  879. parse_filter(f) for f in value if f
  880. ] # Skip empty dictionaries
  881. if and_conditions:
  882. filter_conditions.append(
  883. f"({' AND '.join(and_conditions)})"
  884. )
  885. elif key == "$or":
  886. or_conditions = [
  887. parse_filter(f) for f in value if f
  888. ] # Skip empty dictionaries
  889. if or_conditions:
  890. filter_conditions.append(
  891. f"({' OR '.join(or_conditions)})"
  892. )
  893. else:
  894. filter_conditions.append(parse_condition(key, value))
  895. # Check if there is only a single condition
  896. if len(filter_conditions) == 1:
  897. return filter_conditions[0]
  898. else:
  899. return " AND ".join(filter_conditions)
  900. where_clause = parse_filter(filters)
  901. return where_clause
  902. async def list_indices(
  903. self,
  904. offset: int,
  905. limit: int,
  906. filters: Optional[dict[str, Any]] = None,
  907. ) -> dict:
  908. where_clauses = []
  909. params: list[Any] = [self.project_name] # Start with schema name
  910. param_count = 1
  911. # Handle filtering
  912. if filters:
  913. if "table_name" in filters:
  914. where_clauses.append(f"i.tablename = ${param_count + 1}")
  915. params.append(filters["table_name"])
  916. param_count += 1
  917. if "index_method" in filters:
  918. where_clauses.append(f"am.amname = ${param_count + 1}")
  919. params.append(filters["index_method"])
  920. param_count += 1
  921. if "index_name" in filters:
  922. where_clauses.append(
  923. f"LOWER(i.indexname) LIKE LOWER(${param_count + 1})"
  924. )
  925. params.append(f"%{filters['index_name']}%")
  926. param_count += 1
  927. where_clause = " AND ".join(where_clauses) if where_clauses else ""
  928. if where_clause:
  929. where_clause = "AND " + where_clause
  930. query = f"""
  931. WITH index_info AS (
  932. SELECT
  933. i.indexname as name,
  934. i.tablename as table_name,
  935. i.indexdef as definition,
  936. am.amname as method,
  937. pg_relation_size(c.oid) as size_in_bytes,
  938. c.reltuples::bigint as row_estimate,
  939. COALESCE(psat.idx_scan, 0) as number_of_scans,
  940. COALESCE(psat.idx_tup_read, 0) as tuples_read,
  941. COALESCE(psat.idx_tup_fetch, 0) as tuples_fetched,
  942. COUNT(*) OVER() as total_count
  943. FROM pg_indexes i
  944. JOIN pg_class c ON c.relname = i.indexname
  945. JOIN pg_am am ON c.relam = am.oid
  946. LEFT JOIN pg_stat_user_indexes psat ON psat.indexrelname = i.indexname
  947. AND psat.schemaname = i.schemaname
  948. WHERE i.schemaname = $1
  949. AND i.indexdef LIKE '%vector%'
  950. {where_clause}
  951. )
  952. SELECT *
  953. FROM index_info
  954. ORDER BY name
  955. LIMIT ${param_count + 1}
  956. OFFSET ${param_count + 2}
  957. """
  958. # Add limit and offset to params
  959. params.extend([limit, offset])
  960. results = await self.connection_manager.fetch_query(query, params)
  961. indices = []
  962. total_entries = 0
  963. if results:
  964. total_entries = results[0]["total_count"]
  965. for result in results:
  966. index_info = {
  967. "name": result["name"],
  968. "table_name": result["table_name"],
  969. "definition": result["definition"],
  970. "size_in_bytes": result["size_in_bytes"],
  971. "row_estimate": result["row_estimate"],
  972. "number_of_scans": result["number_of_scans"],
  973. "tuples_read": result["tuples_read"],
  974. "tuples_fetched": result["tuples_fetched"],
  975. }
  976. indices.append(index_info)
  977. # Calculate pagination info
  978. total_pages = (total_entries + limit - 1) // limit if limit > 0 else 1
  979. current_page = (offset // limit) + 1 if limit > 0 else 1
  980. page_info = {
  981. "total_entries": total_entries,
  982. "total_pages": total_pages,
  983. "current_page": current_page,
  984. "limit": limit,
  985. "offset": offset,
  986. "has_previous": offset > 0,
  987. "has_next": offset + limit < total_entries,
  988. "previous_offset": max(0, offset - limit) if offset > 0 else None,
  989. "next_offset": (
  990. offset + limit if offset + limit < total_entries else None
  991. ),
  992. }
  993. return {"indices": indices, "page_info": page_info}
  994. async def delete_index(
  995. self,
  996. index_name: str,
  997. table_name: Optional[VectorTableName] = None,
  998. concurrently: bool = True,
  999. ) -> None:
  1000. """
  1001. Deletes a vector index.
  1002. Args:
  1003. index_name (str): Name of the index to delete
  1004. table_name (VectorTableName, optional): Table the index belongs to
  1005. concurrently (bool): Whether to drop the index concurrently
  1006. Raises:
  1007. ArgError: If table name is invalid or index doesn't exist
  1008. Exception: If index deletion fails
  1009. """
  1010. # Validate table name and get column name
  1011. if table_name == VectorTableName.CHUNKS:
  1012. table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}"
  1013. col_name = "vec"
  1014. elif table_name == VectorTableName.ENTITIES_DOCUMENT:
  1015. table_name_str = (
  1016. f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
  1017. )
  1018. col_name = "description_embedding"
  1019. elif table_name == VectorTableName.GRAPHS_ENTITIES:
  1020. table_name_str = (
  1021. f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
  1022. )
  1023. col_name = "description_embedding"
  1024. elif table_name == VectorTableName.COMMUNITIES:
  1025. table_name_str = (
  1026. f"{self.project_name}.{VectorTableName.COMMUNITIES}"
  1027. )
  1028. col_name = "description_embedding"
  1029. else:
  1030. raise ArgError("invalid table name")
  1031. # Extract schema and base table name
  1032. schema_name, base_table_name = table_name_str.split(".")
  1033. # Verify index exists and is a vector index
  1034. query = """
  1035. SELECT indexdef
  1036. FROM pg_indexes
  1037. WHERE indexname = $1
  1038. AND schemaname = $2
  1039. AND tablename = $3
  1040. AND indexdef LIKE $4
  1041. """
  1042. result = await self.connection_manager.fetchrow_query(
  1043. query, (index_name, schema_name, base_table_name, f"%({col_name}%")
  1044. )
  1045. if not result:
  1046. raise ArgError(
  1047. f"Vector index '{index_name}' does not exist on table {table_name_str}"
  1048. )
  1049. # Drop the index
  1050. concurrently_sql = "CONCURRENTLY" if concurrently else ""
  1051. drop_query = (
  1052. f"DROP INDEX {concurrently_sql} {schema_name}.{index_name}"
  1053. )
  1054. try:
  1055. if concurrently:
  1056. async with (
  1057. self.connection_manager.pool.get_connection() as conn # type: ignore
  1058. ):
  1059. # Disable automatic transaction management
  1060. await conn.execute(
  1061. "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
  1062. )
  1063. await conn.execute(drop_query)
  1064. else:
  1065. await self.connection_manager.execute_query(drop_query)
  1066. except Exception as e:
  1067. raise Exception(f"Failed to delete index: {e}")
  1068. async def get_semantic_neighbors(
  1069. self,
  1070. offset: int,
  1071. limit: int,
  1072. document_id: UUID,
  1073. id: UUID,
  1074. similarity_threshold: float = 0.5,
  1075. ) -> list[dict[str, Any]]:
  1076. table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
  1077. query = f"""
  1078. WITH target_vector AS (
  1079. SELECT vec FROM {table_name}
  1080. WHERE document_id = $1 AND id = $2
  1081. )
  1082. SELECT t.id, t.text, t.metadata, t.document_id, (t.vec <=> tv.vec) AS similarity
  1083. FROM {table_name} t, target_vector tv
  1084. WHERE (t.vec <=> tv.vec) >= $3
  1085. AND t.document_id = $1
  1086. AND t.id != $2
  1087. ORDER BY similarity ASC
  1088. LIMIT $4
  1089. """
  1090. results = await self.connection_manager.fetch_query(
  1091. query,
  1092. (str(document_id), str(id), similarity_threshold, limit),
  1093. )
  1094. return [
  1095. {
  1096. "id": str(r["id"]),
  1097. "text": r["text"],
  1098. "metadata": json.loads(r["metadata"]),
  1099. "document_id": str(r["document_id"]),
  1100. "similarity": float(r["similarity"]),
  1101. }
  1102. for r in results
  1103. ]
  1104. async def list_chunks(
  1105. self,
  1106. offset: int,
  1107. limit: int,
  1108. filters: Optional[dict[str, Any]] = None,
  1109. include_vectors: bool = False,
  1110. ) -> dict[str, Any]:
  1111. """
  1112. List chunks with pagination support.
  1113. Args:
  1114. offset (int, optional): Number of records to skip. Defaults to 0.
  1115. limit (int, optional): Maximum number of records to return. Defaults to 10.
  1116. filters (dict, optional): Dictionary of filters to apply. Defaults to None.
  1117. include_vectors (bool, optional): Whether to include vector data. Defaults to False.
  1118. Returns:
  1119. dict: Dictionary containing:
  1120. - results: List of chunk records
  1121. - total_entries: Total number of chunks matching the filters
  1122. - page_info: Pagination information
  1123. """
  1124. # Validate sort parameters
  1125. valid_sort_columns = {
  1126. "created_at": "metadata->>'created_at'",
  1127. "updated_at": "metadata->>'updated_at'",
  1128. "chunk_order": "metadata->>'chunk_order'",
  1129. "text": "text",
  1130. }
  1131. # Build the select clause
  1132. vector_select = ", vec" if include_vectors else ""
  1133. select_clause = f"""
  1134. id, document_id, owner_id, collection_ids,
  1135. text, metadata{vector_select}, COUNT(*) OVER() AS total
  1136. """
  1137. # Build the where clause if filters are provided
  1138. where_clause = ""
  1139. params: list[str | int | bytes] = []
  1140. if filters:
  1141. where_clause = self._build_filters(filters, params)
  1142. where_clause = f"WHERE {where_clause}"
  1143. # Construct the final query
  1144. query = f"""
  1145. SELECT {select_clause}
  1146. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  1147. {where_clause}
  1148. LIMIT $%s
  1149. OFFSET $%s
  1150. """
  1151. # Add pagination parameters
  1152. params.extend([limit, offset])
  1153. param_indices = list(range(1, len(params) + 1))
  1154. formatted_query = query % tuple(param_indices)
  1155. # Execute the query
  1156. results = await self.connection_manager.fetch_query(
  1157. formatted_query, params
  1158. )
  1159. # Process results
  1160. chunks = []
  1161. total = 0
  1162. if results:
  1163. total = results[0].get("total", 0)
  1164. chunks = [
  1165. {
  1166. "id": str(result["id"]),
  1167. "document_id": str(result["document_id"]),
  1168. "owner_id": str(result["owner_id"]),
  1169. "collection_ids": result["collection_ids"],
  1170. "text": result["text"],
  1171. "metadata": json.loads(result["metadata"]),
  1172. "vector": (
  1173. json.loads(result["vec"]) if include_vectors else None
  1174. ),
  1175. }
  1176. for result in results
  1177. ]
  1178. # Calculate pagination info
  1179. total_pages = (total + limit - 1) // limit if limit > 0 else 1
  1180. current_page = (offset // limit) + 1 if limit > 0 else 1
  1181. page_info = {
  1182. "total_entries": total,
  1183. "total_pages": total_pages,
  1184. "current_page": current_page,
  1185. "limit": limit,
  1186. "offset": offset,
  1187. "has_previous": offset > 0,
  1188. "has_next": offset + limit < total,
  1189. "previous_offset": max(0, offset - limit) if offset > 0 else None,
  1190. "next_offset": offset + limit if offset + limit < total else None,
  1191. }
  1192. return {"results": chunks, "page_info": page_info}
  1193. async def search_documents(
  1194. self,
  1195. query_text: str,
  1196. settings: SearchSettings,
  1197. ) -> list[dict[str, Any]]:
  1198. """
  1199. Search for documents based on their metadata fields and/or body text.
  1200. Joins with documents table to get complete document metadata.
  1201. Args:
  1202. query_text (str): The search query text
  1203. settings (SearchSettings): Search settings including search preferences and filters
  1204. Returns:
  1205. list[dict[str, Any]]: List of documents with their search scores and complete metadata
  1206. """
  1207. where_clauses = []
  1208. params: list[str | int | bytes] = [query_text]
  1209. # Build the dynamic metadata field search expression
  1210. metadata_fields_expr = " || ' ' || ".join(
  1211. [
  1212. f"COALESCE(v.metadata->>{psql_quote_literal(key)}, '')"
  1213. for key in settings.metadata_keys # type: ignore
  1214. ]
  1215. )
  1216. query = f"""
  1217. WITH
  1218. -- Metadata search scores
  1219. metadata_scores AS (
  1220. SELECT DISTINCT ON (v.document_id)
  1221. v.document_id,
  1222. d.metadata as doc_metadata,
  1223. CASE WHEN $1 = '' THEN 0.0
  1224. ELSE
  1225. ts_rank_cd(
  1226. setweight(to_tsvector('english', {metadata_fields_expr}), 'A'),
  1227. websearch_to_tsquery('english', $1),
  1228. 32
  1229. )
  1230. END as metadata_rank
  1231. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} v
  1232. LEFT JOIN {self._get_table_name('documents')} d ON v.document_id = d.id
  1233. WHERE v.metadata IS NOT NULL
  1234. ),
  1235. -- Body search scores
  1236. body_scores AS (
  1237. SELECT
  1238. document_id,
  1239. AVG(
  1240. ts_rank_cd(
  1241. setweight(to_tsvector('english', COALESCE(text, '')), 'B'),
  1242. websearch_to_tsquery('english', $1),
  1243. 32
  1244. )
  1245. ) as body_rank
  1246. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  1247. WHERE $1 != ''
  1248. {f"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if settings.search_over_body else ""}
  1249. GROUP BY document_id
  1250. ),
  1251. -- Combined scores with document metadata
  1252. combined_scores AS (
  1253. SELECT
  1254. COALESCE(m.document_id, b.document_id) as document_id,
  1255. m.doc_metadata as metadata,
  1256. COALESCE(m.metadata_rank, 0) as debug_metadata_rank,
  1257. COALESCE(b.body_rank, 0) as debug_body_rank,
  1258. CASE
  1259. WHEN {str(settings.search_over_metadata).lower()} AND {str(settings.search_over_body).lower()} THEN
  1260. COALESCE(m.metadata_rank, 0) * {settings.metadata_weight} + COALESCE(b.body_rank, 0) * {settings.title_weight}
  1261. WHEN {str(settings.search_over_metadata).lower()} THEN
  1262. COALESCE(m.metadata_rank, 0)
  1263. WHEN {str(settings.search_over_body).lower()} THEN
  1264. COALESCE(b.body_rank, 0)
  1265. ELSE 0
  1266. END as rank
  1267. FROM metadata_scores m
  1268. FULL OUTER JOIN body_scores b ON m.document_id = b.document_id
  1269. WHERE (
  1270. ($1 = '') OR
  1271. ({str(settings.search_over_metadata).lower()} AND m.metadata_rank > 0) OR
  1272. ({str(settings.search_over_body).lower()} AND b.body_rank > 0)
  1273. )
  1274. """
  1275. # Add any additional filters
  1276. if settings.filters:
  1277. filter_clause = self._build_filters(settings.filters, params)
  1278. where_clauses.append(filter_clause)
  1279. if where_clauses:
  1280. query += f" AND {' AND '.join(where_clauses)}"
  1281. query += """
  1282. )
  1283. SELECT
  1284. document_id,
  1285. metadata,
  1286. rank as score,
  1287. debug_metadata_rank,
  1288. debug_body_rank
  1289. FROM combined_scores
  1290. WHERE rank > 0
  1291. ORDER BY rank DESC
  1292. OFFSET ${offset_param} LIMIT ${limit_param}
  1293. """.format(
  1294. offset_param=len(params) + 1,
  1295. limit_param=len(params) + 2,
  1296. )
  1297. # Add offset and limit to params
  1298. params.extend([settings.offset, settings.limit])
  1299. # Execute query
  1300. results = await self.connection_manager.fetch_query(query, params)
  1301. # Format results with complete document metadata
  1302. return [
  1303. {
  1304. "document_id": str(r["document_id"]),
  1305. "metadata": (
  1306. json.loads(r["metadata"])
  1307. if isinstance(r["metadata"], str)
  1308. else r["metadata"]
  1309. ),
  1310. "score": float(r["score"]),
  1311. "debug_metadata_rank": float(r["debug_metadata_rank"]),
  1312. "debug_body_rank": float(r["debug_body_rank"]),
  1313. }
  1314. for r in results
  1315. ]
  1316. def _get_index_options(
  1317. self,
  1318. method: IndexMethod,
  1319. index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW],
  1320. ) -> str:
  1321. if method == IndexMethod.ivfflat:
  1322. if isinstance(index_arguments, IndexArgsIVFFlat):
  1323. return f"WITH (lists={index_arguments.n_lists})"
  1324. else:
  1325. # Default value if no arguments provided
  1326. return "WITH (lists=100)"
  1327. elif method == IndexMethod.hnsw:
  1328. if isinstance(index_arguments, IndexArgsHNSW):
  1329. return f"WITH (m={index_arguments.m}, ef_construction={index_arguments.ef_construction})"
  1330. else:
  1331. # Default values if no arguments provided
  1332. return "WITH (m=16, ef_construction=64)"
  1333. else:
  1334. return "" # No options for other methods