12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319 |
- import copy
- import json
- import logging
- import time
- import uuid
- from typing import Any, Optional, TypedDict
- from uuid import UUID
- import numpy as np
- from core.base import (
- ChunkSearchResult,
- Handler,
- IndexArgsHNSW,
- IndexArgsIVFFlat,
- IndexMeasure,
- IndexMethod,
- R2RException,
- SearchSettings,
- VectorEntry,
- VectorQuantizationType,
- VectorTableName,
- )
- from .base import PostgresConnectionManager
- from .filters import apply_filters
- from .vecs.exc import ArgError, FilterError
- logger = logging.getLogger()
- from core.base.utils import _decorate_vector_type
- def psql_quote_literal(value: str) -> str:
- """
- Safely quote a string literal for PostgreSQL to prevent SQL injection.
- This is a simple implementation - in production, you should use proper parameterization
- or your database driver's quoting functions.
- """
- return "'" + value.replace("'", "''") + "'"
- def index_measure_to_ops(
- measure: IndexMeasure,
- quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
- ):
- return _decorate_vector_type(measure.ops, quantization_type)
- def quantize_vector_to_binary(
- vector: list[float] | np.ndarray,
- threshold: float = 0.0,
- ) -> bytes:
- """
- Quantizes a float vector to a binary vector string for PostgreSQL bit type.
- Used when quantization_type is INT1.
- Args:
- vector (List[float] | np.ndarray): Input vector of floats
- threshold (float, optional): Threshold for binarization. Defaults to 0.0.
- Returns:
- str: Binary string representation for PostgreSQL bit type
- """
- # Convert input to numpy array if it isn't already
- if not isinstance(vector, np.ndarray):
- vector = np.array(vector)
- # Convert to binary (1 where value > threshold, 0 otherwise)
- binary_vector = (vector > threshold).astype(int)
- # Convert to string of 1s and 0s
- # Convert to string of 1s and 0s, then to bytes
- binary_string = "".join(map(str, binary_vector))
- return binary_string.encode("ascii")
- class HybridSearchIntermediateResult(TypedDict):
- semantic_rank: int
- full_text_rank: int
- data: ChunkSearchResult
- rrf_score: float
- class PostgresChunksHandler(Handler):
- TABLE_NAME = VectorTableName.CHUNKS
- def __init__(
- self,
- project_name: str,
- connection_manager: PostgresConnectionManager,
- dimension: int,
- quantization_type: VectorQuantizationType,
- ):
- super().__init__(project_name, connection_manager)
- self.dimension = dimension
- self.quantization_type = quantization_type
- async def create_tables(self):
- # Check for old table name first
- check_query = """
- SELECT EXISTS (
- SELECT FROM pg_tables
- WHERE schemaname = $1
- AND tablename = $2
- );
- """
- old_table_exists = await self.connection_manager.fetch_query(
- check_query, (self.project_name, self.project_name)
- )
- if len(old_table_exists) > 0 and old_table_exists[0]["exists"]:
- raise ValueError(
- f"Found old vector table '{self.project_name}.{self.project_name}'. "
- "Please run `r2r db upgrade` with the CLI, or to run manually, "
- "run in R2R/py/migrations with 'alembic upgrade head' to update "
- "your database schema to the new version."
- )
- binary_col = (
- ""
- if self.quantization_type != VectorQuantizationType.INT1
- else f"vec_binary bit({self.dimension}),"
- )
- query = f"""
- CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (
- id UUID PRIMARY KEY,
- document_id UUID,
- owner_id UUID,
- collection_ids UUID[],
- vec vector({self.dimension}),
- {binary_col}
- text TEXT,
- metadata JSONB,
- fts tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
- );
- CREATE INDEX IF NOT EXISTS idx_vectors_document_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (document_id);
- CREATE INDEX IF NOT EXISTS idx_vectors_owner_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (owner_id);
- CREATE INDEX IF NOT EXISTS idx_vectors_collection_ids ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (collection_ids);
- CREATE INDEX IF NOT EXISTS idx_vectors_text ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (to_tsvector('english', text));
- """
- await self.connection_manager.execute_query(query)
- async def upsert(self, entry: VectorEntry) -> None:
- """
- Upsert function that handles vector quantization only when quantization_type is INT1.
- Matches the table schema where vec_binary column only exists for INT1 quantization.
- """
- # Check the quantization type to determine which columns to use
- if self.quantization_type == VectorQuantizationType.INT1:
- # For quantized vectors, use vec_binary column
- query = f"""
- INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
- VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
- ON CONFLICT (id) DO UPDATE SET
- document_id = EXCLUDED.document_id,
- owner_id = EXCLUDED.owner_id,
- collection_ids = EXCLUDED.collection_ids,
- vec = EXCLUDED.vec,
- vec_binary = EXCLUDED.vec_binary,
- text = EXCLUDED.text,
- metadata = EXCLUDED.metadata;
- """
- await self.connection_manager.execute_query(
- query,
- (
- entry.id,
- entry.document_id,
- entry.owner_id,
- entry.collection_ids,
- str(entry.vector.data),
- quantize_vector_to_binary(
- entry.vector.data
- ), # Convert to binary
- entry.text,
- json.dumps(entry.metadata),
- ),
- )
- else:
- # For regular vectors, use vec column only
- query = f"""
- INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- (id, document_id, owner_id, collection_ids, vec, text, metadata)
- VALUES ($1, $2, $3, $4, $5, $6, $7)
- ON CONFLICT (id) DO UPDATE SET
- document_id = EXCLUDED.document_id,
- owner_id = EXCLUDED.owner_id,
- collection_ids = EXCLUDED.collection_ids,
- vec = EXCLUDED.vec,
- text = EXCLUDED.text,
- metadata = EXCLUDED.metadata;
- """
- await self.connection_manager.execute_query(
- query,
- (
- entry.id,
- entry.document_id,
- entry.owner_id,
- entry.collection_ids,
- str(entry.vector.data),
- entry.text,
- json.dumps(entry.metadata),
- ),
- )
- async def upsert_entries(self, entries: list[VectorEntry]) -> None:
- """
- Batch upsert function that handles vector quantization only when quantization_type is INT1.
- Matches the table schema where vec_binary column only exists for INT1 quantization.
- """
- if self.quantization_type == VectorQuantizationType.INT1:
- # For quantized vectors, use vec_binary column
- query = f"""
- INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
- VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
- ON CONFLICT (id) DO UPDATE SET
- document_id = EXCLUDED.document_id,
- owner_id = EXCLUDED.owner_id,
- collection_ids = EXCLUDED.collection_ids,
- vec = EXCLUDED.vec,
- vec_binary = EXCLUDED.vec_binary,
- text = EXCLUDED.text,
- metadata = EXCLUDED.metadata;
- """
- bin_params = [
- (
- entry.id,
- entry.document_id,
- entry.owner_id,
- entry.collection_ids,
- str(entry.vector.data),
- quantize_vector_to_binary(
- entry.vector.data
- ), # Convert to binary
- entry.text,
- json.dumps(entry.metadata),
- )
- for entry in entries
- ]
- await self.connection_manager.execute_many(query, bin_params)
- else:
- # For regular vectors, use vec column only
- query = f"""
- INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- (id, document_id, owner_id, collection_ids, vec, text, metadata)
- VALUES ($1, $2, $3, $4, $5, $6, $7)
- ON CONFLICT (id) DO UPDATE SET
- document_id = EXCLUDED.document_id,
- owner_id = EXCLUDED.owner_id,
- collection_ids = EXCLUDED.collection_ids,
- vec = EXCLUDED.vec,
- text = EXCLUDED.text,
- metadata = EXCLUDED.metadata;
- """
- params = [
- (
- entry.id,
- entry.document_id,
- entry.owner_id,
- entry.collection_ids,
- str(entry.vector.data),
- entry.text,
- json.dumps(entry.metadata),
- )
- for entry in entries
- ]
- await self.connection_manager.execute_many(query, params)
- async def semantic_search(
- self, query_vector: list[float], search_settings: SearchSettings
- ) -> list[ChunkSearchResult]:
- try:
- imeasure_obj = IndexMeasure(
- search_settings.chunk_settings.index_measure
- )
- except ValueError:
- raise ValueError("Invalid index measure")
- table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
- cols = [
- f"{table_name}.id",
- f"{table_name}.document_id",
- f"{table_name}.owner_id",
- f"{table_name}.collection_ids",
- f"{table_name}.text",
- ]
- params: list[str | int | bytes] = []
- # For binary vectors (INT1), implement two-stage search
- if self.quantization_type == VectorQuantizationType.INT1:
- # Convert query vector to binary format
- binary_query = quantize_vector_to_binary(query_vector)
- # TODO - Put depth multiplier in config / settings
- extended_limit = (
- search_settings.limit * 20
- ) # Get 20x candidates for re-ranking
- if (
- imeasure_obj == IndexMeasure.hamming_distance
- or imeasure_obj == IndexMeasure.jaccard_distance
- ):
- binary_search_measure_repr = imeasure_obj.pgvector_repr
- else:
- binary_search_measure_repr = (
- IndexMeasure.hamming_distance.pgvector_repr
- )
- # Use binary column and binary-specific distance measures for first stage
- stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit({self.dimension})"
- stage1_param = binary_query
- cols.append(
- f"{table_name}.vec"
- ) # Need original vector for re-ranking
- if search_settings.include_metadatas:
- cols.append(f"{table_name}.metadata")
- select_clause = ", ".join(cols)
- where_clause = ""
- params.append(stage1_param)
- if search_settings.filters:
- where_clause, params = apply_filters(
- search_settings.filters, params, mode="where_clause"
- )
- # First stage: Get candidates using binary search
- query = f"""
- WITH candidates AS (
- SELECT {select_clause},
- ({stage1_distance}) as binary_distance
- FROM {table_name}
- {where_clause}
- ORDER BY {stage1_distance}
- LIMIT ${len(params) + 1}
- OFFSET ${len(params) + 2}
- )
- -- Second stage: Re-rank using original vectors
- SELECT
- id,
- document_id,
- owner_id,
- collection_ids,
- text,
- {"metadata," if search_settings.include_metadatas else ""}
- (vec <=> ${len(params) + 4}::vector({self.dimension})) as distance
- FROM candidates
- ORDER BY distance
- LIMIT ${len(params) + 3}
- """
- params.extend(
- [
- extended_limit, # First stage limit
- search_settings.offset,
- search_settings.limit, # Final limit
- str(query_vector), # For re-ranking
- ]
- )
- else:
- # Standard float vector handling
- distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector({self.dimension})"
- query_param = str(query_vector)
- if search_settings.include_scores:
- cols.append(f"({distance_calc}) AS distance")
- if search_settings.include_metadatas:
- cols.append(f"{table_name}.metadata")
- select_clause = ", ".join(cols)
- where_clause = ""
- params.append(query_param)
- if search_settings.filters:
- where_clause, new_params = apply_filters(
- search_settings.filters,
- params,
- mode="where_clause", # Get just conditions without WHERE
- )
- params = new_params
- query = f"""
- SELECT {select_clause}
- FROM {table_name}
- {where_clause}
- ORDER BY {distance_calc}
- LIMIT ${len(params) + 1}
- OFFSET ${len(params) + 2}
- """
- params.extend([search_settings.limit, search_settings.offset])
- results = await self.connection_manager.fetch_query(query, params)
- return [
- ChunkSearchResult(
- id=UUID(str(result["id"])),
- document_id=UUID(str(result["document_id"])),
- owner_id=UUID(str(result["owner_id"])),
- collection_ids=result["collection_ids"],
- text=result["text"],
- score=(
- (1 - float(result["distance"]))
- if "distance" in result
- else -1
- ),
- metadata=(
- json.loads(result["metadata"])
- if search_settings.include_metadatas
- else {}
- ),
- )
- for result in results
- ]
- async def full_text_search(
- self, query_text: str, search_settings: SearchSettings
- ) -> list[ChunkSearchResult]:
- conditions = []
- params: list[str | int | bytes] = [query_text]
- conditions.append("fts @@ websearch_to_tsquery('english', $1)")
- if search_settings.filters:
- filter_condition, params = apply_filters(
- search_settings.filters, params, mode="condition_only"
- )
- if filter_condition:
- conditions.append(filter_condition)
- where_clause = "WHERE " + " AND ".join(conditions)
- query = f"""
- SELECT
- id,
- document_id,
- owner_id,
- collection_ids,
- text,
- metadata,
- ts_rank(fts, websearch_to_tsquery('english', $1), 32) as rank
- FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- {where_clause}
- ORDER BY rank DESC
- OFFSET ${len(params)+1}
- LIMIT ${len(params)+2}
- """
- params.extend(
- [
- search_settings.offset,
- search_settings.hybrid_settings.full_text_limit,
- ]
- )
- results = await self.connection_manager.fetch_query(query, params)
- return [
- ChunkSearchResult(
- id=UUID(str(r["id"])),
- document_id=UUID(str(r["document_id"])),
- owner_id=UUID(str(r["owner_id"])),
- collection_ids=r["collection_ids"],
- text=r["text"],
- score=float(r["rank"]),
- metadata=json.loads(r["metadata"]),
- )
- for r in results
- ]
- async def hybrid_search(
- self,
- query_text: str,
- query_vector: list[float],
- search_settings: SearchSettings,
- *args,
- **kwargs,
- ) -> list[ChunkSearchResult]:
- if search_settings.hybrid_settings is None:
- raise ValueError(
- "Please provide a valid `hybrid_settings` in the `search_settings`."
- )
- if (
- search_settings.hybrid_settings.full_text_limit
- < search_settings.limit
- ):
- raise ValueError(
- "The `full_text_limit` must be greater than or equal to the `limit`."
- )
- semantic_settings = copy.deepcopy(search_settings)
- semantic_settings.limit += search_settings.offset
- full_text_settings = copy.deepcopy(search_settings)
- full_text_settings.hybrid_settings.full_text_limit += (
- search_settings.offset
- )
- semantic_results: list[ChunkSearchResult] = await self.semantic_search(
- query_vector, semantic_settings
- )
- full_text_results: list[ChunkSearchResult] = (
- await self.full_text_search(query_text, full_text_settings)
- )
- semantic_limit = search_settings.limit
- full_text_limit = search_settings.hybrid_settings.full_text_limit
- semantic_weight = search_settings.hybrid_settings.semantic_weight
- full_text_weight = search_settings.hybrid_settings.full_text_weight
- rrf_k = search_settings.hybrid_settings.rrf_k
- combined_results: dict[uuid.UUID, HybridSearchIntermediateResult] = {}
- for rank, result in enumerate(semantic_results, 1):
- combined_results[result.id] = {
- "semantic_rank": rank,
- "full_text_rank": full_text_limit,
- "data": result,
- "rrf_score": 0.0, # Initialize with 0, will be calculated later
- }
- for rank, result in enumerate(full_text_results, 1):
- if result.id in combined_results:
- combined_results[result.id]["full_text_rank"] = rank
- else:
- combined_results[result.id] = {
- "semantic_rank": semantic_limit,
- "full_text_rank": rank,
- "data": result,
- "rrf_score": 0.0, # Initialize with 0, will be calculated later
- }
- combined_results = {
- k: v
- for k, v in combined_results.items()
- if v["semantic_rank"] <= semantic_limit * 2
- and v["full_text_rank"] <= full_text_limit * 2
- }
- for hyb_result in combined_results.values():
- semantic_score = 1 / (rrf_k + hyb_result["semantic_rank"])
- full_text_score = 1 / (rrf_k + hyb_result["full_text_rank"])
- hyb_result["rrf_score"] = (
- semantic_score * semantic_weight
- + full_text_score * full_text_weight
- ) / (semantic_weight + full_text_weight)
- sorted_results = sorted(
- combined_results.values(),
- key=lambda x: x["rrf_score"],
- reverse=True,
- )
- offset_results = sorted_results[
- search_settings.offset : search_settings.offset
- + search_settings.limit
- ]
- return [
- ChunkSearchResult(
- id=result["data"].id,
- document_id=result["data"].document_id,
- owner_id=result["data"].owner_id,
- collection_ids=result["data"].collection_ids,
- text=result["data"].text,
- score=result["rrf_score"],
- metadata={
- **result["data"].metadata,
- "semantic_rank": result["semantic_rank"],
- "full_text_rank": result["full_text_rank"],
- },
- )
- for result in offset_results
- ]
- async def delete(
- self, filters: dict[str, Any]
- ) -> dict[str, dict[str, str]]:
- params: list[str | int | bytes] = []
- where_clause, params = apply_filters(
- filters, params, mode="condition_only"
- )
- query = f"""
- DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- WHERE {where_clause}
- RETURNING id, document_id, text;
- """
- results = await self.connection_manager.fetch_query(query, params)
- return {
- str(result["id"]): {
- "status": "deleted",
- "id": str(result["id"]),
- "document_id": str(result["document_id"]),
- "text": result["text"],
- }
- for result in results
- }
- async def assign_document_chunks_to_collection(
- self, document_id: UUID, collection_id: UUID
- ) -> None:
- query = f"""
- UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- SET collection_ids = array_append(collection_ids, $1)
- WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids));
- """
- result = await self.connection_manager.execute_query(
- query, (str(collection_id), str(document_id))
- )
- return result
- async def remove_document_from_collection_vector(
- self, document_id: UUID, collection_id: UUID
- ) -> None:
- query = f"""
- UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- SET collection_ids = array_remove(collection_ids, $1)
- WHERE document_id = $2;
- """
- await self.connection_manager.execute_query(
- query, (collection_id, document_id)
- )
- async def delete_user_vector(self, owner_id: UUID) -> None:
- query = f"""
- DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- WHERE owner_id = $1;
- """
- await self.connection_manager.execute_query(query, (owner_id,))
- async def delete_collection_vector(self, collection_id: UUID) -> None:
- query = f"""
- DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- WHERE $1 = ANY(collection_ids)
- RETURNING collection_ids
- """
- results = await self.connection_manager.fetchrow_query(
- query, (collection_id,)
- )
- return None
- async def list_document_chunks(
- self,
- document_id: UUID,
- offset: int,
- limit: int,
- include_vectors: bool = False,
- ) -> dict[str, Any]:
- vector_select = ", vec" if include_vectors else ""
- limit_clause = f"LIMIT {limit}" if limit > -1 else ""
- query = f"""
- SELECT id, document_id, owner_id, collection_ids, text, metadata{vector_select}, COUNT(*) OVER() AS total
- FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- WHERE document_id = $1
- ORDER BY (metadata->>'chunk_order')::integer
- OFFSET $2
- {limit_clause};
- """
- params = [document_id, offset]
- results = await self.connection_manager.fetch_query(query, params)
- chunks = []
- total = 0
- if results:
- total = results[0].get("total", 0)
- chunks = [
- {
- "id": result["id"],
- "document_id": result["document_id"],
- "owner_id": result["owner_id"],
- "collection_ids": result["collection_ids"],
- "text": result["text"],
- "metadata": json.loads(result["metadata"]),
- "vector": (
- json.loads(result["vec"]) if include_vectors else None
- ),
- }
- for result in results
- ]
- return {"results": chunks, "total_entries": total}
- async def get_chunk(self, id: UUID) -> dict:
- query = f"""
- SELECT id, document_id, owner_id, collection_ids, text, metadata
- FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- WHERE id = $1;
- """
- result = await self.connection_manager.fetchrow_query(query, (id,))
- if result:
- return {
- "id": result["id"],
- "document_id": result["document_id"],
- "owner_id": result["owner_id"],
- "collection_ids": result["collection_ids"],
- "text": result["text"],
- "metadata": json.loads(result["metadata"]),
- }
- raise R2RException(
- message=f"Chunk with ID {id} not found", status_code=404
- )
- async def create_index(
- self,
- table_name: Optional[VectorTableName] = None,
- index_measure: IndexMeasure = IndexMeasure.cosine_distance,
- index_method: IndexMethod = IndexMethod.auto,
- index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = None,
- index_name: Optional[str] = None,
- index_column: Optional[str] = None,
- concurrently: bool = True,
- ) -> None:
- """
- Creates an index for the collection.
- Note:
- When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step
- process that enables performant indexes to be built for large collections with low end
- database hardware.
- Those steps are:
- - Creates a new table with a different name
- - Randomly selects records from the existing table
- - Inserts the random records from the existing table into the new table
- - Creates the requested vector index on the new table
- - Upserts all data from the existing table into the new table
- - Drops the existing table
- - Renames the new table to the existing tables name
- If you create dependencies (like views) on the table that underpins
- a `vecs.Collection` the `create_index` step may require you to drop those dependencies before
- it will succeed.
- Args:
- index_measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'.
- index_method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'.
- index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments
- index_name (str, optional): The name of the index to create. Defaults to None.
- concurrently (bool, optional): Whether to create the index concurrently. Defaults to True.
- Raises:
- ArgError: If an invalid index method is used, or if *replace* is False and an index already exists.
- """
- if table_name == VectorTableName.CHUNKS:
- table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}" # TODO - Fix bug in vector table naming convention
- if index_column:
- col_name = index_column
- else:
- col_name = (
- "vec"
- if (
- index_measure != IndexMeasure.hamming_distance
- and index_measure != IndexMeasure.jaccard_distance
- )
- else "vec_binary"
- )
- elif table_name == VectorTableName.ENTITIES_DOCUMENT:
- table_name_str = (
- f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
- )
- col_name = "description_embedding"
- elif table_name == VectorTableName.GRAPHS_ENTITIES:
- table_name_str = (
- f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
- )
- col_name = "description_embedding"
- elif table_name == VectorTableName.COMMUNITIES:
- table_name_str = (
- f"{self.project_name}.{VectorTableName.COMMUNITIES}"
- )
- col_name = "embedding"
- else:
- raise ArgError("invalid table name")
- if index_method not in (
- IndexMethod.ivfflat,
- IndexMethod.hnsw,
- IndexMethod.auto,
- ):
- raise ArgError("invalid index method")
- if index_arguments:
- # Disallow case where user submits index arguments but uses the
- # IndexMethod.auto index (index build arguments should only be
- # used with a specific index)
- if index_method == IndexMethod.auto:
- raise ArgError(
- "Index build parameters are not allowed when using the IndexMethod.auto index."
- )
- # Disallow case where user specifies one index type but submits
- # index build arguments for the other index type
- if (
- isinstance(index_arguments, IndexArgsHNSW)
- and index_method != IndexMethod.hnsw
- ) or (
- isinstance(index_arguments, IndexArgsIVFFlat)
- and index_method != IndexMethod.ivfflat
- ):
- raise ArgError(
- f"{index_arguments.__class__.__name__} build parameters were supplied but {index_method} index was specified."
- )
- if index_method == IndexMethod.auto:
- index_method = IndexMethod.hnsw
- ops = index_measure_to_ops(
- index_measure # , quantization_type=self.quantization_type
- )
- if ops is None:
- raise ArgError("Unknown index measure")
- concurrently_sql = "CONCURRENTLY" if concurrently else ""
- index_name = (
- index_name
- or f"ix_{ops}_{index_method}__{col_name}_{time.strftime('%Y%m%d%H%M%S')}"
- )
- create_index_sql = f"""
- CREATE INDEX {concurrently_sql} {index_name}
- ON {table_name_str}
- USING {index_method} ({col_name} {ops}) {self._get_index_options(index_method, index_arguments)};
- """
- try:
- if concurrently:
- async with (
- self.connection_manager.pool.get_connection() as conn # type: ignore
- ):
- # Disable automatic transaction management
- await conn.execute(
- "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
- )
- await conn.execute(create_index_sql)
- else:
- # Non-concurrent index creation can use normal query execution
- await self.connection_manager.execute_query(create_index_sql)
- except Exception as e:
- raise Exception(f"Failed to create index: {e}")
- return None
- async def list_indices(
- self,
- offset: int,
- limit: int,
- filters: Optional[dict[str, Any]] = None,
- ) -> dict:
- where_clauses = []
- params: list[Any] = [self.project_name] # Start with schema name
- param_count = 1
- # Handle filtering
- if filters:
- if "table_name" in filters:
- where_clauses.append(f"i.tablename = ${param_count + 1}")
- params.append(filters["table_name"])
- param_count += 1
- if "index_method" in filters:
- where_clauses.append(f"am.amname = ${param_count + 1}")
- params.append(filters["index_method"])
- param_count += 1
- if "index_name" in filters:
- where_clauses.append(
- f"LOWER(i.indexname) LIKE LOWER(${param_count + 1})"
- )
- params.append(f"%{filters['index_name']}%")
- param_count += 1
- where_clause = " AND ".join(where_clauses) if where_clauses else ""
- if where_clause:
- where_clause = "AND " + where_clause
- query = f"""
- WITH index_info AS (
- SELECT
- i.indexname as name,
- i.tablename as table_name,
- i.indexdef as definition,
- am.amname as method,
- pg_relation_size(c.oid) as size_in_bytes,
- c.reltuples::bigint as row_estimate,
- COALESCE(psat.idx_scan, 0) as number_of_scans,
- COALESCE(psat.idx_tup_read, 0) as tuples_read,
- COALESCE(psat.idx_tup_fetch, 0) as tuples_fetched,
- COUNT(*) OVER() as total_count
- FROM pg_indexes i
- JOIN pg_class c ON c.relname = i.indexname
- JOIN pg_am am ON c.relam = am.oid
- LEFT JOIN pg_stat_user_indexes psat ON psat.indexrelname = i.indexname
- AND psat.schemaname = i.schemaname
- WHERE i.schemaname = $1
- AND i.indexdef LIKE '%vector%'
- {where_clause}
- )
- SELECT *
- FROM index_info
- ORDER BY name
- LIMIT ${param_count + 1}
- OFFSET ${param_count + 2}
- """
- # Add limit and offset to params
- params.extend([limit, offset])
- results = await self.connection_manager.fetch_query(query, params)
- indices = []
- total_entries = 0
- if results:
- total_entries = results[0]["total_count"]
- for result in results:
- index_info = {
- "name": result["name"],
- "table_name": result["table_name"],
- "definition": result["definition"],
- "size_in_bytes": result["size_in_bytes"],
- "row_estimate": result["row_estimate"],
- "number_of_scans": result["number_of_scans"],
- "tuples_read": result["tuples_read"],
- "tuples_fetched": result["tuples_fetched"],
- }
- indices.append(index_info)
- # Calculate pagination info
- total_pages = (total_entries + limit - 1) // limit if limit > 0 else 1
- current_page = (offset // limit) + 1 if limit > 0 else 1
- page_info = {
- "total_entries": total_entries,
- "total_pages": total_pages,
- "current_page": current_page,
- "limit": limit,
- "offset": offset,
- "has_previous": offset > 0,
- "has_next": offset + limit < total_entries,
- "previous_offset": max(0, offset - limit) if offset > 0 else None,
- "next_offset": (
- offset + limit if offset + limit < total_entries else None
- ),
- }
- return {"indices": indices, "page_info": page_info}
- async def delete_index(
- self,
- index_name: str,
- table_name: Optional[VectorTableName] = None,
- concurrently: bool = True,
- ) -> None:
- """
- Deletes a vector index.
- Args:
- index_name (str): Name of the index to delete
- table_name (VectorTableName, optional): Table the index belongs to
- concurrently (bool): Whether to drop the index concurrently
- Raises:
- ArgError: If table name is invalid or index doesn't exist
- Exception: If index deletion fails
- """
- # Validate table name and get column name
- if table_name == VectorTableName.CHUNKS:
- table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}"
- col_name = "vec"
- elif table_name == VectorTableName.ENTITIES_DOCUMENT:
- table_name_str = (
- f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
- )
- col_name = "description_embedding"
- elif table_name == VectorTableName.GRAPHS_ENTITIES:
- table_name_str = (
- f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
- )
- col_name = "description_embedding"
- elif table_name == VectorTableName.COMMUNITIES:
- table_name_str = (
- f"{self.project_name}.{VectorTableName.COMMUNITIES}"
- )
- col_name = "description_embedding"
- else:
- raise ArgError("invalid table name")
- # Extract schema and base table name
- schema_name, base_table_name = table_name_str.split(".")
- # Verify index exists and is a vector index
- query = """
- SELECT indexdef
- FROM pg_indexes
- WHERE indexname = $1
- AND schemaname = $2
- AND tablename = $3
- AND indexdef LIKE $4
- """
- result = await self.connection_manager.fetchrow_query(
- query, (index_name, schema_name, base_table_name, f"%({col_name}%")
- )
- if not result:
- raise ArgError(
- f"Vector index '{index_name}' does not exist on table {table_name_str}"
- )
- # Drop the index
- concurrently_sql = "CONCURRENTLY" if concurrently else ""
- drop_query = (
- f"DROP INDEX {concurrently_sql} {schema_name}.{index_name}"
- )
- try:
- if concurrently:
- async with (
- self.connection_manager.pool.get_connection() as conn # type: ignore
- ):
- # Disable automatic transaction management
- await conn.execute(
- "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
- )
- await conn.execute(drop_query)
- else:
- await self.connection_manager.execute_query(drop_query)
- except Exception as e:
- raise Exception(f"Failed to delete index: {e}")
- async def get_semantic_neighbors(
- self,
- offset: int,
- limit: int,
- document_id: UUID,
- id: UUID,
- similarity_threshold: float = 0.5,
- ) -> list[dict[str, Any]]:
- table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
- query = f"""
- WITH target_vector AS (
- SELECT vec FROM {table_name}
- WHERE document_id = $1 AND id = $2
- )
- SELECT t.id, t.text, t.metadata, t.document_id, (t.vec <=> tv.vec) AS similarity
- FROM {table_name} t, target_vector tv
- WHERE (t.vec <=> tv.vec) >= $3
- AND t.document_id = $1
- AND t.id != $2
- ORDER BY similarity ASC
- LIMIT $4
- """
- results = await self.connection_manager.fetch_query(
- query,
- (str(document_id), str(id), similarity_threshold, limit),
- )
- return [
- {
- "id": str(r["id"]),
- "text": r["text"],
- "metadata": json.loads(r["metadata"]),
- "document_id": str(r["document_id"]),
- "similarity": float(r["similarity"]),
- }
- for r in results
- ]
- async def list_chunks(
- self,
- offset: int,
- limit: int,
- filters: Optional[dict[str, Any]] = None,
- include_vectors: bool = False,
- ) -> dict[str, Any]:
- """
- List chunks with pagination support.
- Args:
- offset (int, optional): Number of records to skip. Defaults to 0.
- limit (int, optional): Maximum number of records to return. Defaults to 10.
- filters (dict, optional): Dictionary of filters to apply. Defaults to None.
- include_vectors (bool, optional): Whether to include vector data. Defaults to False.
- Returns:
- dict: Dictionary containing:
- - results: List of chunk records
- - total_entries: Total number of chunks matching the filters
- - page_info: Pagination information
- """
- vector_select = ", vec" if include_vectors else ""
- select_clause = f"""
- id, document_id, owner_id, collection_ids,
- text, metadata{vector_select}, COUNT(*) OVER() AS total
- """
- params: list[str | int | bytes] = []
- where_clause = ""
- if filters:
- where_clause, params = apply_filters(
- filters, params, mode="where_clause"
- )
- query = f"""
- SELECT {select_clause}
- FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- {where_clause}
- LIMIT ${len(params) + 1}
- OFFSET ${len(params) + 2}
- """
- params.extend([limit, offset])
- # Execute the query
- results = await self.connection_manager.fetch_query(query, params)
- # Process results
- chunks = []
- total = 0
- if results:
- total = results[0].get("total", 0)
- chunks = [
- {
- "id": str(result["id"]),
- "document_id": str(result["document_id"]),
- "owner_id": str(result["owner_id"]),
- "collection_ids": result["collection_ids"],
- "text": result["text"],
- "metadata": json.loads(result["metadata"]),
- "vector": (
- json.loads(result["vec"]) if include_vectors else None
- ),
- }
- for result in results
- ]
- # Calculate pagination info
- total_pages = (total + limit - 1) // limit if limit > 0 else 1
- current_page = (offset // limit) + 1 if limit > 0 else 1
- page_info = {
- "total_entries": total,
- "total_pages": total_pages,
- "current_page": current_page,
- "limit": limit,
- "offset": offset,
- "has_previous": offset > 0,
- "has_next": offset + limit < total,
- "previous_offset": max(0, offset - limit) if offset > 0 else None,
- "next_offset": offset + limit if offset + limit < total else None,
- }
- return {"results": chunks, "page_info": page_info}
- async def search_documents(
- self,
- query_text: str,
- settings: SearchSettings,
- ) -> list[dict[str, Any]]:
- """
- Search for documents based on their metadata fields and/or body text.
- Joins with documents table to get complete document metadata.
- Args:
- query_text (str): The search query text
- settings (SearchSettings): Search settings including search preferences and filters
- Returns:
- list[dict[str, Any]]: List of documents with their search scores and complete metadata
- """
- where_clauses = []
- params: list[str | int | bytes] = [query_text]
- # Build the dynamic metadata field search expression
- metadata_fields_expr = " || ' ' || ".join(
- [
- f"COALESCE(v.metadata->>{psql_quote_literal(key)}, '')"
- for key in settings.metadata_keys # type: ignore
- ]
- )
- query = f"""
- WITH
- -- Metadata search scores
- metadata_scores AS (
- SELECT DISTINCT ON (v.document_id)
- v.document_id,
- d.metadata as doc_metadata,
- CASE WHEN $1 = '' THEN 0.0
- ELSE
- ts_rank_cd(
- setweight(to_tsvector('english', {metadata_fields_expr}), 'A'),
- websearch_to_tsquery('english', $1),
- 32
- )
- END as metadata_rank
- FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} v
- LEFT JOIN {self._get_table_name('documents')} d ON v.document_id = d.id
- WHERE v.metadata IS NOT NULL
- ),
- -- Body search scores
- body_scores AS (
- SELECT
- document_id,
- AVG(
- ts_rank_cd(
- setweight(to_tsvector('english', COALESCE(text, '')), 'B'),
- websearch_to_tsquery('english', $1),
- 32
- )
- ) as body_rank
- FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
- WHERE $1 != ''
- {f"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if settings.search_over_body else ""}
- GROUP BY document_id
- ),
- -- Combined scores with document metadata
- combined_scores AS (
- SELECT
- COALESCE(m.document_id, b.document_id) as document_id,
- m.doc_metadata as metadata,
- COALESCE(m.metadata_rank, 0) as debug_metadata_rank,
- COALESCE(b.body_rank, 0) as debug_body_rank,
- CASE
- WHEN {str(settings.search_over_metadata).lower()} AND {str(settings.search_over_body).lower()} THEN
- COALESCE(m.metadata_rank, 0) * {settings.metadata_weight} + COALESCE(b.body_rank, 0) * {settings.title_weight}
- WHEN {str(settings.search_over_metadata).lower()} THEN
- COALESCE(m.metadata_rank, 0)
- WHEN {str(settings.search_over_body).lower()} THEN
- COALESCE(b.body_rank, 0)
- ELSE 0
- END as rank
- FROM metadata_scores m
- FULL OUTER JOIN body_scores b ON m.document_id = b.document_id
- WHERE (
- ($1 = '') OR
- ({str(settings.search_over_metadata).lower()} AND m.metadata_rank > 0) OR
- ({str(settings.search_over_body).lower()} AND b.body_rank > 0)
- )
- """
- # Add any additional filters
- if settings.filters:
- filter_clause, params = apply_filters(settings.filters, params)
- where_clauses.append(filter_clause)
- if where_clauses:
- query += f" AND {' AND '.join(where_clauses)}"
- query += """
- )
- SELECT
- document_id,
- metadata,
- rank as score,
- debug_metadata_rank,
- debug_body_rank
- FROM combined_scores
- WHERE rank > 0
- ORDER BY rank DESC
- OFFSET ${offset_param} LIMIT ${limit_param}
- """.format(
- offset_param=len(params) + 1,
- limit_param=len(params) + 2,
- )
- # Add offset and limit to params
- params.extend([settings.offset, settings.limit])
- # Execute query
- results = await self.connection_manager.fetch_query(query, params)
- # Format results with complete document metadata
- return [
- {
- "document_id": str(r["document_id"]),
- "metadata": (
- json.loads(r["metadata"])
- if isinstance(r["metadata"], str)
- else r["metadata"]
- ),
- "score": float(r["score"]),
- "debug_metadata_rank": float(r["debug_metadata_rank"]),
- "debug_body_rank": float(r["debug_body_rank"]),
- }
- for r in results
- ]
- def _get_index_options(
- self,
- method: IndexMethod,
- index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW],
- ) -> str:
- if method == IndexMethod.ivfflat:
- if isinstance(index_arguments, IndexArgsIVFFlat):
- return f"WITH (lists={index_arguments.n_lists})"
- else:
- # Default value if no arguments provided
- return "WITH (lists=100)"
- elif method == IndexMethod.hnsw:
- if isinstance(index_arguments, IndexArgsHNSW):
- return f"WITH (m={index_arguments.m}, ef_construction={index_arguments.ef_construction})"
- else:
- # Default values if no arguments provided
- return "WITH (m=16, ef_construction=64)"
- else:
- return "" # No options for other methods
|