chunks.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312
  1. import copy
  2. import json
  3. import logging
  4. import math
  5. import time
  6. import uuid
  7. from typing import Any, Optional, TypedDict
  8. from uuid import UUID
  9. import numpy as np
  10. from core.base import (
  11. ChunkSearchResult,
  12. Handler,
  13. IndexArgsHNSW,
  14. IndexArgsIVFFlat,
  15. IndexMeasure,
  16. IndexMethod,
  17. R2RException,
  18. SearchSettings,
  19. VectorEntry,
  20. VectorQuantizationType,
  21. VectorTableName,
  22. )
  23. from core.base.utils import _decorate_vector_type
  24. from .base import PostgresConnectionManager
  25. from .filters import apply_filters
  26. from .utils import psql_quote_literal
  27. logger = logging.getLogger()
  28. def index_measure_to_ops(
  29. measure: IndexMeasure,
  30. quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
  31. ):
  32. return _decorate_vector_type(measure.ops, quantization_type)
  33. def quantize_vector_to_binary(
  34. vector: list[float] | np.ndarray,
  35. threshold: float = 0.0,
  36. ) -> bytes:
  37. """Quantizes a float vector to a binary vector string for PostgreSQL bit
  38. type. Used when quantization_type is INT1.
  39. Args:
  40. vector (List[float] | np.ndarray): Input vector of floats
  41. threshold (float, optional): Threshold for binarization. Defaults to 0.0.
  42. Returns:
  43. str: Binary string representation for PostgreSQL bit type
  44. """
  45. # Convert input to numpy array if it isn't already
  46. if not isinstance(vector, np.ndarray):
  47. vector = np.array(vector)
  48. # Convert to binary (1 where value > threshold, 0 otherwise)
  49. binary_vector = (vector > threshold).astype(int)
  50. # Convert to string of 1s and 0s
  51. # Convert to string of 1s and 0s, then to bytes
  52. binary_string = "".join(map(str, binary_vector))
  53. return binary_string.encode("ascii")
  54. class HybridSearchIntermediateResult(TypedDict):
  55. semantic_rank: int
  56. full_text_rank: int
  57. data: ChunkSearchResult
  58. rrf_score: float
  59. class PostgresChunksHandler(Handler):
  60. TABLE_NAME = VectorTableName.CHUNKS
  61. def __init__(
  62. self,
  63. project_name: str,
  64. connection_manager: PostgresConnectionManager,
  65. dimension: int | float,
  66. quantization_type: VectorQuantizationType,
  67. ):
  68. super().__init__(project_name, connection_manager)
  69. self.dimension = dimension
  70. self.quantization_type = quantization_type
  71. async def create_tables(self):
  72. # First check if table already exists and validate dimensions
  73. table_exists_query = """
  74. SELECT EXISTS (
  75. SELECT FROM pg_tables
  76. WHERE schemaname = $1
  77. AND tablename = $2
  78. );
  79. """
  80. table_name = VectorTableName.CHUNKS
  81. table_exists = await self.connection_manager.fetch_query(
  82. table_exists_query, (self.project_name, table_name)
  83. )
  84. if len(table_exists) > 0 and table_exists[0]["exists"]:
  85. # Table exists, check vector dimension
  86. vector_dim_query = """
  87. SELECT a.atttypmod as dimension
  88. FROM pg_attribute a
  89. JOIN pg_class c ON a.attrelid = c.oid
  90. JOIN pg_namespace n ON c.relnamespace = n.oid
  91. WHERE n.nspname = $1
  92. AND c.relname = $2
  93. AND a.attname = 'vec';
  94. """
  95. vector_dim_result = await self.connection_manager.fetch_query(
  96. vector_dim_query, (self.project_name, table_name)
  97. )
  98. if vector_dim_result and len(vector_dim_result) > 0:
  99. existing_dimension = vector_dim_result[0]["dimension"]
  100. # In pgvector, dimension is stored as atttypmod - 4
  101. if existing_dimension > 0: # If it has a specific dimension
  102. # Compare with provided dimension
  103. if (
  104. self.dimension > 0
  105. and existing_dimension != self.dimension
  106. ):
  107. raise ValueError(
  108. f"Dimension mismatch: Table '{self.project_name}.{table_name}' was created with "
  109. f"dimension {existing_dimension}, but {self.dimension} was provided. "
  110. f"You must use the same dimension for existing tables."
  111. )
  112. # Check for old table name
  113. check_query = """
  114. SELECT EXISTS (
  115. SELECT FROM pg_tables
  116. WHERE schemaname = $1
  117. AND tablename = $2
  118. );
  119. """
  120. old_table_exists = await self.connection_manager.fetch_query(
  121. check_query, (self.project_name, self.project_name)
  122. )
  123. if len(old_table_exists) > 0 and old_table_exists[0]["exists"]:
  124. raise ValueError(
  125. f"Found old vector table '{self.project_name}.{self.project_name}'. "
  126. "Please run `r2r db upgrade` with the CLI, or to run manually, "
  127. "run in R2R/py/migrations with 'alembic upgrade head' to update "
  128. "your database schema to the new version."
  129. )
  130. binary_col = (
  131. ""
  132. if self.quantization_type != VectorQuantizationType.INT1
  133. else f"vec_binary bit({self.dimension}),"
  134. )
  135. if self.dimension > 0:
  136. vector_col = f"vec vector({self.dimension})"
  137. else:
  138. vector_col = "vec vector"
  139. query = f"""
  140. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (
  141. id UUID PRIMARY KEY,
  142. document_id UUID,
  143. owner_id UUID,
  144. collection_ids UUID[],
  145. {vector_col},
  146. {binary_col}
  147. text TEXT,
  148. metadata JSONB,
  149. fts tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
  150. );
  151. CREATE INDEX IF NOT EXISTS idx_vectors_document_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (document_id);
  152. CREATE INDEX IF NOT EXISTS idx_vectors_owner_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (owner_id);
  153. CREATE INDEX IF NOT EXISTS idx_vectors_collection_ids ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (collection_ids);
  154. CREATE INDEX IF NOT EXISTS idx_vectors_text ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (to_tsvector('english', text));
  155. """
  156. await self.connection_manager.execute_query(query)
  157. async def upsert(self, entry: VectorEntry) -> None:
  158. """Upsert function that handles vector quantization only when
  159. quantization_type is INT1.
  160. Matches the table schema where vec_binary column only exists for INT1
  161. quantization.
  162. """
  163. # Check the quantization type to determine which columns to use
  164. if self.quantization_type == VectorQuantizationType.INT1:
  165. bit_dim = (
  166. "" if math.isnan(self.dimension) else f"({self.dimension})"
  167. )
  168. # For quantized vectors, use vec_binary column
  169. query = f"""
  170. INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  171. (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
  172. VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
  173. ON CONFLICT (id) DO UPDATE SET
  174. document_id = EXCLUDED.document_id,
  175. owner_id = EXCLUDED.owner_id,
  176. collection_ids = EXCLUDED.collection_ids,
  177. vec = EXCLUDED.vec,
  178. vec_binary = EXCLUDED.vec_binary,
  179. text = EXCLUDED.text,
  180. metadata = EXCLUDED.metadata;
  181. """
  182. await self.connection_manager.execute_query(
  183. query,
  184. (
  185. entry.id,
  186. entry.document_id,
  187. entry.owner_id,
  188. entry.collection_ids,
  189. str(entry.vector.data),
  190. quantize_vector_to_binary(
  191. entry.vector.data
  192. ), # Convert to binary
  193. entry.text,
  194. json.dumps(entry.metadata),
  195. ),
  196. )
  197. else:
  198. # For regular vectors, use vec column only
  199. query = f"""
  200. INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  201. (id, document_id, owner_id, collection_ids, vec, text, metadata)
  202. VALUES ($1, $2, $3, $4, $5, $6, $7)
  203. ON CONFLICT (id) DO UPDATE SET
  204. document_id = EXCLUDED.document_id,
  205. owner_id = EXCLUDED.owner_id,
  206. collection_ids = EXCLUDED.collection_ids,
  207. vec = EXCLUDED.vec,
  208. text = EXCLUDED.text,
  209. metadata = EXCLUDED.metadata;
  210. """
  211. await self.connection_manager.execute_query(
  212. query,
  213. (
  214. entry.id,
  215. entry.document_id,
  216. entry.owner_id,
  217. entry.collection_ids,
  218. str(entry.vector.data),
  219. entry.text,
  220. json.dumps(entry.metadata),
  221. ),
  222. )
  223. async def upsert_entries(self, entries: list[VectorEntry]) -> None:
  224. """Batch upsert function that handles vector quantization only when
  225. quantization_type is INT1.
  226. Matches the table schema where vec_binary column only exists for INT1
  227. quantization.
  228. """
  229. if self.quantization_type == VectorQuantizationType.INT1:
  230. bit_dim = (
  231. "" if math.isnan(self.dimension) else f"({self.dimension})"
  232. )
  233. # For quantized vectors, use vec_binary column
  234. query = f"""
  235. INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  236. (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
  237. VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
  238. ON CONFLICT (id) DO UPDATE SET
  239. document_id = EXCLUDED.document_id,
  240. owner_id = EXCLUDED.owner_id,
  241. collection_ids = EXCLUDED.collection_ids,
  242. vec = EXCLUDED.vec,
  243. vec_binary = EXCLUDED.vec_binary,
  244. text = EXCLUDED.text,
  245. metadata = EXCLUDED.metadata;
  246. """
  247. bin_params = [
  248. (
  249. entry.id,
  250. entry.document_id,
  251. entry.owner_id,
  252. entry.collection_ids,
  253. str(entry.vector.data),
  254. quantize_vector_to_binary(
  255. entry.vector.data
  256. ), # Convert to binary
  257. entry.text,
  258. json.dumps(entry.metadata),
  259. )
  260. for entry in entries
  261. ]
  262. await self.connection_manager.execute_many(query, bin_params)
  263. else:
  264. # For regular vectors, use vec column only
  265. query = f"""
  266. INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  267. (id, document_id, owner_id, collection_ids, vec, text, metadata)
  268. VALUES ($1, $2, $3, $4, $5, $6, $7)
  269. ON CONFLICT (id) DO UPDATE SET
  270. document_id = EXCLUDED.document_id,
  271. owner_id = EXCLUDED.owner_id,
  272. collection_ids = EXCLUDED.collection_ids,
  273. vec = EXCLUDED.vec,
  274. text = EXCLUDED.text,
  275. metadata = EXCLUDED.metadata;
  276. """
  277. params = [
  278. (
  279. entry.id,
  280. entry.document_id,
  281. entry.owner_id,
  282. entry.collection_ids,
  283. str(entry.vector.data),
  284. entry.text,
  285. json.dumps(entry.metadata),
  286. )
  287. for entry in entries
  288. ]
  289. await self.connection_manager.execute_many(query, params)
  290. async def semantic_search(
  291. self, query_vector: list[float], search_settings: SearchSettings
  292. ) -> list[ChunkSearchResult]:
  293. try:
  294. imeasure_obj = IndexMeasure(
  295. search_settings.chunk_settings.index_measure
  296. )
  297. except ValueError:
  298. raise ValueError("Invalid index measure") from None
  299. table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
  300. cols = [
  301. f"{table_name}.id",
  302. f"{table_name}.document_id",
  303. f"{table_name}.owner_id",
  304. f"{table_name}.collection_ids",
  305. f"{table_name}.text",
  306. ]
  307. params: list[str | int | bytes] = []
  308. # For binary vectors (INT1), implement two-stage search
  309. if self.quantization_type == VectorQuantizationType.INT1:
  310. # Convert query vector to binary format
  311. binary_query = quantize_vector_to_binary(query_vector)
  312. # TODO - Put depth multiplier in config / settings
  313. extended_limit = (
  314. search_settings.limit * 20
  315. ) # Get 20x candidates for re-ranking
  316. if (
  317. imeasure_obj == IndexMeasure.hamming_distance
  318. or imeasure_obj == IndexMeasure.jaccard_distance
  319. ):
  320. binary_search_measure_repr = imeasure_obj.pgvector_repr
  321. else:
  322. binary_search_measure_repr = (
  323. IndexMeasure.hamming_distance.pgvector_repr
  324. )
  325. # Use binary column and binary-specific distance measures for first stage
  326. bit_dim = (
  327. "" if math.isnan(self.dimension) else f"({self.dimension})"
  328. )
  329. stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit{bit_dim}"
  330. stage1_param = binary_query
  331. cols.append(
  332. f"{table_name}.vec"
  333. ) # Need original vector for re-ranking
  334. if search_settings.include_metadatas:
  335. cols.append(f"{table_name}.metadata")
  336. select_clause = ", ".join(cols)
  337. where_clause = ""
  338. params.append(stage1_param)
  339. if search_settings.filters:
  340. where_clause, params = apply_filters(
  341. search_settings.filters, params, mode="where_clause"
  342. )
  343. vector_dim = (
  344. "" if math.isnan(self.dimension) else f"({self.dimension})"
  345. )
  346. # First stage: Get candidates using binary search
  347. query = f"""
  348. WITH candidates AS (
  349. SELECT {select_clause},
  350. ({stage1_distance}) as binary_distance
  351. FROM {table_name}
  352. {where_clause}
  353. ORDER BY {stage1_distance}
  354. LIMIT ${len(params) + 1}
  355. OFFSET ${len(params) + 2}
  356. )
  357. -- Second stage: Re-rank using original vectors
  358. SELECT
  359. id,
  360. document_id,
  361. owner_id,
  362. collection_ids,
  363. text,
  364. {"metadata," if search_settings.include_metadatas else ""}
  365. (vec <=> ${len(params) + 4}::vector{vector_dim}) as distance
  366. FROM candidates
  367. ORDER BY distance
  368. LIMIT ${len(params) + 3}
  369. """
  370. params.extend(
  371. [
  372. extended_limit, # First stage limit
  373. search_settings.offset,
  374. search_settings.limit, # Final limit
  375. str(query_vector), # For re-ranking
  376. ]
  377. )
  378. else:
  379. # Standard float vector handling
  380. vector_dim = (
  381. "" if math.isnan(self.dimension) else f"({self.dimension})"
  382. )
  383. distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector{vector_dim}"
  384. query_param = str(query_vector)
  385. if search_settings.include_scores:
  386. cols.append(f"({distance_calc}) AS distance")
  387. if search_settings.include_metadatas:
  388. cols.append(f"{table_name}.metadata")
  389. select_clause = ", ".join(cols)
  390. where_clause = ""
  391. params.append(query_param)
  392. if search_settings.filters:
  393. where_clause, new_params = apply_filters(
  394. search_settings.filters,
  395. params,
  396. mode="where_clause", # Get just conditions without WHERE
  397. )
  398. params = new_params
  399. query = f"""
  400. SELECT {select_clause}
  401. FROM {table_name}
  402. {where_clause}
  403. ORDER BY {distance_calc}
  404. LIMIT ${len(params) + 1}
  405. OFFSET ${len(params) + 2}
  406. """
  407. params.extend([search_settings.limit, search_settings.offset])
  408. results = await self.connection_manager.fetch_query(query, params)
  409. return [
  410. ChunkSearchResult(
  411. id=UUID(str(result["id"])),
  412. document_id=UUID(str(result["document_id"])),
  413. owner_id=UUID(str(result["owner_id"])),
  414. collection_ids=result["collection_ids"],
  415. text=result["text"],
  416. score=(
  417. (1 - float(result["distance"]))
  418. if "distance" in result
  419. else -1
  420. ),
  421. metadata=(
  422. json.loads(result["metadata"])
  423. if search_settings.include_metadatas
  424. else {}
  425. ),
  426. )
  427. for result in results
  428. ]
  429. async def full_text_search(
  430. self, query_text: str, search_settings: SearchSettings
  431. ) -> list[ChunkSearchResult]:
  432. conditions = []
  433. params: list[str | int | bytes] = [query_text]
  434. conditions.append("fts @@ websearch_to_tsquery('english', $1)")
  435. if search_settings.filters:
  436. filter_condition, params = apply_filters(
  437. search_settings.filters, params, mode="condition_only"
  438. )
  439. if filter_condition:
  440. conditions.append(filter_condition)
  441. where_clause = "WHERE " + " AND ".join(conditions)
  442. query = f"""
  443. SELECT
  444. id,
  445. document_id,
  446. owner_id,
  447. collection_ids,
  448. text,
  449. metadata,
  450. ts_rank(fts, websearch_to_tsquery('english', $1), 32) as rank
  451. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  452. {where_clause}
  453. ORDER BY rank DESC
  454. OFFSET ${len(params) + 1}
  455. LIMIT ${len(params) + 2}
  456. """
  457. params.extend(
  458. [
  459. search_settings.offset,
  460. search_settings.hybrid_settings.full_text_limit,
  461. ]
  462. )
  463. results = await self.connection_manager.fetch_query(query, params)
  464. return [
  465. ChunkSearchResult(
  466. id=UUID(str(r["id"])),
  467. document_id=UUID(str(r["document_id"])),
  468. owner_id=UUID(str(r["owner_id"])),
  469. collection_ids=r["collection_ids"],
  470. text=r["text"],
  471. score=float(r["rank"]),
  472. metadata=json.loads(r["metadata"]),
  473. )
  474. for r in results
  475. ]
  476. async def hybrid_search(
  477. self,
  478. query_text: str,
  479. query_vector: list[float],
  480. search_settings: SearchSettings,
  481. *args,
  482. **kwargs,
  483. ) -> list[ChunkSearchResult]:
  484. if search_settings.hybrid_settings is None:
  485. raise ValueError(
  486. "Please provide a valid `hybrid_settings` in the `search_settings`."
  487. )
  488. if (
  489. search_settings.hybrid_settings.full_text_limit
  490. < search_settings.limit
  491. ):
  492. raise ValueError(
  493. "The `full_text_limit` must be greater than or equal to the `limit`."
  494. )
  495. semantic_settings = copy.deepcopy(search_settings)
  496. semantic_settings.limit += search_settings.offset
  497. full_text_settings = copy.deepcopy(search_settings)
  498. full_text_settings.hybrid_settings.full_text_limit += (
  499. search_settings.offset
  500. )
  501. semantic_results: list[ChunkSearchResult] = await self.semantic_search(
  502. query_vector, semantic_settings
  503. )
  504. full_text_results: list[
  505. ChunkSearchResult
  506. ] = await self.full_text_search(query_text, full_text_settings)
  507. semantic_limit = search_settings.limit
  508. full_text_limit = search_settings.hybrid_settings.full_text_limit
  509. semantic_weight = search_settings.hybrid_settings.semantic_weight
  510. full_text_weight = search_settings.hybrid_settings.full_text_weight
  511. rrf_k = search_settings.hybrid_settings.rrf_k
  512. combined_results: dict[uuid.UUID, HybridSearchIntermediateResult] = {}
  513. for rank, result in enumerate(semantic_results, 1):
  514. combined_results[result.id] = {
  515. "semantic_rank": rank,
  516. "full_text_rank": full_text_limit,
  517. "data": result,
  518. "rrf_score": 0.0, # Initialize with 0, will be calculated later
  519. }
  520. for rank, result in enumerate(full_text_results, 1):
  521. if result.id in combined_results:
  522. combined_results[result.id]["full_text_rank"] = rank
  523. else:
  524. combined_results[result.id] = {
  525. "semantic_rank": semantic_limit,
  526. "full_text_rank": rank,
  527. "data": result,
  528. "rrf_score": 0.0, # Initialize with 0, will be calculated later
  529. }
  530. combined_results = {
  531. k: v
  532. for k, v in combined_results.items()
  533. if v["semantic_rank"] <= semantic_limit * 2
  534. and v["full_text_rank"] <= full_text_limit * 2
  535. }
  536. for hyb_result in combined_results.values():
  537. semantic_score = 1 / (rrf_k + hyb_result["semantic_rank"])
  538. full_text_score = 1 / (rrf_k + hyb_result["full_text_rank"])
  539. hyb_result["rrf_score"] = (
  540. semantic_score * semantic_weight
  541. + full_text_score * full_text_weight
  542. ) / (semantic_weight + full_text_weight)
  543. sorted_results = sorted(
  544. combined_results.values(),
  545. key=lambda x: x["rrf_score"],
  546. reverse=True,
  547. )
  548. offset_results = sorted_results[
  549. search_settings.offset : search_settings.offset
  550. + search_settings.limit
  551. ]
  552. return [
  553. ChunkSearchResult(
  554. id=result["data"].id,
  555. document_id=result["data"].document_id,
  556. owner_id=result["data"].owner_id,
  557. collection_ids=result["data"].collection_ids,
  558. text=result["data"].text,
  559. score=result["rrf_score"],
  560. metadata={
  561. **result["data"].metadata,
  562. "semantic_rank": result["semantic_rank"],
  563. "full_text_rank": result["full_text_rank"],
  564. },
  565. )
  566. for result in offset_results
  567. ]
  568. async def delete(
  569. self, filters: dict[str, Any]
  570. ) -> dict[str, dict[str, str]]:
  571. params: list[str | int | bytes] = []
  572. where_clause, params = apply_filters(
  573. filters, params, mode="condition_only"
  574. )
  575. query = f"""
  576. DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  577. WHERE {where_clause}
  578. RETURNING id, document_id, text;
  579. """
  580. results = await self.connection_manager.fetch_query(query, params)
  581. return {
  582. str(result["id"]): {
  583. "status": "deleted",
  584. "id": str(result["id"]),
  585. "document_id": str(result["document_id"]),
  586. "text": result["text"],
  587. }
  588. for result in results
  589. }
  590. async def assign_document_chunks_to_collection(
  591. self, document_id: UUID, collection_id: UUID
  592. ) -> None:
  593. query = f"""
  594. UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  595. SET collection_ids = array_append(collection_ids, $1)
  596. WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids));
  597. """
  598. logger.info(query)
  599. logger.info(f"collection_id: {collection_id}")
  600. logger.info(f"document_id: {document_id}")
  601. return await self.connection_manager.execute_query(
  602. query, (str(collection_id), str(document_id))
  603. )
  604. async def remove_document_from_collection_vector(
  605. self, document_id: UUID, collection_id: UUID
  606. ) -> None:
  607. query = f"""
  608. UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  609. SET collection_ids = array_remove(collection_ids, $1)
  610. WHERE document_id = $2;
  611. """
  612. await self.connection_manager.execute_query(
  613. query, (collection_id, document_id)
  614. )
  615. async def delete_user_vector(self, owner_id: UUID) -> None:
  616. query = f"""
  617. DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  618. WHERE owner_id = $1;
  619. """
  620. await self.connection_manager.execute_query(query, (owner_id,))
  621. async def delete_collection_vector(self, collection_id: UUID) -> None:
  622. query = f"""
  623. DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  624. WHERE $1 = ANY(collection_ids)
  625. RETURNING collection_ids
  626. """
  627. await self.connection_manager.fetchrow_query(query, (collection_id,))
  628. return None
  629. async def list_document_chunks(
  630. self,
  631. document_id: UUID,
  632. offset: int,
  633. limit: int,
  634. include_vectors: bool = False,
  635. ) -> dict[str, Any]:
  636. vector_select = ", vec" if include_vectors else ""
  637. limit_clause = f"LIMIT {limit}" if limit > -1 else ""
  638. query = f"""
  639. SELECT id, document_id, owner_id, collection_ids, text, metadata{vector_select}, COUNT(*) OVER() AS total
  640. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  641. WHERE document_id = $1
  642. ORDER BY (metadata->>'chunk_order')::integer
  643. OFFSET $2
  644. {limit_clause};
  645. """
  646. params = [document_id, offset]
  647. results = await self.connection_manager.fetch_query(query, params)
  648. chunks = []
  649. total = 0
  650. if results:
  651. total = results[0].get("total", 0)
  652. chunks = [
  653. {
  654. "id": result["id"],
  655. "document_id": result["document_id"],
  656. "owner_id": result["owner_id"],
  657. "collection_ids": result["collection_ids"],
  658. "text": result["text"],
  659. "metadata": json.loads(result["metadata"]),
  660. "vector": (
  661. json.loads(result["vec"]) if include_vectors else None
  662. ),
  663. }
  664. for result in results
  665. ]
  666. return {"results": chunks, "total_entries": total}
  667. async def get_chunk(self, id: UUID) -> dict:
  668. query = f"""
  669. SELECT id, document_id, owner_id, collection_ids, text, metadata
  670. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  671. WHERE id = $1;
  672. """
  673. result = await self.connection_manager.fetchrow_query(query, (id,))
  674. if result:
  675. return {
  676. "id": result["id"],
  677. "document_id": result["document_id"],
  678. "owner_id": result["owner_id"],
  679. "collection_ids": result["collection_ids"],
  680. "text": result["text"],
  681. "metadata": json.loads(result["metadata"]),
  682. }
  683. raise R2RException(
  684. message=f"Chunk with ID {id} not found", status_code=404
  685. )
  686. async def create_index(
  687. self,
  688. table_name: Optional[VectorTableName] = None,
  689. index_measure: IndexMeasure = IndexMeasure.cosine_distance,
  690. index_method: IndexMethod = IndexMethod.auto,
  691. index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = None,
  692. index_name: Optional[str] = None,
  693. index_column: Optional[str] = None,
  694. concurrently: bool = True,
  695. ) -> None:
  696. """Creates an index for the collection.
  697. Note:
  698. When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step
  699. process that enables performant indexes to be built for large collections with low end
  700. database hardware.
  701. Those steps are:
  702. - Creates a new table with a different name
  703. - Randomly selects records from the existing table
  704. - Inserts the random records from the existing table into the new table
  705. - Creates the requested vector index on the new table
  706. - Upserts all data from the existing table into the new table
  707. - Drops the existing table
  708. - Renames the new table to the existing tables name
  709. If you create dependencies (like views) on the table that underpins
  710. a `vecs.Collection` the `create_index` step may require you to drop those dependencies before
  711. it will succeed.
  712. Args:
  713. index_measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'.
  714. index_method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'.
  715. index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments
  716. index_name (str, optional): The name of the index to create. Defaults to None.
  717. concurrently (bool, optional): Whether to create the index concurrently. Defaults to True.
  718. Raises:
  719. ValueError: If an invalid index method is used, or if *replace* is False and an index already exists.
  720. """
  721. if table_name == VectorTableName.CHUNKS:
  722. table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}" # TODO - Fix bug in vector table naming convention
  723. if index_column:
  724. col_name = index_column
  725. else:
  726. col_name = (
  727. "vec"
  728. if (
  729. index_measure != IndexMeasure.hamming_distance
  730. and index_measure != IndexMeasure.jaccard_distance
  731. )
  732. else "vec_binary"
  733. )
  734. elif table_name == VectorTableName.ENTITIES_DOCUMENT:
  735. table_name_str = (
  736. f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
  737. )
  738. col_name = "description_embedding"
  739. elif table_name == VectorTableName.GRAPHS_ENTITIES:
  740. table_name_str = (
  741. f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
  742. )
  743. col_name = "description_embedding"
  744. elif table_name == VectorTableName.COMMUNITIES:
  745. table_name_str = (
  746. f"{self.project_name}.{VectorTableName.COMMUNITIES}"
  747. )
  748. col_name = "embedding"
  749. else:
  750. raise ValueError("invalid table name")
  751. if index_method not in (
  752. IndexMethod.ivfflat,
  753. IndexMethod.hnsw,
  754. IndexMethod.auto,
  755. ):
  756. raise ValueError("invalid index method")
  757. if index_arguments:
  758. # Disallow case where user submits index arguments but uses the
  759. # IndexMethod.auto index (index build arguments should only be
  760. # used with a specific index)
  761. if index_method == IndexMethod.auto:
  762. raise ValueError(
  763. "Index build parameters are not allowed when using the IndexMethod.auto index."
  764. )
  765. # Disallow case where user specifies one index type but submits
  766. # index build arguments for the other index type
  767. if (
  768. isinstance(index_arguments, IndexArgsHNSW)
  769. and index_method != IndexMethod.hnsw
  770. ) or (
  771. isinstance(index_arguments, IndexArgsIVFFlat)
  772. and index_method != IndexMethod.ivfflat
  773. ):
  774. raise ValueError(
  775. f"{index_arguments.__class__.__name__} build parameters were supplied but {index_method} index was specified."
  776. )
  777. if index_method == IndexMethod.auto:
  778. index_method = IndexMethod.hnsw
  779. ops = index_measure_to_ops(
  780. index_measure # , quantization_type=self.quantization_type
  781. )
  782. if ops is None:
  783. raise ValueError("Unknown index measure")
  784. concurrently_sql = "CONCURRENTLY" if concurrently else ""
  785. index_name = (
  786. index_name
  787. or f"ix_{ops}_{index_method}__{col_name}_{time.strftime('%Y%m%d%H%M%S')}"
  788. )
  789. create_index_sql = f"""
  790. CREATE INDEX {concurrently_sql} {index_name}
  791. ON {table_name_str}
  792. USING {index_method} ({col_name} {ops}) {self._get_index_options(index_method, index_arguments)};
  793. """
  794. try:
  795. if concurrently:
  796. async with (
  797. self.connection_manager.pool.get_connection() as conn # type: ignore
  798. ):
  799. # Disable automatic transaction management
  800. await conn.execute(
  801. "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
  802. )
  803. await conn.execute(create_index_sql)
  804. else:
  805. # Non-concurrent index creation can use normal query execution
  806. await self.connection_manager.execute_query(create_index_sql)
  807. except Exception as e:
  808. raise Exception(f"Failed to create index: {e}") from e
  809. return None
  810. async def list_indices(
  811. self,
  812. offset: int,
  813. limit: int,
  814. filters: Optional[dict[str, Any]] = None,
  815. ) -> dict:
  816. where_clauses = []
  817. params: list[Any] = [self.project_name] # Start with schema name
  818. param_count = 1
  819. # Handle filtering
  820. if filters:
  821. if "table_name" in filters:
  822. where_clauses.append(f"i.tablename = ${param_count + 1}")
  823. params.append(filters["table_name"])
  824. param_count += 1
  825. if "index_method" in filters:
  826. where_clauses.append(f"am.amname = ${param_count + 1}")
  827. params.append(filters["index_method"])
  828. param_count += 1
  829. if "index_name" in filters:
  830. where_clauses.append(
  831. f"LOWER(i.indexname) LIKE LOWER(${param_count + 1})"
  832. )
  833. params.append(f"%{filters['index_name']}%")
  834. param_count += 1
  835. where_clause = " AND ".join(where_clauses) if where_clauses else ""
  836. if where_clause:
  837. where_clause = f"AND {where_clause}"
  838. query = f"""
  839. WITH index_info AS (
  840. SELECT
  841. i.indexname as name,
  842. i.tablename as table_name,
  843. i.indexdef as definition,
  844. am.amname as method,
  845. pg_relation_size(c.oid) as size_in_bytes,
  846. c.reltuples::bigint as row_estimate,
  847. COALESCE(psat.idx_scan, 0) as number_of_scans,
  848. COALESCE(psat.idx_tup_read, 0) as tuples_read,
  849. COALESCE(psat.idx_tup_fetch, 0) as tuples_fetched,
  850. COUNT(*) OVER() as total_count
  851. FROM pg_indexes i
  852. JOIN pg_class c ON c.relname = i.indexname
  853. JOIN pg_am am ON c.relam = am.oid
  854. LEFT JOIN pg_stat_user_indexes psat ON psat.indexrelname = i.indexname
  855. AND psat.schemaname = i.schemaname
  856. WHERE i.schemaname = $1
  857. AND i.indexdef LIKE '%vector%'
  858. {where_clause}
  859. )
  860. SELECT *
  861. FROM index_info
  862. ORDER BY name
  863. LIMIT ${param_count + 1}
  864. OFFSET ${param_count + 2}
  865. """
  866. # Add limit and offset to params
  867. params.extend([limit, offset])
  868. results = await self.connection_manager.fetch_query(query, params)
  869. indices = []
  870. total_entries = 0
  871. if results:
  872. total_entries = results[0]["total_count"]
  873. for result in results:
  874. index_info = {
  875. "name": result["name"],
  876. "table_name": result["table_name"],
  877. "definition": result["definition"],
  878. "size_in_bytes": result["size_in_bytes"],
  879. "row_estimate": result["row_estimate"],
  880. "number_of_scans": result["number_of_scans"],
  881. "tuples_read": result["tuples_read"],
  882. "tuples_fetched": result["tuples_fetched"],
  883. }
  884. indices.append(index_info)
  885. return {"indices": indices, "total_entries": total_entries}
  886. async def delete_index(
  887. self,
  888. index_name: str,
  889. table_name: Optional[VectorTableName] = None,
  890. concurrently: bool = True,
  891. ) -> None:
  892. """Deletes a vector index.
  893. Args:
  894. index_name (str): Name of the index to delete
  895. table_name (VectorTableName, optional): Table the index belongs to
  896. concurrently (bool): Whether to drop the index concurrently
  897. Raises:
  898. ValueError: If table name is invalid or index doesn't exist
  899. Exception: If index deletion fails
  900. """
  901. # Validate table name and get column name
  902. if table_name == VectorTableName.CHUNKS:
  903. table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}"
  904. col_name = "vec"
  905. elif table_name == VectorTableName.ENTITIES_DOCUMENT:
  906. table_name_str = (
  907. f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
  908. )
  909. col_name = "description_embedding"
  910. elif table_name == VectorTableName.GRAPHS_ENTITIES:
  911. table_name_str = (
  912. f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
  913. )
  914. col_name = "description_embedding"
  915. elif table_name == VectorTableName.COMMUNITIES:
  916. table_name_str = (
  917. f"{self.project_name}.{VectorTableName.COMMUNITIES}"
  918. )
  919. col_name = "description_embedding"
  920. else:
  921. raise ValueError("invalid table name")
  922. # Extract schema and base table name
  923. schema_name, base_table_name = table_name_str.split(".")
  924. # Verify index exists and is a vector index
  925. query = """
  926. SELECT indexdef
  927. FROM pg_indexes
  928. WHERE indexname = $1
  929. AND schemaname = $2
  930. AND tablename = $3
  931. AND indexdef LIKE $4
  932. """
  933. result = await self.connection_manager.fetchrow_query(
  934. query, (index_name, schema_name, base_table_name, f"%({col_name}%")
  935. )
  936. if not result:
  937. raise ValueError(
  938. f"Vector index '{index_name}' does not exist on table {table_name_str}"
  939. )
  940. # Drop the index
  941. concurrently_sql = "CONCURRENTLY" if concurrently else ""
  942. drop_query = (
  943. f"DROP INDEX {concurrently_sql} {schema_name}.{index_name}"
  944. )
  945. try:
  946. if concurrently:
  947. async with (
  948. self.connection_manager.pool.get_connection() as conn # type: ignore
  949. ):
  950. # Disable automatic transaction management
  951. await conn.execute(
  952. "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
  953. )
  954. await conn.execute(drop_query)
  955. else:
  956. await self.connection_manager.execute_query(drop_query)
  957. except Exception as e:
  958. raise Exception(f"Failed to delete index: {e}") from e
  959. async def list_chunks(
  960. self,
  961. offset: int,
  962. limit: int,
  963. filters: Optional[dict[str, Any]] = None,
  964. include_vectors: bool = False,
  965. ) -> dict[str, Any]:
  966. """List chunks with pagination support.
  967. Args:
  968. offset (int, optional): Number of records to skip. Defaults to 0.
  969. limit (int, optional): Maximum number of records to return. Defaults to 10.
  970. filters (dict, optional): Dictionary of filters to apply. Defaults to None.
  971. include_vectors (bool, optional): Whether to include vector data. Defaults to False.
  972. Returns:
  973. dict: Dictionary containing:
  974. - results: List of chunk records
  975. - total_entries: Total number of chunks matching the filters
  976. """
  977. vector_select = ", vec" if include_vectors else ""
  978. select_clause = f"""
  979. id, document_id, owner_id, collection_ids,
  980. text, metadata{vector_select}, COUNT(*) OVER() AS total_entries
  981. """
  982. params: list[str | int | bytes] = []
  983. where_clause = ""
  984. if filters:
  985. where_clause, params = apply_filters(
  986. filters, params, mode="where_clause"
  987. )
  988. query = f"""
  989. SELECT {select_clause}
  990. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  991. {where_clause}
  992. LIMIT ${len(params) + 1}
  993. OFFSET ${len(params) + 2}
  994. """
  995. params.extend([limit, offset])
  996. logger.info(query)
  997. logger.info(params)
  998. # Execute the query
  999. results = await self.connection_manager.fetch_query(query, params)
  1000. # Process results
  1001. chunks = []
  1002. total_entries = 0
  1003. if results:
  1004. total_entries = results[0].get("total_entries", 0)
  1005. chunks = [
  1006. {
  1007. "id": str(result["id"]),
  1008. "document_id": str(result["document_id"]),
  1009. "owner_id": str(result["owner_id"]),
  1010. "collection_ids": result["collection_ids"],
  1011. "text": result["text"],
  1012. "metadata": json.loads(result["metadata"]),
  1013. "vector": (
  1014. json.loads(result["vec"]) if include_vectors else None
  1015. ),
  1016. }
  1017. for result in results
  1018. ]
  1019. return {"results": chunks, "total_entries": total_entries}
  1020. async def search_documents(
  1021. self,
  1022. query_text: str,
  1023. settings: SearchSettings,
  1024. ) -> list[dict[str, Any]]:
  1025. """Search for documents based on their metadata fields and/or body
  1026. text. Joins with documents table to get complete document metadata.
  1027. Args:
  1028. query_text (str): The search query text
  1029. settings (SearchSettings): Search settings including search preferences and filters
  1030. Returns:
  1031. list[dict[str, Any]]: List of documents with their search scores and complete metadata
  1032. """
  1033. where_clauses = []
  1034. params: list[str | int | bytes] = [query_text]
  1035. search_over_body = getattr(settings, "search_over_body", True)
  1036. search_over_metadata = getattr(settings, "search_over_metadata", True)
  1037. metadata_weight = getattr(settings, "metadata_weight", 3.0)
  1038. title_weight = getattr(settings, "title_weight", 1.0)
  1039. metadata_keys = getattr(
  1040. settings, "metadata_keys", ["title", "description"]
  1041. )
  1042. # Build the dynamic metadata field search expression
  1043. metadata_fields_expr = " || ' ' || ".join(
  1044. [
  1045. f"COALESCE(v.metadata->>{psql_quote_literal(key)}, '')"
  1046. for key in metadata_keys # type: ignore
  1047. ]
  1048. )
  1049. query = f"""
  1050. WITH
  1051. -- Metadata search scores
  1052. metadata_scores AS (
  1053. SELECT DISTINCT ON (v.document_id)
  1054. v.document_id,
  1055. d.metadata as doc_metadata,
  1056. CASE WHEN $1 = '' THEN 0.0
  1057. ELSE
  1058. ts_rank_cd(
  1059. setweight(to_tsvector('english', {metadata_fields_expr}), 'A'),
  1060. websearch_to_tsquery('english', $1),
  1061. 32
  1062. )
  1063. END as metadata_rank
  1064. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} v
  1065. LEFT JOIN {self._get_table_name("documents")} d ON v.document_id = d.id
  1066. WHERE v.metadata IS NOT NULL
  1067. ),
  1068. -- Body search scores
  1069. body_scores AS (
  1070. SELECT
  1071. document_id,
  1072. AVG(
  1073. ts_rank_cd(
  1074. setweight(to_tsvector('english', COALESCE(text, '')), 'B'),
  1075. websearch_to_tsquery('english', $1),
  1076. 32
  1077. )
  1078. ) as body_rank
  1079. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  1080. WHERE $1 != ''
  1081. {"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if search_over_body else ""}
  1082. GROUP BY document_id
  1083. ),
  1084. -- Combined scores with document metadata
  1085. combined_scores AS (
  1086. SELECT
  1087. COALESCE(m.document_id, b.document_id) as document_id,
  1088. m.doc_metadata as metadata,
  1089. COALESCE(m.metadata_rank, 0) as debug_metadata_rank,
  1090. COALESCE(b.body_rank, 0) as debug_body_rank,
  1091. CASE
  1092. WHEN {str(search_over_metadata).lower()} AND {str(search_over_body).lower()} THEN
  1093. COALESCE(m.metadata_rank, 0) * {metadata_weight} + COALESCE(b.body_rank, 0) * {title_weight}
  1094. WHEN {str(search_over_metadata).lower()} THEN
  1095. COALESCE(m.metadata_rank, 0)
  1096. WHEN {str(search_over_body).lower()} THEN
  1097. COALESCE(b.body_rank, 0)
  1098. ELSE 0
  1099. END as rank
  1100. FROM metadata_scores m
  1101. FULL OUTER JOIN body_scores b ON m.document_id = b.document_id
  1102. WHERE (
  1103. ($1 = '') OR
  1104. ({str(search_over_metadata).lower()} AND m.metadata_rank > 0) OR
  1105. ({str(search_over_body).lower()} AND b.body_rank > 0)
  1106. )
  1107. """
  1108. # Add any additional filters
  1109. if settings.filters:
  1110. filter_clause, params = apply_filters(settings.filters, params)
  1111. where_clauses.append(filter_clause)
  1112. if where_clauses:
  1113. query += f" AND {' AND '.join(where_clauses)}"
  1114. query += """
  1115. )
  1116. SELECT
  1117. document_id,
  1118. metadata,
  1119. rank as score,
  1120. debug_metadata_rank,
  1121. debug_body_rank
  1122. FROM combined_scores
  1123. WHERE rank > 0
  1124. ORDER BY rank DESC
  1125. OFFSET ${offset_param} LIMIT ${limit_param}
  1126. """.format(
  1127. offset_param=len(params) + 1,
  1128. limit_param=len(params) + 2,
  1129. )
  1130. # Add offset and limit to params
  1131. params.extend([settings.offset, settings.limit])
  1132. # Execute query
  1133. results = await self.connection_manager.fetch_query(query, params)
  1134. # Format results with complete document metadata
  1135. return [
  1136. {
  1137. "document_id": str(r["document_id"]),
  1138. "metadata": (
  1139. json.loads(r["metadata"])
  1140. if isinstance(r["metadata"], str)
  1141. else r["metadata"]
  1142. ),
  1143. "score": float(r["score"]),
  1144. "debug_metadata_rank": float(r["debug_metadata_rank"]),
  1145. "debug_body_rank": float(r["debug_body_rank"]),
  1146. }
  1147. for r in results
  1148. ]
  1149. def _get_index_options(
  1150. self,
  1151. method: IndexMethod,
  1152. index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW],
  1153. ) -> str:
  1154. if method == IndexMethod.ivfflat:
  1155. if isinstance(index_arguments, IndexArgsIVFFlat):
  1156. return f"WITH (lists={index_arguments.n_lists})"
  1157. else:
  1158. # Default value if no arguments provided
  1159. return "WITH (lists=100)"
  1160. elif method == IndexMethod.hnsw:
  1161. if isinstance(index_arguments, IndexArgsHNSW):
  1162. return f"WITH (m={index_arguments.m}, ef_construction={index_arguments.ef_construction})"
  1163. else:
  1164. # Default values if no arguments provided
  1165. return "WITH (m=16, ef_construction=64)"
  1166. else:
  1167. return "" # No options for other methods