chunks.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319
  1. import copy
  2. import json
  3. import logging
  4. import time
  5. import uuid
  6. from typing import Any, Optional, TypedDict
  7. from uuid import UUID
  8. import numpy as np
  9. from core.base import (
  10. ChunkSearchResult,
  11. Handler,
  12. IndexArgsHNSW,
  13. IndexArgsIVFFlat,
  14. IndexMeasure,
  15. IndexMethod,
  16. R2RException,
  17. SearchSettings,
  18. VectorEntry,
  19. VectorQuantizationType,
  20. VectorTableName,
  21. )
  22. from .base import PostgresConnectionManager
  23. from .filters import apply_filters
  24. from .vecs.exc import ArgError, FilterError
  25. logger = logging.getLogger()
  26. from core.base.utils import _decorate_vector_type
  27. def psql_quote_literal(value: str) -> str:
  28. """
  29. Safely quote a string literal for PostgreSQL to prevent SQL injection.
  30. This is a simple implementation - in production, you should use proper parameterization
  31. or your database driver's quoting functions.
  32. """
  33. return "'" + value.replace("'", "''") + "'"
  34. def index_measure_to_ops(
  35. measure: IndexMeasure,
  36. quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
  37. ):
  38. return _decorate_vector_type(measure.ops, quantization_type)
  39. def quantize_vector_to_binary(
  40. vector: list[float] | np.ndarray,
  41. threshold: float = 0.0,
  42. ) -> bytes:
  43. """
  44. Quantizes a float vector to a binary vector string for PostgreSQL bit type.
  45. Used when quantization_type is INT1.
  46. Args:
  47. vector (List[float] | np.ndarray): Input vector of floats
  48. threshold (float, optional): Threshold for binarization. Defaults to 0.0.
  49. Returns:
  50. str: Binary string representation for PostgreSQL bit type
  51. """
  52. # Convert input to numpy array if it isn't already
  53. if not isinstance(vector, np.ndarray):
  54. vector = np.array(vector)
  55. # Convert to binary (1 where value > threshold, 0 otherwise)
  56. binary_vector = (vector > threshold).astype(int)
  57. # Convert to string of 1s and 0s
  58. # Convert to string of 1s and 0s, then to bytes
  59. binary_string = "".join(map(str, binary_vector))
  60. return binary_string.encode("ascii")
  61. class HybridSearchIntermediateResult(TypedDict):
  62. semantic_rank: int
  63. full_text_rank: int
  64. data: ChunkSearchResult
  65. rrf_score: float
  66. class PostgresChunksHandler(Handler):
  67. TABLE_NAME = VectorTableName.CHUNKS
  68. def __init__(
  69. self,
  70. project_name: str,
  71. connection_manager: PostgresConnectionManager,
  72. dimension: int,
  73. quantization_type: VectorQuantizationType,
  74. ):
  75. super().__init__(project_name, connection_manager)
  76. self.dimension = dimension
  77. self.quantization_type = quantization_type
  78. async def create_tables(self):
  79. # Check for old table name first
  80. check_query = """
  81. SELECT EXISTS (
  82. SELECT FROM pg_tables
  83. WHERE schemaname = $1
  84. AND tablename = $2
  85. );
  86. """
  87. old_table_exists = await self.connection_manager.fetch_query(
  88. check_query, (self.project_name, self.project_name)
  89. )
  90. if len(old_table_exists) > 0 and old_table_exists[0]["exists"]:
  91. raise ValueError(
  92. f"Found old vector table '{self.project_name}.{self.project_name}'. "
  93. "Please run `r2r db upgrade` with the CLI, or to run manually, "
  94. "run in R2R/py/migrations with 'alembic upgrade head' to update "
  95. "your database schema to the new version."
  96. )
  97. binary_col = (
  98. ""
  99. if self.quantization_type != VectorQuantizationType.INT1
  100. else f"vec_binary bit({self.dimension}),"
  101. )
  102. query = f"""
  103. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (
  104. id UUID PRIMARY KEY,
  105. document_id UUID,
  106. owner_id UUID,
  107. collection_ids UUID[],
  108. vec vector({self.dimension}),
  109. {binary_col}
  110. text TEXT,
  111. metadata JSONB,
  112. fts tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
  113. );
  114. CREATE INDEX IF NOT EXISTS idx_vectors_document_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (document_id);
  115. CREATE INDEX IF NOT EXISTS idx_vectors_owner_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (owner_id);
  116. CREATE INDEX IF NOT EXISTS idx_vectors_collection_ids ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (collection_ids);
  117. CREATE INDEX IF NOT EXISTS idx_vectors_text ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (to_tsvector('english', text));
  118. """
  119. await self.connection_manager.execute_query(query)
  120. async def upsert(self, entry: VectorEntry) -> None:
  121. """
  122. Upsert function that handles vector quantization only when quantization_type is INT1.
  123. Matches the table schema where vec_binary column only exists for INT1 quantization.
  124. """
  125. # Check the quantization type to determine which columns to use
  126. if self.quantization_type == VectorQuantizationType.INT1:
  127. # For quantized vectors, use vec_binary column
  128. query = f"""
  129. INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  130. (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
  131. VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
  132. ON CONFLICT (id) DO UPDATE SET
  133. document_id = EXCLUDED.document_id,
  134. owner_id = EXCLUDED.owner_id,
  135. collection_ids = EXCLUDED.collection_ids,
  136. vec = EXCLUDED.vec,
  137. vec_binary = EXCLUDED.vec_binary,
  138. text = EXCLUDED.text,
  139. metadata = EXCLUDED.metadata;
  140. """
  141. await self.connection_manager.execute_query(
  142. query,
  143. (
  144. entry.id,
  145. entry.document_id,
  146. entry.owner_id,
  147. entry.collection_ids,
  148. str(entry.vector.data),
  149. quantize_vector_to_binary(
  150. entry.vector.data
  151. ), # Convert to binary
  152. entry.text,
  153. json.dumps(entry.metadata),
  154. ),
  155. )
  156. else:
  157. # For regular vectors, use vec column only
  158. query = f"""
  159. INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  160. (id, document_id, owner_id, collection_ids, vec, text, metadata)
  161. VALUES ($1, $2, $3, $4, $5, $6, $7)
  162. ON CONFLICT (id) DO UPDATE SET
  163. document_id = EXCLUDED.document_id,
  164. owner_id = EXCLUDED.owner_id,
  165. collection_ids = EXCLUDED.collection_ids,
  166. vec = EXCLUDED.vec,
  167. text = EXCLUDED.text,
  168. metadata = EXCLUDED.metadata;
  169. """
  170. await self.connection_manager.execute_query(
  171. query,
  172. (
  173. entry.id,
  174. entry.document_id,
  175. entry.owner_id,
  176. entry.collection_ids,
  177. str(entry.vector.data),
  178. entry.text,
  179. json.dumps(entry.metadata),
  180. ),
  181. )
  182. async def upsert_entries(self, entries: list[VectorEntry]) -> None:
  183. """
  184. Batch upsert function that handles vector quantization only when quantization_type is INT1.
  185. Matches the table schema where vec_binary column only exists for INT1 quantization.
  186. """
  187. if self.quantization_type == VectorQuantizationType.INT1:
  188. # For quantized vectors, use vec_binary column
  189. query = f"""
  190. INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  191. (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
  192. VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
  193. ON CONFLICT (id) DO UPDATE SET
  194. document_id = EXCLUDED.document_id,
  195. owner_id = EXCLUDED.owner_id,
  196. collection_ids = EXCLUDED.collection_ids,
  197. vec = EXCLUDED.vec,
  198. vec_binary = EXCLUDED.vec_binary,
  199. text = EXCLUDED.text,
  200. metadata = EXCLUDED.metadata;
  201. """
  202. bin_params = [
  203. (
  204. entry.id,
  205. entry.document_id,
  206. entry.owner_id,
  207. entry.collection_ids,
  208. str(entry.vector.data),
  209. quantize_vector_to_binary(
  210. entry.vector.data
  211. ), # Convert to binary
  212. entry.text,
  213. json.dumps(entry.metadata),
  214. )
  215. for entry in entries
  216. ]
  217. await self.connection_manager.execute_many(query, bin_params)
  218. else:
  219. # For regular vectors, use vec column only
  220. query = f"""
  221. INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  222. (id, document_id, owner_id, collection_ids, vec, text, metadata)
  223. VALUES ($1, $2, $3, $4, $5, $6, $7)
  224. ON CONFLICT (id) DO UPDATE SET
  225. document_id = EXCLUDED.document_id,
  226. owner_id = EXCLUDED.owner_id,
  227. collection_ids = EXCLUDED.collection_ids,
  228. vec = EXCLUDED.vec,
  229. text = EXCLUDED.text,
  230. metadata = EXCLUDED.metadata;
  231. """
  232. params = [
  233. (
  234. entry.id,
  235. entry.document_id,
  236. entry.owner_id,
  237. entry.collection_ids,
  238. str(entry.vector.data),
  239. entry.text,
  240. json.dumps(entry.metadata),
  241. )
  242. for entry in entries
  243. ]
  244. await self.connection_manager.execute_many(query, params)
  245. async def semantic_search(
  246. self, query_vector: list[float], search_settings: SearchSettings
  247. ) -> list[ChunkSearchResult]:
  248. try:
  249. imeasure_obj = IndexMeasure(
  250. search_settings.chunk_settings.index_measure
  251. )
  252. except ValueError:
  253. raise ValueError("Invalid index measure")
  254. table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
  255. cols = [
  256. f"{table_name}.id",
  257. f"{table_name}.document_id",
  258. f"{table_name}.owner_id",
  259. f"{table_name}.collection_ids",
  260. f"{table_name}.text",
  261. ]
  262. params: list[str | int | bytes] = []
  263. # For binary vectors (INT1), implement two-stage search
  264. if self.quantization_type == VectorQuantizationType.INT1:
  265. # Convert query vector to binary format
  266. binary_query = quantize_vector_to_binary(query_vector)
  267. # TODO - Put depth multiplier in config / settings
  268. extended_limit = (
  269. search_settings.limit * 20
  270. ) # Get 20x candidates for re-ranking
  271. if (
  272. imeasure_obj == IndexMeasure.hamming_distance
  273. or imeasure_obj == IndexMeasure.jaccard_distance
  274. ):
  275. binary_search_measure_repr = imeasure_obj.pgvector_repr
  276. else:
  277. binary_search_measure_repr = (
  278. IndexMeasure.hamming_distance.pgvector_repr
  279. )
  280. # Use binary column and binary-specific distance measures for first stage
  281. stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit({self.dimension})"
  282. stage1_param = binary_query
  283. cols.append(
  284. f"{table_name}.vec"
  285. ) # Need original vector for re-ranking
  286. if search_settings.include_metadatas:
  287. cols.append(f"{table_name}.metadata")
  288. select_clause = ", ".join(cols)
  289. where_clause = ""
  290. params.append(stage1_param)
  291. if search_settings.filters:
  292. where_clause, params = apply_filters(
  293. search_settings.filters, params, mode="where_clause"
  294. )
  295. # First stage: Get candidates using binary search
  296. query = f"""
  297. WITH candidates AS (
  298. SELECT {select_clause},
  299. ({stage1_distance}) as binary_distance
  300. FROM {table_name}
  301. {where_clause}
  302. ORDER BY {stage1_distance}
  303. LIMIT ${len(params) + 1}
  304. OFFSET ${len(params) + 2}
  305. )
  306. -- Second stage: Re-rank using original vectors
  307. SELECT
  308. id,
  309. document_id,
  310. owner_id,
  311. collection_ids,
  312. text,
  313. {"metadata," if search_settings.include_metadatas else ""}
  314. (vec <=> ${len(params) + 4}::vector({self.dimension})) as distance
  315. FROM candidates
  316. ORDER BY distance
  317. LIMIT ${len(params) + 3}
  318. """
  319. params.extend(
  320. [
  321. extended_limit, # First stage limit
  322. search_settings.offset,
  323. search_settings.limit, # Final limit
  324. str(query_vector), # For re-ranking
  325. ]
  326. )
  327. else:
  328. # Standard float vector handling
  329. distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector({self.dimension})"
  330. query_param = str(query_vector)
  331. if search_settings.include_scores:
  332. cols.append(f"({distance_calc}) AS distance")
  333. if search_settings.include_metadatas:
  334. cols.append(f"{table_name}.metadata")
  335. select_clause = ", ".join(cols)
  336. where_clause = ""
  337. params.append(query_param)
  338. if search_settings.filters:
  339. where_clause, new_params = apply_filters(
  340. search_settings.filters,
  341. params,
  342. mode="where_clause", # Get just conditions without WHERE
  343. )
  344. params = new_params
  345. query = f"""
  346. SELECT {select_clause}
  347. FROM {table_name}
  348. {where_clause}
  349. ORDER BY {distance_calc}
  350. LIMIT ${len(params) + 1}
  351. OFFSET ${len(params) + 2}
  352. """
  353. params.extend([search_settings.limit, search_settings.offset])
  354. results = await self.connection_manager.fetch_query(query, params)
  355. return [
  356. ChunkSearchResult(
  357. id=UUID(str(result["id"])),
  358. document_id=UUID(str(result["document_id"])),
  359. owner_id=UUID(str(result["owner_id"])),
  360. collection_ids=result["collection_ids"],
  361. text=result["text"],
  362. score=(
  363. (1 - float(result["distance"]))
  364. if "distance" in result
  365. else -1
  366. ),
  367. metadata=(
  368. json.loads(result["metadata"])
  369. if search_settings.include_metadatas
  370. else {}
  371. ),
  372. )
  373. for result in results
  374. ]
  375. async def full_text_search(
  376. self, query_text: str, search_settings: SearchSettings
  377. ) -> list[ChunkSearchResult]:
  378. conditions = []
  379. params: list[str | int | bytes] = [query_text]
  380. conditions.append("fts @@ websearch_to_tsquery('english', $1)")
  381. if search_settings.filters:
  382. filter_condition, params = apply_filters(
  383. search_settings.filters, params, mode="condition_only"
  384. )
  385. if filter_condition:
  386. conditions.append(filter_condition)
  387. where_clause = "WHERE " + " AND ".join(conditions)
  388. query = f"""
  389. SELECT
  390. id,
  391. document_id,
  392. owner_id,
  393. collection_ids,
  394. text,
  395. metadata,
  396. ts_rank(fts, websearch_to_tsquery('english', $1), 32) as rank
  397. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  398. {where_clause}
  399. ORDER BY rank DESC
  400. OFFSET ${len(params)+1}
  401. LIMIT ${len(params)+2}
  402. """
  403. params.extend(
  404. [
  405. search_settings.offset,
  406. search_settings.hybrid_settings.full_text_limit,
  407. ]
  408. )
  409. results = await self.connection_manager.fetch_query(query, params)
  410. return [
  411. ChunkSearchResult(
  412. id=UUID(str(r["id"])),
  413. document_id=UUID(str(r["document_id"])),
  414. owner_id=UUID(str(r["owner_id"])),
  415. collection_ids=r["collection_ids"],
  416. text=r["text"],
  417. score=float(r["rank"]),
  418. metadata=json.loads(r["metadata"]),
  419. )
  420. for r in results
  421. ]
  422. async def hybrid_search(
  423. self,
  424. query_text: str,
  425. query_vector: list[float],
  426. search_settings: SearchSettings,
  427. *args,
  428. **kwargs,
  429. ) -> list[ChunkSearchResult]:
  430. if search_settings.hybrid_settings is None:
  431. raise ValueError(
  432. "Please provide a valid `hybrid_settings` in the `search_settings`."
  433. )
  434. if (
  435. search_settings.hybrid_settings.full_text_limit
  436. < search_settings.limit
  437. ):
  438. raise ValueError(
  439. "The `full_text_limit` must be greater than or equal to the `limit`."
  440. )
  441. semantic_settings = copy.deepcopy(search_settings)
  442. semantic_settings.limit += search_settings.offset
  443. full_text_settings = copy.deepcopy(search_settings)
  444. full_text_settings.hybrid_settings.full_text_limit += (
  445. search_settings.offset
  446. )
  447. semantic_results: list[ChunkSearchResult] = await self.semantic_search(
  448. query_vector, semantic_settings
  449. )
  450. full_text_results: list[ChunkSearchResult] = (
  451. await self.full_text_search(query_text, full_text_settings)
  452. )
  453. semantic_limit = search_settings.limit
  454. full_text_limit = search_settings.hybrid_settings.full_text_limit
  455. semantic_weight = search_settings.hybrid_settings.semantic_weight
  456. full_text_weight = search_settings.hybrid_settings.full_text_weight
  457. rrf_k = search_settings.hybrid_settings.rrf_k
  458. combined_results: dict[uuid.UUID, HybridSearchIntermediateResult] = {}
  459. for rank, result in enumerate(semantic_results, 1):
  460. combined_results[result.id] = {
  461. "semantic_rank": rank,
  462. "full_text_rank": full_text_limit,
  463. "data": result,
  464. "rrf_score": 0.0, # Initialize with 0, will be calculated later
  465. }
  466. for rank, result in enumerate(full_text_results, 1):
  467. if result.id in combined_results:
  468. combined_results[result.id]["full_text_rank"] = rank
  469. else:
  470. combined_results[result.id] = {
  471. "semantic_rank": semantic_limit,
  472. "full_text_rank": rank,
  473. "data": result,
  474. "rrf_score": 0.0, # Initialize with 0, will be calculated later
  475. }
  476. combined_results = {
  477. k: v
  478. for k, v in combined_results.items()
  479. if v["semantic_rank"] <= semantic_limit * 2
  480. and v["full_text_rank"] <= full_text_limit * 2
  481. }
  482. for hyb_result in combined_results.values():
  483. semantic_score = 1 / (rrf_k + hyb_result["semantic_rank"])
  484. full_text_score = 1 / (rrf_k + hyb_result["full_text_rank"])
  485. hyb_result["rrf_score"] = (
  486. semantic_score * semantic_weight
  487. + full_text_score * full_text_weight
  488. ) / (semantic_weight + full_text_weight)
  489. sorted_results = sorted(
  490. combined_results.values(),
  491. key=lambda x: x["rrf_score"],
  492. reverse=True,
  493. )
  494. offset_results = sorted_results[
  495. search_settings.offset : search_settings.offset
  496. + search_settings.limit
  497. ]
  498. return [
  499. ChunkSearchResult(
  500. id=result["data"].id,
  501. document_id=result["data"].document_id,
  502. owner_id=result["data"].owner_id,
  503. collection_ids=result["data"].collection_ids,
  504. text=result["data"].text,
  505. score=result["rrf_score"],
  506. metadata={
  507. **result["data"].metadata,
  508. "semantic_rank": result["semantic_rank"],
  509. "full_text_rank": result["full_text_rank"],
  510. },
  511. )
  512. for result in offset_results
  513. ]
  514. async def delete(
  515. self, filters: dict[str, Any]
  516. ) -> dict[str, dict[str, str]]:
  517. params: list[str | int | bytes] = []
  518. where_clause, params = apply_filters(
  519. filters, params, mode="condition_only"
  520. )
  521. query = f"""
  522. DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  523. WHERE {where_clause}
  524. RETURNING id, document_id, text;
  525. """
  526. results = await self.connection_manager.fetch_query(query, params)
  527. return {
  528. str(result["id"]): {
  529. "status": "deleted",
  530. "id": str(result["id"]),
  531. "document_id": str(result["document_id"]),
  532. "text": result["text"],
  533. }
  534. for result in results
  535. }
  536. async def assign_document_chunks_to_collection(
  537. self, document_id: UUID, collection_id: UUID
  538. ) -> None:
  539. query = f"""
  540. UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  541. SET collection_ids = array_append(collection_ids, $1)
  542. WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids));
  543. """
  544. result = await self.connection_manager.execute_query(
  545. query, (str(collection_id), str(document_id))
  546. )
  547. return result
  548. async def remove_document_from_collection_vector(
  549. self, document_id: UUID, collection_id: UUID
  550. ) -> None:
  551. query = f"""
  552. UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  553. SET collection_ids = array_remove(collection_ids, $1)
  554. WHERE document_id = $2;
  555. """
  556. await self.connection_manager.execute_query(
  557. query, (collection_id, document_id)
  558. )
  559. async def delete_user_vector(self, owner_id: UUID) -> None:
  560. query = f"""
  561. DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  562. WHERE owner_id = $1;
  563. """
  564. await self.connection_manager.execute_query(query, (owner_id,))
  565. async def delete_collection_vector(self, collection_id: UUID) -> None:
  566. query = f"""
  567. DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  568. WHERE $1 = ANY(collection_ids)
  569. RETURNING collection_ids
  570. """
  571. results = await self.connection_manager.fetchrow_query(
  572. query, (collection_id,)
  573. )
  574. return None
  575. async def list_document_chunks(
  576. self,
  577. document_id: UUID,
  578. offset: int,
  579. limit: int,
  580. include_vectors: bool = False,
  581. ) -> dict[str, Any]:
  582. vector_select = ", vec" if include_vectors else ""
  583. limit_clause = f"LIMIT {limit}" if limit > -1 else ""
  584. query = f"""
  585. SELECT id, document_id, owner_id, collection_ids, text, metadata{vector_select}, COUNT(*) OVER() AS total
  586. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  587. WHERE document_id = $1
  588. ORDER BY (metadata->>'chunk_order')::integer
  589. OFFSET $2
  590. {limit_clause};
  591. """
  592. params = [document_id, offset]
  593. results = await self.connection_manager.fetch_query(query, params)
  594. chunks = []
  595. total = 0
  596. if results:
  597. total = results[0].get("total", 0)
  598. chunks = [
  599. {
  600. "id": result["id"],
  601. "document_id": result["document_id"],
  602. "owner_id": result["owner_id"],
  603. "collection_ids": result["collection_ids"],
  604. "text": result["text"],
  605. "metadata": json.loads(result["metadata"]),
  606. "vector": (
  607. json.loads(result["vec"]) if include_vectors else None
  608. ),
  609. }
  610. for result in results
  611. ]
  612. return {"results": chunks, "total_entries": total}
  613. async def get_chunk(self, id: UUID) -> dict:
  614. query = f"""
  615. SELECT id, document_id, owner_id, collection_ids, text, metadata
  616. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  617. WHERE id = $1;
  618. """
  619. result = await self.connection_manager.fetchrow_query(query, (id,))
  620. if result:
  621. return {
  622. "id": result["id"],
  623. "document_id": result["document_id"],
  624. "owner_id": result["owner_id"],
  625. "collection_ids": result["collection_ids"],
  626. "text": result["text"],
  627. "metadata": json.loads(result["metadata"]),
  628. }
  629. raise R2RException(
  630. message=f"Chunk with ID {id} not found", status_code=404
  631. )
  632. async def create_index(
  633. self,
  634. table_name: Optional[VectorTableName] = None,
  635. index_measure: IndexMeasure = IndexMeasure.cosine_distance,
  636. index_method: IndexMethod = IndexMethod.auto,
  637. index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = None,
  638. index_name: Optional[str] = None,
  639. index_column: Optional[str] = None,
  640. concurrently: bool = True,
  641. ) -> None:
  642. """
  643. Creates an index for the collection.
  644. Note:
  645. When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step
  646. process that enables performant indexes to be built for large collections with low end
  647. database hardware.
  648. Those steps are:
  649. - Creates a new table with a different name
  650. - Randomly selects records from the existing table
  651. - Inserts the random records from the existing table into the new table
  652. - Creates the requested vector index on the new table
  653. - Upserts all data from the existing table into the new table
  654. - Drops the existing table
  655. - Renames the new table to the existing tables name
  656. If you create dependencies (like views) on the table that underpins
  657. a `vecs.Collection` the `create_index` step may require you to drop those dependencies before
  658. it will succeed.
  659. Args:
  660. index_measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'.
  661. index_method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'.
  662. index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments
  663. index_name (str, optional): The name of the index to create. Defaults to None.
  664. concurrently (bool, optional): Whether to create the index concurrently. Defaults to True.
  665. Raises:
  666. ArgError: If an invalid index method is used, or if *replace* is False and an index already exists.
  667. """
  668. if table_name == VectorTableName.CHUNKS:
  669. table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}" # TODO - Fix bug in vector table naming convention
  670. if index_column:
  671. col_name = index_column
  672. else:
  673. col_name = (
  674. "vec"
  675. if (
  676. index_measure != IndexMeasure.hamming_distance
  677. and index_measure != IndexMeasure.jaccard_distance
  678. )
  679. else "vec_binary"
  680. )
  681. elif table_name == VectorTableName.ENTITIES_DOCUMENT:
  682. table_name_str = (
  683. f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
  684. )
  685. col_name = "description_embedding"
  686. elif table_name == VectorTableName.GRAPHS_ENTITIES:
  687. table_name_str = (
  688. f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
  689. )
  690. col_name = "description_embedding"
  691. elif table_name == VectorTableName.COMMUNITIES:
  692. table_name_str = (
  693. f"{self.project_name}.{VectorTableName.COMMUNITIES}"
  694. )
  695. col_name = "embedding"
  696. else:
  697. raise ArgError("invalid table name")
  698. if index_method not in (
  699. IndexMethod.ivfflat,
  700. IndexMethod.hnsw,
  701. IndexMethod.auto,
  702. ):
  703. raise ArgError("invalid index method")
  704. if index_arguments:
  705. # Disallow case where user submits index arguments but uses the
  706. # IndexMethod.auto index (index build arguments should only be
  707. # used with a specific index)
  708. if index_method == IndexMethod.auto:
  709. raise ArgError(
  710. "Index build parameters are not allowed when using the IndexMethod.auto index."
  711. )
  712. # Disallow case where user specifies one index type but submits
  713. # index build arguments for the other index type
  714. if (
  715. isinstance(index_arguments, IndexArgsHNSW)
  716. and index_method != IndexMethod.hnsw
  717. ) or (
  718. isinstance(index_arguments, IndexArgsIVFFlat)
  719. and index_method != IndexMethod.ivfflat
  720. ):
  721. raise ArgError(
  722. f"{index_arguments.__class__.__name__} build parameters were supplied but {index_method} index was specified."
  723. )
  724. if index_method == IndexMethod.auto:
  725. index_method = IndexMethod.hnsw
  726. ops = index_measure_to_ops(
  727. index_measure # , quantization_type=self.quantization_type
  728. )
  729. if ops is None:
  730. raise ArgError("Unknown index measure")
  731. concurrently_sql = "CONCURRENTLY" if concurrently else ""
  732. index_name = (
  733. index_name
  734. or f"ix_{ops}_{index_method}__{col_name}_{time.strftime('%Y%m%d%H%M%S')}"
  735. )
  736. create_index_sql = f"""
  737. CREATE INDEX {concurrently_sql} {index_name}
  738. ON {table_name_str}
  739. USING {index_method} ({col_name} {ops}) {self._get_index_options(index_method, index_arguments)};
  740. """
  741. try:
  742. if concurrently:
  743. async with (
  744. self.connection_manager.pool.get_connection() as conn # type: ignore
  745. ):
  746. # Disable automatic transaction management
  747. await conn.execute(
  748. "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
  749. )
  750. await conn.execute(create_index_sql)
  751. else:
  752. # Non-concurrent index creation can use normal query execution
  753. await self.connection_manager.execute_query(create_index_sql)
  754. except Exception as e:
  755. raise Exception(f"Failed to create index: {e}")
  756. return None
  757. async def list_indices(
  758. self,
  759. offset: int,
  760. limit: int,
  761. filters: Optional[dict[str, Any]] = None,
  762. ) -> dict:
  763. where_clauses = []
  764. params: list[Any] = [self.project_name] # Start with schema name
  765. param_count = 1
  766. # Handle filtering
  767. if filters:
  768. if "table_name" in filters:
  769. where_clauses.append(f"i.tablename = ${param_count + 1}")
  770. params.append(filters["table_name"])
  771. param_count += 1
  772. if "index_method" in filters:
  773. where_clauses.append(f"am.amname = ${param_count + 1}")
  774. params.append(filters["index_method"])
  775. param_count += 1
  776. if "index_name" in filters:
  777. where_clauses.append(
  778. f"LOWER(i.indexname) LIKE LOWER(${param_count + 1})"
  779. )
  780. params.append(f"%{filters['index_name']}%")
  781. param_count += 1
  782. where_clause = " AND ".join(where_clauses) if where_clauses else ""
  783. if where_clause:
  784. where_clause = "AND " + where_clause
  785. query = f"""
  786. WITH index_info AS (
  787. SELECT
  788. i.indexname as name,
  789. i.tablename as table_name,
  790. i.indexdef as definition,
  791. am.amname as method,
  792. pg_relation_size(c.oid) as size_in_bytes,
  793. c.reltuples::bigint as row_estimate,
  794. COALESCE(psat.idx_scan, 0) as number_of_scans,
  795. COALESCE(psat.idx_tup_read, 0) as tuples_read,
  796. COALESCE(psat.idx_tup_fetch, 0) as tuples_fetched,
  797. COUNT(*) OVER() as total_count
  798. FROM pg_indexes i
  799. JOIN pg_class c ON c.relname = i.indexname
  800. JOIN pg_am am ON c.relam = am.oid
  801. LEFT JOIN pg_stat_user_indexes psat ON psat.indexrelname = i.indexname
  802. AND psat.schemaname = i.schemaname
  803. WHERE i.schemaname = $1
  804. AND i.indexdef LIKE '%vector%'
  805. {where_clause}
  806. )
  807. SELECT *
  808. FROM index_info
  809. ORDER BY name
  810. LIMIT ${param_count + 1}
  811. OFFSET ${param_count + 2}
  812. """
  813. # Add limit and offset to params
  814. params.extend([limit, offset])
  815. results = await self.connection_manager.fetch_query(query, params)
  816. indices = []
  817. total_entries = 0
  818. if results:
  819. total_entries = results[0]["total_count"]
  820. for result in results:
  821. index_info = {
  822. "name": result["name"],
  823. "table_name": result["table_name"],
  824. "definition": result["definition"],
  825. "size_in_bytes": result["size_in_bytes"],
  826. "row_estimate": result["row_estimate"],
  827. "number_of_scans": result["number_of_scans"],
  828. "tuples_read": result["tuples_read"],
  829. "tuples_fetched": result["tuples_fetched"],
  830. }
  831. indices.append(index_info)
  832. # Calculate pagination info
  833. total_pages = (total_entries + limit - 1) // limit if limit > 0 else 1
  834. current_page = (offset // limit) + 1 if limit > 0 else 1
  835. page_info = {
  836. "total_entries": total_entries,
  837. "total_pages": total_pages,
  838. "current_page": current_page,
  839. "limit": limit,
  840. "offset": offset,
  841. "has_previous": offset > 0,
  842. "has_next": offset + limit < total_entries,
  843. "previous_offset": max(0, offset - limit) if offset > 0 else None,
  844. "next_offset": (
  845. offset + limit if offset + limit < total_entries else None
  846. ),
  847. }
  848. return {"indices": indices, "page_info": page_info}
  849. async def delete_index(
  850. self,
  851. index_name: str,
  852. table_name: Optional[VectorTableName] = None,
  853. concurrently: bool = True,
  854. ) -> None:
  855. """
  856. Deletes a vector index.
  857. Args:
  858. index_name (str): Name of the index to delete
  859. table_name (VectorTableName, optional): Table the index belongs to
  860. concurrently (bool): Whether to drop the index concurrently
  861. Raises:
  862. ArgError: If table name is invalid or index doesn't exist
  863. Exception: If index deletion fails
  864. """
  865. # Validate table name and get column name
  866. if table_name == VectorTableName.CHUNKS:
  867. table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}"
  868. col_name = "vec"
  869. elif table_name == VectorTableName.ENTITIES_DOCUMENT:
  870. table_name_str = (
  871. f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
  872. )
  873. col_name = "description_embedding"
  874. elif table_name == VectorTableName.GRAPHS_ENTITIES:
  875. table_name_str = (
  876. f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
  877. )
  878. col_name = "description_embedding"
  879. elif table_name == VectorTableName.COMMUNITIES:
  880. table_name_str = (
  881. f"{self.project_name}.{VectorTableName.COMMUNITIES}"
  882. )
  883. col_name = "description_embedding"
  884. else:
  885. raise ArgError("invalid table name")
  886. # Extract schema and base table name
  887. schema_name, base_table_name = table_name_str.split(".")
  888. # Verify index exists and is a vector index
  889. query = """
  890. SELECT indexdef
  891. FROM pg_indexes
  892. WHERE indexname = $1
  893. AND schemaname = $2
  894. AND tablename = $3
  895. AND indexdef LIKE $4
  896. """
  897. result = await self.connection_manager.fetchrow_query(
  898. query, (index_name, schema_name, base_table_name, f"%({col_name}%")
  899. )
  900. if not result:
  901. raise ArgError(
  902. f"Vector index '{index_name}' does not exist on table {table_name_str}"
  903. )
  904. # Drop the index
  905. concurrently_sql = "CONCURRENTLY" if concurrently else ""
  906. drop_query = (
  907. f"DROP INDEX {concurrently_sql} {schema_name}.{index_name}"
  908. )
  909. try:
  910. if concurrently:
  911. async with (
  912. self.connection_manager.pool.get_connection() as conn # type: ignore
  913. ):
  914. # Disable automatic transaction management
  915. await conn.execute(
  916. "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
  917. )
  918. await conn.execute(drop_query)
  919. else:
  920. await self.connection_manager.execute_query(drop_query)
  921. except Exception as e:
  922. raise Exception(f"Failed to delete index: {e}")
  923. async def get_semantic_neighbors(
  924. self,
  925. offset: int,
  926. limit: int,
  927. document_id: UUID,
  928. id: UUID,
  929. similarity_threshold: float = 0.5,
  930. ) -> list[dict[str, Any]]:
  931. table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
  932. query = f"""
  933. WITH target_vector AS (
  934. SELECT vec FROM {table_name}
  935. WHERE document_id = $1 AND id = $2
  936. )
  937. SELECT t.id, t.text, t.metadata, t.document_id, (t.vec <=> tv.vec) AS similarity
  938. FROM {table_name} t, target_vector tv
  939. WHERE (t.vec <=> tv.vec) >= $3
  940. AND t.document_id = $1
  941. AND t.id != $2
  942. ORDER BY similarity ASC
  943. LIMIT $4
  944. """
  945. results = await self.connection_manager.fetch_query(
  946. query,
  947. (str(document_id), str(id), similarity_threshold, limit),
  948. )
  949. return [
  950. {
  951. "id": str(r["id"]),
  952. "text": r["text"],
  953. "metadata": json.loads(r["metadata"]),
  954. "document_id": str(r["document_id"]),
  955. "similarity": float(r["similarity"]),
  956. }
  957. for r in results
  958. ]
  959. async def list_chunks(
  960. self,
  961. offset: int,
  962. limit: int,
  963. filters: Optional[dict[str, Any]] = None,
  964. include_vectors: bool = False,
  965. ) -> dict[str, Any]:
  966. """
  967. List chunks with pagination support.
  968. Args:
  969. offset (int, optional): Number of records to skip. Defaults to 0.
  970. limit (int, optional): Maximum number of records to return. Defaults to 10.
  971. filters (dict, optional): Dictionary of filters to apply. Defaults to None.
  972. include_vectors (bool, optional): Whether to include vector data. Defaults to False.
  973. Returns:
  974. dict: Dictionary containing:
  975. - results: List of chunk records
  976. - total_entries: Total number of chunks matching the filters
  977. - page_info: Pagination information
  978. """
  979. vector_select = ", vec" if include_vectors else ""
  980. select_clause = f"""
  981. id, document_id, owner_id, collection_ids,
  982. text, metadata{vector_select}, COUNT(*) OVER() AS total
  983. """
  984. params: list[str | int | bytes] = []
  985. where_clause = ""
  986. if filters:
  987. where_clause, params = apply_filters(
  988. filters, params, mode="where_clause"
  989. )
  990. query = f"""
  991. SELECT {select_clause}
  992. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  993. {where_clause}
  994. LIMIT ${len(params) + 1}
  995. OFFSET ${len(params) + 2}
  996. """
  997. params.extend([limit, offset])
  998. # Execute the query
  999. results = await self.connection_manager.fetch_query(query, params)
  1000. # Process results
  1001. chunks = []
  1002. total = 0
  1003. if results:
  1004. total = results[0].get("total", 0)
  1005. chunks = [
  1006. {
  1007. "id": str(result["id"]),
  1008. "document_id": str(result["document_id"]),
  1009. "owner_id": str(result["owner_id"]),
  1010. "collection_ids": result["collection_ids"],
  1011. "text": result["text"],
  1012. "metadata": json.loads(result["metadata"]),
  1013. "vector": (
  1014. json.loads(result["vec"]) if include_vectors else None
  1015. ),
  1016. }
  1017. for result in results
  1018. ]
  1019. # Calculate pagination info
  1020. total_pages = (total + limit - 1) // limit if limit > 0 else 1
  1021. current_page = (offset // limit) + 1 if limit > 0 else 1
  1022. page_info = {
  1023. "total_entries": total,
  1024. "total_pages": total_pages,
  1025. "current_page": current_page,
  1026. "limit": limit,
  1027. "offset": offset,
  1028. "has_previous": offset > 0,
  1029. "has_next": offset + limit < total,
  1030. "previous_offset": max(0, offset - limit) if offset > 0 else None,
  1031. "next_offset": offset + limit if offset + limit < total else None,
  1032. }
  1033. return {"results": chunks, "page_info": page_info}
  1034. async def search_documents(
  1035. self,
  1036. query_text: str,
  1037. settings: SearchSettings,
  1038. ) -> list[dict[str, Any]]:
  1039. """
  1040. Search for documents based on their metadata fields and/or body text.
  1041. Joins with documents table to get complete document metadata.
  1042. Args:
  1043. query_text (str): The search query text
  1044. settings (SearchSettings): Search settings including search preferences and filters
  1045. Returns:
  1046. list[dict[str, Any]]: List of documents with their search scores and complete metadata
  1047. """
  1048. where_clauses = []
  1049. params: list[str | int | bytes] = [query_text]
  1050. # Build the dynamic metadata field search expression
  1051. metadata_fields_expr = " || ' ' || ".join(
  1052. [
  1053. f"COALESCE(v.metadata->>{psql_quote_literal(key)}, '')"
  1054. for key in settings.metadata_keys # type: ignore
  1055. ]
  1056. )
  1057. query = f"""
  1058. WITH
  1059. -- Metadata search scores
  1060. metadata_scores AS (
  1061. SELECT DISTINCT ON (v.document_id)
  1062. v.document_id,
  1063. d.metadata as doc_metadata,
  1064. CASE WHEN $1 = '' THEN 0.0
  1065. ELSE
  1066. ts_rank_cd(
  1067. setweight(to_tsvector('english', {metadata_fields_expr}), 'A'),
  1068. websearch_to_tsquery('english', $1),
  1069. 32
  1070. )
  1071. END as metadata_rank
  1072. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} v
  1073. LEFT JOIN {self._get_table_name('documents')} d ON v.document_id = d.id
  1074. WHERE v.metadata IS NOT NULL
  1075. ),
  1076. -- Body search scores
  1077. body_scores AS (
  1078. SELECT
  1079. document_id,
  1080. AVG(
  1081. ts_rank_cd(
  1082. setweight(to_tsvector('english', COALESCE(text, '')), 'B'),
  1083. websearch_to_tsquery('english', $1),
  1084. 32
  1085. )
  1086. ) as body_rank
  1087. FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
  1088. WHERE $1 != ''
  1089. {f"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if settings.search_over_body else ""}
  1090. GROUP BY document_id
  1091. ),
  1092. -- Combined scores with document metadata
  1093. combined_scores AS (
  1094. SELECT
  1095. COALESCE(m.document_id, b.document_id) as document_id,
  1096. m.doc_metadata as metadata,
  1097. COALESCE(m.metadata_rank, 0) as debug_metadata_rank,
  1098. COALESCE(b.body_rank, 0) as debug_body_rank,
  1099. CASE
  1100. WHEN {str(settings.search_over_metadata).lower()} AND {str(settings.search_over_body).lower()} THEN
  1101. COALESCE(m.metadata_rank, 0) * {settings.metadata_weight} + COALESCE(b.body_rank, 0) * {settings.title_weight}
  1102. WHEN {str(settings.search_over_metadata).lower()} THEN
  1103. COALESCE(m.metadata_rank, 0)
  1104. WHEN {str(settings.search_over_body).lower()} THEN
  1105. COALESCE(b.body_rank, 0)
  1106. ELSE 0
  1107. END as rank
  1108. FROM metadata_scores m
  1109. FULL OUTER JOIN body_scores b ON m.document_id = b.document_id
  1110. WHERE (
  1111. ($1 = '') OR
  1112. ({str(settings.search_over_metadata).lower()} AND m.metadata_rank > 0) OR
  1113. ({str(settings.search_over_body).lower()} AND b.body_rank > 0)
  1114. )
  1115. """
  1116. # Add any additional filters
  1117. if settings.filters:
  1118. filter_clause, params = apply_filters(settings.filters, params)
  1119. where_clauses.append(filter_clause)
  1120. if where_clauses:
  1121. query += f" AND {' AND '.join(where_clauses)}"
  1122. query += """
  1123. )
  1124. SELECT
  1125. document_id,
  1126. metadata,
  1127. rank as score,
  1128. debug_metadata_rank,
  1129. debug_body_rank
  1130. FROM combined_scores
  1131. WHERE rank > 0
  1132. ORDER BY rank DESC
  1133. OFFSET ${offset_param} LIMIT ${limit_param}
  1134. """.format(
  1135. offset_param=len(params) + 1,
  1136. limit_param=len(params) + 2,
  1137. )
  1138. # Add offset and limit to params
  1139. params.extend([settings.offset, settings.limit])
  1140. # Execute query
  1141. results = await self.connection_manager.fetch_query(query, params)
  1142. # Format results with complete document metadata
  1143. return [
  1144. {
  1145. "document_id": str(r["document_id"]),
  1146. "metadata": (
  1147. json.loads(r["metadata"])
  1148. if isinstance(r["metadata"], str)
  1149. else r["metadata"]
  1150. ),
  1151. "score": float(r["score"]),
  1152. "debug_metadata_rank": float(r["debug_metadata_rank"]),
  1153. "debug_body_rank": float(r["debug_body_rank"]),
  1154. }
  1155. for r in results
  1156. ]
  1157. def _get_index_options(
  1158. self,
  1159. method: IndexMethod,
  1160. index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW],
  1161. ) -> str:
  1162. if method == IndexMethod.ivfflat:
  1163. if isinstance(index_arguments, IndexArgsIVFFlat):
  1164. return f"WITH (lists={index_arguments.n_lists})"
  1165. else:
  1166. # Default value if no arguments provided
  1167. return "WITH (lists=100)"
  1168. elif method == IndexMethod.hnsw:
  1169. if isinstance(index_arguments, IndexArgsHNSW):
  1170. return f"WITH (m={index_arguments.m}, ef_construction={index_arguments.ef_construction})"
  1171. else:
  1172. # Default values if no arguments provided
  1173. return "WITH (m=16, ef_construction=64)"
  1174. else:
  1175. return "" # No options for other methods