1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318 |
- import copy
- import json
- import logging
- import time
- import uuid
- from typing import Any, Optional, TypedDict
- from uuid import UUID
- import numpy as np
- from core.base import (
- ChunkSearchResult,
- Handler,
- IndexArgsHNSW,
- IndexArgsIVFFlat,
- IndexMeasure,
- IndexMethod,
- R2RException,
- SearchSettings,
- VectorEntry,
- VectorQuantizationType,
- VectorTableName,
- )
- from .base import PostgresConnectionManager
- from .filters import apply_filters
- from .vecs.exc import ArgError, FilterError
- logger = logging.getLogger()
- from core.base.utils import _decorate_vector_type
- def psql_quote_literal(value: str) -> str:
- """
- Safely quote a string literal for PostgreSQL to prevent SQL injection.
- This is a simple implementation - in production, you should use proper parameterization
- or your database driver's quoting functions.
- """
- return "'" + value.replace("'", "''") + "'"
- def index_measure_to_ops(
- measure: IndexMeasure,
- quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
- ):
- return _decorate_vector_type(measure.ops, quantization_type)
- def quantize_vector_to_binary(
- vector: list[float] | np.ndarray,
- threshold: float = 0.0,
- ) -> bytes:
- """
- Quantizes a float vector to a binary vector string for PostgreSQL bit type.
- Used when quantization_type is INT1.
- Args:
- vector (List[float] | np.ndarray): Input vector of floats
- threshold (float, optional): Threshold for binarization. Defaults to 0.0.
- Returns:
- str: Binary string representation for PostgreSQL bit type
- """
- # Convert input to numpy array if it isn't already
- if not isinstance(vector, np.ndarray):
- vector = np.array(vector)
- # Convert to binary (1 where value > threshold, 0 otherwise)
- binary_vector = (vector > threshold).astype(int)
- # Convert to string of 1s and 0s
- # Convert to string of 1s and 0s, then to bytes
- binary_string = "".join(map(str, binary_vector))
- return binary_string.encode("ascii")
- class HybridSearchIntermediateResult(TypedDict):
- semantic_rank: int
- full_text_rank: int
- data: ChunkSearchResult
- rrf_score: float
- class PostgresChunksHandler(Handler):
- TABLE_NAME = VectorTableName.CHUNKS
- def __init__(
- self,
- project_name: str,
- connection_manager: PostgresConnectionManager,
- dimension: int,
- quantization_type: VectorQuantizationType,
- ):
- super().__init__(project_name, connection_manager)
- self.dimension = dimension
- self.quantization_type = quantization_type
- async def create_tables(self):
- # Check for old table name first
- check_query = """
- SELECT EXISTS (
- SELECT FROM pg_tables
- WHERE schemaname = $1
- AND tablename = $2
- );
- """
- old_table_exists = await self.connection_manager.fetch_query(
- check_query, (self.project_name, self.project_name)
- )
- if len(old_table_exists) > 0 and old_table_exists[0]["exists"]:
- raise ValueError(
- f"Found old vector table '{self.project_name}.{self.project_name}'. "
- "Please run `r2r db upgrade` with the CLI, or to run manually, "
- "run in R2R/py/migrations with 'alembic upgrade head' to update "
- "your database schema to the new version."
- )
- binary_col = (
- ""
- if self.quantization_type != VectorQuantizationType.INT1
- else f"vec_binary bit({self.dimension}),"
- )
- query = f"""
- CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (
- id UUID PRIMARY KEY,
- document_id UUID,
- owner_id UUID,
- collection_ids UUID[],
- vec vector({self.dimension}),
- {binary_col}
- text TEXT,
- metadata JSONB,
- fts tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
- );
- CREATE INDEX IF NOT EXISTS idx_vectors_document_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (document_id);
- CREATE INDEX IF NOT EXISTS idx_vectors_owner_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (owner_id);
- CREATE INDEX IF NOT EXISTS idx_vectors_collection_ids ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (collection_ids);
- CREATE INDEX IF NOT EXISTS idx_vectors_text ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (to_tsvector('english', text));
- """
- await self.connection_manager.execute_query(query)
- async def upsert(self, entry: VectorEntry) -> None:
- """
- Upsert function that handles vector quantization only when quantization_type is INT1.
- Matches the table schema where vec_binary column only exists for INT1 quantization.
- """
- # Check the quantization type to determine which columns to use
- if self.quantization_type == VectorQuantizationType.INT1:
- # For quantized vectors, use vec_binary column
- query = f"""
- INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
- VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
- ON CONFLICT (id) DO UPDATE SET
- document_id = EXCLUDED.document_id,
- owner_id = EXCLUDED.owner_id,
- collection_ids = EXCLUDED.collection_ids,
- vec = EXCLUDED.vec,
- vec_binary = EXCLUDED.vec_binary,
- text = EXCLUDED.text,
- metadata = EXCLUDED.metadata;
- """
- await self.connection_manager.execute_query(
- query,
- (
- entry.id,
- entry.document_id,
- entry.owner_id,
- entry.collection_ids,
- str(entry.vector.data),
- quantize_vector_to_binary(
- entry.vector.data
- ), # Convert to binary
- entry.text,
- json.dumps(entry.metadata),
- ),
- )
- else:
- # For regular vectors, use vec column only
- query = f"""
- INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- (id, document_id, owner_id, collection_ids, vec, text, metadata)
- VALUES ($1, $2, $3, $4, $5, $6, $7)
- ON CONFLICT (id) DO UPDATE SET
- document_id = EXCLUDED.document_id,
- owner_id = EXCLUDED.owner_id,
- collection_ids = EXCLUDED.collection_ids,
- vec = EXCLUDED.vec,
- text = EXCLUDED.text,
- metadata = EXCLUDED.metadata;
- """
- await self.connection_manager.execute_query(
- query,
- (
- entry.id,
- entry.document_id,
- entry.owner_id,
- entry.collection_ids,
- str(entry.vector.data),
- entry.text,
- json.dumps(entry.metadata),
- ),
- )
- async def upsert_entries(self, entries: list[VectorEntry]) -> None:
- """
- Batch upsert function that handles vector quantization only when quantization_type is INT1.
- Matches the table schema where vec_binary column only exists for INT1 quantization.
- """
- if self.quantization_type == VectorQuantizationType.INT1:
- # For quantized vectors, use vec_binary column
- query = f"""
- INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
- VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
- ON CONFLICT (id) DO UPDATE SET
- document_id = EXCLUDED.document_id,
- owner_id = EXCLUDED.owner_id,
- collection_ids = EXCLUDED.collection_ids,
- vec = EXCLUDED.vec,
- vec_binary = EXCLUDED.vec_binary,
- text = EXCLUDED.text,
- metadata = EXCLUDED.metadata;
- """
- bin_params = [
- (
- entry.id,
- entry.document_id,
- entry.owner_id,
- entry.collection_ids,
- str(entry.vector.data),
- quantize_vector_to_binary(
- entry.vector.data
- ), # Convert to binary
- entry.text,
- json.dumps(entry.metadata),
- )
- for entry in entries
- ]
- await self.connection_manager.execute_many(query, bin_params)
- else:
- # For regular vectors, use vec column only
- query = f"""
- INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- (id, document_id, owner_id, collection_ids, vec, text, metadata)
- VALUES ($1, $2, $3, $4, $5, $6, $7)
- ON CONFLICT (id) DO UPDATE SET
- document_id = EXCLUDED.document_id,
- owner_id = EXCLUDED.owner_id,
- collection_ids = EXCLUDED.collection_ids,
- vec = EXCLUDED.vec,
- text = EXCLUDED.text,
- metadata = EXCLUDED.metadata;
- """
- params = [
- (
- entry.id,
- entry.document_id,
- entry.owner_id,
- entry.collection_ids,
- str(entry.vector.data),
- entry.text,
- json.dumps(entry.metadata),
- )
- for entry in entries
- ]
- await self.connection_manager.execute_many(query, params)
- async def semantic_search(
- self, query_vector: list[float], search_settings: SearchSettings
- ) -> list[ChunkSearchResult]:
- try:
- imeasure_obj = IndexMeasure(
- search_settings.chunk_settings.index_measure
- )
- except ValueError:
- raise ValueError("Invalid index measure")
- table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
- cols = [
- f"{table_name}.id",
- f"{table_name}.document_id",
- f"{table_name}.owner_id",
- f"{table_name}.collection_ids",
- f"{table_name}.text",
- ]
- params: list[str | int | bytes] = []
- # For binary vectors (INT1), implement two-stage search
- if self.quantization_type == VectorQuantizationType.INT1:
- # Convert query vector to binary format
- binary_query = quantize_vector_to_binary(query_vector)
- # TODO - Put depth multiplier in config / settings
- extended_limit = (
- search_settings.limit * 20
- ) # Get 20x candidates for re-ranking
- if (
- imeasure_obj == IndexMeasure.hamming_distance
- or imeasure_obj == IndexMeasure.jaccard_distance
- ):
- binary_search_measure_repr = imeasure_obj.pgvector_repr
- else:
- binary_search_measure_repr = (
- IndexMeasure.hamming_distance.pgvector_repr
- )
- # Use binary column and binary-specific distance measures for first stage
- stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit({self.dimension})"
- stage1_param = binary_query
- cols.append(
- f"{table_name}.vec"
- ) # Need original vector for re-ranking
- if search_settings.include_metadatas:
- cols.append(f"{table_name}.metadata")
- select_clause = ", ".join(cols)
- where_clause = ""
- params.append(stage1_param)
- if search_settings.filters:
- where_clause, params = apply_filters(
- search_settings.filters, params, mode="where_clause"
- )
- # First stage: Get candidates using binary search
- query = f"""
- WITH candidates AS (
- SELECT {select_clause},
- ({stage1_distance}) as binary_distance
- FROM {table_name}
- {where_clause}
- ORDER BY {stage1_distance}
- LIMIT ${len(params) + 1}
- OFFSET ${len(params) + 2}
- )
- -- Second stage: Re-rank using original vectors
- SELECT
- id,
- document_id,
- owner_id,
- collection_ids,
- text,
- {"metadata," if search_settings.include_metadatas else ""}
- (vec <=> ${len(params) + 4}::vector({self.dimension})) as distance
- FROM candidates
- ORDER BY distance
- LIMIT ${len(params) + 3}
- """
- params.extend(
- [
- extended_limit, # First stage limit
- search_settings.offset,
- search_settings.limit, # Final limit
- str(query_vector), # For re-ranking
- ]
- )
- else:
- # Standard float vector handling
- distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector({self.dimension})"
- query_param = str(query_vector)
- if search_settings.include_scores:
- cols.append(f"({distance_calc}) AS distance")
- if search_settings.include_metadatas:
- cols.append(f"{table_name}.metadata")
- select_clause = ", ".join(cols)
- where_clause = ""
- params.append(query_param)
- if search_settings.filters:
- where_clause, new_params = apply_filters(
- search_settings.filters,
- params,
- mode="where_clause", # Get just conditions without WHERE
- )
- params = new_params
- query = f"""
- SELECT {select_clause}
- FROM {table_name}
- {where_clause}
- ORDER BY {distance_calc}
- LIMIT ${len(params) + 1}
- OFFSET ${len(params) + 2}
- """
- params.extend([search_settings.limit, search_settings.offset])
- results = await self.connection_manager.fetch_query(query, params)
- return [
- ChunkSearchResult(
- id=UUID(str(result["id"])),
- document_id=UUID(str(result["document_id"])),
- owner_id=UUID(str(result["owner_id"])),
- collection_ids=result["collection_ids"],
- text=result["text"],
- score=(
- (1 - float(result["distance"]))
- if "distance" in result
- else -1
- ),
- metadata=(
- json.loads(result["metadata"])
- if search_settings.include_metadatas
- else {}
- ),
- )
- for result in results
- ]
- async def full_text_search(
- self, query_text: str, search_settings: SearchSettings
- ) -> list[ChunkSearchResult]:
- conditions = []
- params: list[str | int | bytes] = [query_text]
- conditions.append("fts @@ websearch_to_tsquery('english', $1)")
- if search_settings.filters:
- filter_condition, params = apply_filters(
- search_settings.filters, params, mode="condition_only"
- )
- if filter_condition:
- conditions.append(filter_condition)
- where_clause = "WHERE " + " AND ".join(conditions)
- query = f"""
- SELECT
- id,
- document_id,
- owner_id,
- collection_ids,
- text,
- metadata,
- ts_rank(fts, websearch_to_tsquery('english', $1), 32) as rank
- FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- {where_clause}
- ORDER BY rank DESC
- OFFSET ${len(params)+1}
- LIMIT ${len(params)+2}
- """
- params.extend(
- [
- search_settings.offset,
- search_settings.hybrid_settings.full_text_limit,
- ]
- )
- results = await self.connection_manager.fetch_query(query, params)
- return [
- ChunkSearchResult(
- id=UUID(str(r["id"])),
- document_id=UUID(str(r["document_id"])),
- owner_id=UUID(str(r["owner_id"])),
- collection_ids=r["collection_ids"],
- text=r["text"],
- score=float(r["rank"]),
- metadata=json.loads(r["metadata"]),
- )
- for r in results
- ]
- async def hybrid_search(
- self,
- query_text: str,
- query_vector: list[float],
- search_settings: SearchSettings,
- *args,
- **kwargs,
- ) -> list[ChunkSearchResult]:
- if search_settings.hybrid_settings is None:
- raise ValueError(
- "Please provide a valid `hybrid_settings` in the `search_settings`."
- )
- if (
- search_settings.hybrid_settings.full_text_limit
- < search_settings.limit
- ):
- raise ValueError(
- "The `full_text_limit` must be greater than or equal to the `limit`."
- )
- semantic_settings = copy.deepcopy(search_settings)
- semantic_settings.limit += search_settings.offset
- full_text_settings = copy.deepcopy(search_settings)
- full_text_settings.hybrid_settings.full_text_limit += (
- search_settings.offset
- )
- semantic_results: list[ChunkSearchResult] = await self.semantic_search(
- query_vector, semantic_settings
- )
- full_text_results: list[ChunkSearchResult] = (
- await self.full_text_search(query_text, full_text_settings)
- )
- semantic_limit = search_settings.limit
- full_text_limit = search_settings.hybrid_settings.full_text_limit
- semantic_weight = search_settings.hybrid_settings.semantic_weight
- full_text_weight = search_settings.hybrid_settings.full_text_weight
- rrf_k = search_settings.hybrid_settings.rrf_k
- combined_results: dict[uuid.UUID, HybridSearchIntermediateResult] = {}
- for rank, result in enumerate(semantic_results, 1):
- combined_results[result.id] = {
- "semantic_rank": rank,
- "full_text_rank": full_text_limit,
- "data": result,
- "rrf_score": 0.0, # Initialize with 0, will be calculated later
- }
- for rank, result in enumerate(full_text_results, 1):
- if result.id in combined_results:
- combined_results[result.id]["full_text_rank"] = rank
- else:
- combined_results[result.id] = {
- "semantic_rank": semantic_limit,
- "full_text_rank": rank,
- "data": result,
- "rrf_score": 0.0, # Initialize with 0, will be calculated later
- }
- combined_results = {
- k: v
- for k, v in combined_results.items()
- if v["semantic_rank"] <= semantic_limit * 2
- and v["full_text_rank"] <= full_text_limit * 2
- }
- for hyb_result in combined_results.values():
- semantic_score = 1 / (rrf_k + hyb_result["semantic_rank"])
- full_text_score = 1 / (rrf_k + hyb_result["full_text_rank"])
- hyb_result["rrf_score"] = (
- semantic_score * semantic_weight
- + full_text_score * full_text_weight
- ) / (semantic_weight + full_text_weight)
- sorted_results = sorted(
- combined_results.values(),
- key=lambda x: x["rrf_score"],
- reverse=True,
- )
- offset_results = sorted_results[
- search_settings.offset : search_settings.offset
- + search_settings.limit
- ]
- return [
- ChunkSearchResult(
- id=result["data"].id,
- document_id=result["data"].document_id,
- owner_id=result["data"].owner_id,
- collection_ids=result["data"].collection_ids,
- text=result["data"].text,
- score=result["rrf_score"],
- metadata={
- **result["data"].metadata,
- "semantic_rank": result["semantic_rank"],
- "full_text_rank": result["full_text_rank"],
- },
- )
- for result in offset_results
- ]
- async def delete(
- self, filters: dict[str, Any]
- ) -> dict[str, dict[str, str]]:
- params: list[str | int | bytes] = []
- where_clause, params = apply_filters(
- filters, params, mode="condition_only"
- )
- query = f"""
- DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- WHERE {where_clause}
- RETURNING id, document_id, text;
- """
- results = await self.connection_manager.fetch_query(query, params)
- return {
- str(result["id"]): {
- "status": "deleted",
- "id": str(result["id"]),
- "document_id": str(result["document_id"]),
- "text": result["text"],
- }
- for result in results
- }
- async def assign_document_chunks_to_collection(
- self, document_id: UUID, collection_id: UUID
- ) -> None:
- query = f"""
- UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- SET collection_ids = array_append(collection_ids, $1)
- WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids));
- """
- return await self.connection_manager.execute_query(
- query, (str(collection_id), str(document_id))
- )
- async def remove_document_from_collection_vector(
- self, document_id: UUID, collection_id: UUID
- ) -> None:
- query = f"""
- UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- SET collection_ids = array_remove(collection_ids, $1)
- WHERE document_id = $2;
- """
- await self.connection_manager.execute_query(
- query, (collection_id, document_id)
- )
- async def delete_user_vector(self, owner_id: UUID) -> None:
- query = f"""
- DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- WHERE owner_id = $1;
- """
- await self.connection_manager.execute_query(query, (owner_id,))
- async def delete_collection_vector(self, collection_id: UUID) -> None:
- query = f"""
- DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- WHERE $1 = ANY(collection_ids)
- RETURNING collection_ids
- """
- results = await self.connection_manager.fetchrow_query(
- query, (collection_id,)
- )
- return None
- async def list_document_chunks(
- self,
- document_id: UUID,
- offset: int,
- limit: int,
- include_vectors: bool = False,
- ) -> dict[str, Any]:
- vector_select = ", vec" if include_vectors else ""
- limit_clause = f"LIMIT {limit}" if limit > -1 else ""
- query = f"""
- SELECT id, document_id, owner_id, collection_ids, text, metadata{vector_select}, COUNT(*) OVER() AS total
- FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- WHERE document_id = $1
- ORDER BY (metadata->>'chunk_order')::integer
- OFFSET $2
- {limit_clause};
- """
- params = [document_id, offset]
- results = await self.connection_manager.fetch_query(query, params)
- chunks = []
- total = 0
- if results:
- total = results[0].get("total", 0)
- chunks = [
- {
- "id": result["id"],
- "document_id": result["document_id"],
- "owner_id": result["owner_id"],
- "collection_ids": result["collection_ids"],
- "text": result["text"],
- "metadata": json.loads(result["metadata"]),
- "vector": (
- json.loads(result["vec"]) if include_vectors else None
- ),
- }
- for result in results
- ]
- return {"results": chunks, "total_entries": total}
- async def get_chunk(self, id: UUID) -> dict:
- query = f"""
- SELECT id, document_id, owner_id, collection_ids, text, metadata
- FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- WHERE id = $1;
- """
- result = await self.connection_manager.fetchrow_query(query, (id,))
- if result:
- return {
- "id": result["id"],
- "document_id": result["document_id"],
- "owner_id": result["owner_id"],
- "collection_ids": result["collection_ids"],
- "text": result["text"],
- "metadata": json.loads(result["metadata"]),
- }
- raise R2RException(
- message=f"Chunk with ID {id} not found", status_code=404
- )
- async def create_index(
- self,
- table_name: Optional[VectorTableName] = None,
- index_measure: IndexMeasure = IndexMeasure.cosine_distance,
- index_method: IndexMethod = IndexMethod.auto,
- index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = None,
- index_name: Optional[str] = None,
- index_column: Optional[str] = None,
- concurrently: bool = True,
- ) -> None:
- """
- Creates an index for the collection.
- Note:
- When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step
- process that enables performant indexes to be built for large collections with low end
- database hardware.
- Those steps are:
- - Creates a new table with a different name
- - Randomly selects records from the existing table
- - Inserts the random records from the existing table into the new table
- - Creates the requested vector index on the new table
- - Upserts all data from the existing table into the new table
- - Drops the existing table
- - Renames the new table to the existing tables name
- If you create dependencies (like views) on the table that underpins
- a `vecs.Collection` the `create_index` step may require you to drop those dependencies before
- it will succeed.
- Args:
- index_measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'.
- index_method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'.
- index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments
- index_name (str, optional): The name of the index to create. Defaults to None.
- concurrently (bool, optional): Whether to create the index concurrently. Defaults to True.
- Raises:
- ArgError: If an invalid index method is used, or if *replace* is False and an index already exists.
- """
- if table_name == VectorTableName.CHUNKS:
- table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}" # TODO - Fix bug in vector table naming convention
- if index_column:
- col_name = index_column
- else:
- col_name = (
- "vec"
- if (
- index_measure != IndexMeasure.hamming_distance
- and index_measure != IndexMeasure.jaccard_distance
- )
- else "vec_binary"
- )
- elif table_name == VectorTableName.ENTITIES_DOCUMENT:
- table_name_str = (
- f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
- )
- col_name = "description_embedding"
- elif table_name == VectorTableName.GRAPHS_ENTITIES:
- table_name_str = (
- f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
- )
- col_name = "description_embedding"
- elif table_name == VectorTableName.COMMUNITIES:
- table_name_str = (
- f"{self.project_name}.{VectorTableName.COMMUNITIES}"
- )
- col_name = "embedding"
- else:
- raise ArgError("invalid table name")
- if index_method not in (
- IndexMethod.ivfflat,
- IndexMethod.hnsw,
- IndexMethod.auto,
- ):
- raise ArgError("invalid index method")
- if index_arguments:
- # Disallow case where user submits index arguments but uses the
- # IndexMethod.auto index (index build arguments should only be
- # used with a specific index)
- if index_method == IndexMethod.auto:
- raise ArgError(
- "Index build parameters are not allowed when using the IndexMethod.auto index."
- )
- # Disallow case where user specifies one index type but submits
- # index build arguments for the other index type
- if (
- isinstance(index_arguments, IndexArgsHNSW)
- and index_method != IndexMethod.hnsw
- ) or (
- isinstance(index_arguments, IndexArgsIVFFlat)
- and index_method != IndexMethod.ivfflat
- ):
- raise ArgError(
- f"{index_arguments.__class__.__name__} build parameters were supplied but {index_method} index was specified."
- )
- if index_method == IndexMethod.auto:
- index_method = IndexMethod.hnsw
- ops = index_measure_to_ops(
- index_measure # , quantization_type=self.quantization_type
- )
- if ops is None:
- raise ArgError("Unknown index measure")
- concurrently_sql = "CONCURRENTLY" if concurrently else ""
- index_name = (
- index_name
- or f"ix_{ops}_{index_method}__{col_name}_{time.strftime('%Y%m%d%H%M%S')}"
- )
- create_index_sql = f"""
- CREATE INDEX {concurrently_sql} {index_name}
- ON {table_name_str}
- USING {index_method} ({col_name} {ops}) {self._get_index_options(index_method, index_arguments)};
- """
- try:
- if concurrently:
- async with (
- self.connection_manager.pool.get_connection() as conn # type: ignore
- ):
- # Disable automatic transaction management
- await conn.execute(
- "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
- )
- await conn.execute(create_index_sql)
- else:
- # Non-concurrent index creation can use normal query execution
- await self.connection_manager.execute_query(create_index_sql)
- except Exception as e:
- raise Exception(f"Failed to create index: {e}")
- return None
- async def list_indices(
- self,
- offset: int,
- limit: int,
- filters: Optional[dict[str, Any]] = None,
- ) -> dict:
- where_clauses = []
- params: list[Any] = [self.project_name] # Start with schema name
- param_count = 1
- # Handle filtering
- if filters:
- if "table_name" in filters:
- where_clauses.append(f"i.tablename = ${param_count + 1}")
- params.append(filters["table_name"])
- param_count += 1
- if "index_method" in filters:
- where_clauses.append(f"am.amname = ${param_count + 1}")
- params.append(filters["index_method"])
- param_count += 1
- if "index_name" in filters:
- where_clauses.append(
- f"LOWER(i.indexname) LIKE LOWER(${param_count + 1})"
- )
- params.append(f"%{filters['index_name']}%")
- param_count += 1
- where_clause = " AND ".join(where_clauses) if where_clauses else ""
- if where_clause:
- where_clause = f"AND {where_clause}"
- query = f"""
- WITH index_info AS (
- SELECT
- i.indexname as name,
- i.tablename as table_name,
- i.indexdef as definition,
- am.amname as method,
- pg_relation_size(c.oid) as size_in_bytes,
- c.reltuples::bigint as row_estimate,
- COALESCE(psat.idx_scan, 0) as number_of_scans,
- COALESCE(psat.idx_tup_read, 0) as tuples_read,
- COALESCE(psat.idx_tup_fetch, 0) as tuples_fetched,
- COUNT(*) OVER() as total_count
- FROM pg_indexes i
- JOIN pg_class c ON c.relname = i.indexname
- JOIN pg_am am ON c.relam = am.oid
- LEFT JOIN pg_stat_user_indexes psat ON psat.indexrelname = i.indexname
- AND psat.schemaname = i.schemaname
- WHERE i.schemaname = $1
- AND i.indexdef LIKE '%vector%'
- {where_clause}
- )
- SELECT *
- FROM index_info
- ORDER BY name
- LIMIT ${param_count + 1}
- OFFSET ${param_count + 2}
- """
- # Add limit and offset to params
- params.extend([limit, offset])
- results = await self.connection_manager.fetch_query(query, params)
- indices = []
- total_entries = 0
- if results:
- total_entries = results[0]["total_count"]
- for result in results:
- index_info = {
- "name": result["name"],
- "table_name": result["table_name"],
- "definition": result["definition"],
- "size_in_bytes": result["size_in_bytes"],
- "row_estimate": result["row_estimate"],
- "number_of_scans": result["number_of_scans"],
- "tuples_read": result["tuples_read"],
- "tuples_fetched": result["tuples_fetched"],
- }
- indices.append(index_info)
- # Calculate pagination info
- total_pages = (total_entries + limit - 1) // limit if limit > 0 else 1
- current_page = (offset // limit) + 1 if limit > 0 else 1
- page_info = {
- "total_entries": total_entries,
- "total_pages": total_pages,
- "current_page": current_page,
- "limit": limit,
- "offset": offset,
- "has_previous": offset > 0,
- "has_next": offset + limit < total_entries,
- "previous_offset": max(0, offset - limit) if offset > 0 else None,
- "next_offset": (
- offset + limit if offset + limit < total_entries else None
- ),
- }
- return {"indices": indices, "page_info": page_info}
- async def delete_index(
- self,
- index_name: str,
- table_name: Optional[VectorTableName] = None,
- concurrently: bool = True,
- ) -> None:
- """
- Deletes a vector index.
- Args:
- index_name (str): Name of the index to delete
- table_name (VectorTableName, optional): Table the index belongs to
- concurrently (bool): Whether to drop the index concurrently
- Raises:
- ArgError: If table name is invalid or index doesn't exist
- Exception: If index deletion fails
- """
- # Validate table name and get column name
- if table_name == VectorTableName.CHUNKS:
- table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}"
- col_name = "vec"
- elif table_name == VectorTableName.ENTITIES_DOCUMENT:
- table_name_str = (
- f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
- )
- col_name = "description_embedding"
- elif table_name == VectorTableName.GRAPHS_ENTITIES:
- table_name_str = (
- f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
- )
- col_name = "description_embedding"
- elif table_name == VectorTableName.COMMUNITIES:
- table_name_str = (
- f"{self.project_name}.{VectorTableName.COMMUNITIES}"
- )
- col_name = "description_embedding"
- else:
- raise ArgError("invalid table name")
- # Extract schema and base table name
- schema_name, base_table_name = table_name_str.split(".")
- # Verify index exists and is a vector index
- query = """
- SELECT indexdef
- FROM pg_indexes
- WHERE indexname = $1
- AND schemaname = $2
- AND tablename = $3
- AND indexdef LIKE $4
- """
- result = await self.connection_manager.fetchrow_query(
- query, (index_name, schema_name, base_table_name, f"%({col_name}%")
- )
- if not result:
- raise ArgError(
- f"Vector index '{index_name}' does not exist on table {table_name_str}"
- )
- # Drop the index
- concurrently_sql = "CONCURRENTLY" if concurrently else ""
- drop_query = (
- f"DROP INDEX {concurrently_sql} {schema_name}.{index_name}"
- )
- try:
- if concurrently:
- async with (
- self.connection_manager.pool.get_connection() as conn # type: ignore
- ):
- # Disable automatic transaction management
- await conn.execute(
- "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
- )
- await conn.execute(drop_query)
- else:
- await self.connection_manager.execute_query(drop_query)
- except Exception as e:
- raise Exception(f"Failed to delete index: {e}")
- async def get_semantic_neighbors(
- self,
- offset: int,
- limit: int,
- document_id: UUID,
- id: UUID,
- similarity_threshold: float = 0.5,
- ) -> list[dict[str, Any]]:
- table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
- query = f"""
- WITH target_vector AS (
- SELECT vec FROM {table_name}
- WHERE document_id = $1 AND id = $2
- )
- SELECT t.id, t.text, t.metadata, t.document_id, (t.vec <=> tv.vec) AS similarity
- FROM {table_name} t, target_vector tv
- WHERE (t.vec <=> tv.vec) >= $3
- AND t.document_id = $1
- AND t.id != $2
- ORDER BY similarity ASC
- LIMIT $4
- """
- results = await self.connection_manager.fetch_query(
- query,
- (str(document_id), str(id), similarity_threshold, limit),
- )
- return [
- {
- "id": str(r["id"]),
- "text": r["text"],
- "metadata": json.loads(r["metadata"]),
- "document_id": str(r["document_id"]),
- "similarity": float(r["similarity"]),
- }
- for r in results
- ]
- async def list_chunks(
- self,
- offset: int,
- limit: int,
- filters: Optional[dict[str, Any]] = None,
- include_vectors: bool = False,
- ) -> dict[str, Any]:
- """
- List chunks with pagination support.
- Args:
- offset (int, optional): Number of records to skip. Defaults to 0.
- limit (int, optional): Maximum number of records to return. Defaults to 10.
- filters (dict, optional): Dictionary of filters to apply. Defaults to None.
- include_vectors (bool, optional): Whether to include vector data. Defaults to False.
- Returns:
- dict: Dictionary containing:
- - results: List of chunk records
- - total_entries: Total number of chunks matching the filters
- - page_info: Pagination information
- """
- vector_select = ", vec" if include_vectors else ""
- select_clause = f"""
- id, document_id, owner_id, collection_ids,
- text, metadata{vector_select}, COUNT(*) OVER() AS total
- """
- params: list[str | int | bytes] = []
- where_clause = ""
- if filters:
- where_clause, params = apply_filters(
- filters, params, mode="where_clause"
- )
- query = f"""
- SELECT {select_clause}
- FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- {where_clause}
- LIMIT ${len(params) + 1}
- OFFSET ${len(params) + 2}
- """
- params.extend([limit, offset])
- # Execute the query
- results = await self.connection_manager.fetch_query(query, params)
- # Process results
- chunks = []
- total = 0
- if results:
- total = results[0].get("total", 0)
- chunks = [
- {
- "id": str(result["id"]),
- "document_id": str(result["document_id"]),
- "owner_id": str(result["owner_id"]),
- "collection_ids": result["collection_ids"],
- "text": result["text"],
- "metadata": json.loads(result["metadata"]),
- "vector": (
- json.loads(result["vec"]) if include_vectors else None
- ),
- }
- for result in results
- ]
- # Calculate pagination info
- total_pages = (total + limit - 1) // limit if limit > 0 else 1
- current_page = (offset // limit) + 1 if limit > 0 else 1
- page_info = {
- "total_entries": total,
- "total_pages": total_pages,
- "current_page": current_page,
- "limit": limit,
- "offset": offset,
- "has_previous": offset > 0,
- "has_next": offset + limit < total,
- "previous_offset": max(0, offset - limit) if offset > 0 else None,
- "next_offset": offset + limit if offset + limit < total else None,
- }
- return {"results": chunks, "page_info": page_info}
- async def search_documents(
- self,
- query_text: str,
- settings: SearchSettings,
- ) -> list[dict[str, Any]]:
- """
- Search for documents based on their metadata fields and/or body text.
- Joins with documents table to get complete document metadata.
- Args:
- query_text (str): The search query text
- settings (SearchSettings): Search settings including search preferences and filters
- Returns:
- list[dict[str, Any]]: List of documents with their search scores and complete metadata
- """
- where_clauses = []
- params: list[str | int | bytes] = [query_text]
- # Build the dynamic metadata field search expression
- metadata_fields_expr = " || ' ' || ".join(
- [
- f"COALESCE(v.metadata->>{psql_quote_literal(key)}, '')"
- for key in settings.metadata_keys # type: ignore
- ]
- )
- query = f"""
- WITH
- -- Metadata search scores
- metadata_scores AS (
- SELECT DISTINCT ON (v.document_id)
- v.document_id,
- d.metadata as doc_metadata,
- CASE WHEN $1 = '' THEN 0.0
- ELSE
- ts_rank_cd(
- setweight(to_tsvector('english', {metadata_fields_expr}), 'A'),
- websearch_to_tsquery('english', $1),
- 32
- )
- END as metadata_rank
- FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} v
- LEFT JOIN {self._get_table_name('documents')} d ON v.document_id = d.id
- WHERE v.metadata IS NOT NULL
- ),
- -- Body search scores
- body_scores AS (
- SELECT
- document_id,
- AVG(
- ts_rank_cd(
- setweight(to_tsvector('english', COALESCE(text, '')), 'B'),
- websearch_to_tsquery('english', $1),
- 32
- )
- ) as body_rank
- FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- WHERE $1 != ''
- {"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if settings.search_over_body else ""}
- GROUP BY document_id
- ),
- -- Combined scores with document metadata
- combined_scores AS (
- SELECT
- COALESCE(m.document_id, b.document_id) as document_id,
- m.doc_metadata as metadata,
- COALESCE(m.metadata_rank, 0) as debug_metadata_rank,
- COALESCE(b.body_rank, 0) as debug_body_rank,
- CASE
- WHEN {str(settings.search_over_metadata).lower()} AND {str(settings.search_over_body).lower()} THEN
- COALESCE(m.metadata_rank, 0) * {settings.metadata_weight} + COALESCE(b.body_rank, 0) * {settings.title_weight}
- WHEN {str(settings.search_over_metadata).lower()} THEN
- COALESCE(m.metadata_rank, 0)
- WHEN {str(settings.search_over_body).lower()} THEN
- COALESCE(b.body_rank, 0)
- ELSE 0
- END as rank
- FROM metadata_scores m
- FULL OUTER JOIN body_scores b ON m.document_id = b.document_id
- WHERE (
- ($1 = '') OR
- ({str(settings.search_over_metadata).lower()} AND m.metadata_rank > 0) OR
- ({str(settings.search_over_body).lower()} AND b.body_rank > 0)
- )
- """
- # Add any additional filters
- if settings.filters:
- filter_clause, params = apply_filters(settings.filters, params)
- where_clauses.append(filter_clause)
- if where_clauses:
- query += f" AND {' AND '.join(where_clauses)}"
- query += """
- )
- SELECT
- document_id,
- metadata,
- rank as score,
- debug_metadata_rank,
- debug_body_rank
- FROM combined_scores
- WHERE rank > 0
- ORDER BY rank DESC
- OFFSET ${offset_param} LIMIT ${limit_param}
- """.format(
- offset_param=len(params) + 1,
- limit_param=len(params) + 2,
- )
- # Add offset and limit to params
- params.extend([settings.offset, settings.limit])
- # Execute query
- results = await self.connection_manager.fetch_query(query, params)
- # Format results with complete document metadata
- return [
- {
- "document_id": str(r["document_id"]),
- "metadata": (
- json.loads(r["metadata"])
- if isinstance(r["metadata"], str)
- else r["metadata"]
- ),
- "score": float(r["score"]),
- "debug_metadata_rank": float(r["debug_metadata_rank"]),
- "debug_body_rank": float(r["debug_body_rank"]),
- }
- for r in results
- ]
- def _get_index_options(
- self,
- method: IndexMethod,
- index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW],
- ) -> str:
- if method == IndexMethod.ivfflat:
- if isinstance(index_arguments, IndexArgsIVFFlat):
- return f"WITH (lists={index_arguments.n_lists})"
- else:
- # Default value if no arguments provided
- return "WITH (lists=100)"
- elif method == IndexMethod.hnsw:
- if isinstance(index_arguments, IndexArgsHNSW):
- return f"WITH (m={index_arguments.m}, ef_construction={index_arguments.ef_construction})"
- else:
- # Default values if no arguments provided
- return "WITH (m=16, ef_construction=64)"
- else:
- return "" # No options for other methods
|