vector.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. """Abstraction for a vector that can be stored in the system."""
  2. from enum import Enum
  3. from typing import Any, Optional
  4. from uuid import UUID
  5. from pydantic import BaseModel, Field
  6. from .base import R2RSerializable
  7. class VectorType(str, Enum):
  8. FIXED = "FIXED"
  9. class IndexMethod(str, Enum):
  10. """
  11. An enum representing the index methods available.
  12. This class currently only supports the 'ivfflat' method but may
  13. expand in the future.
  14. Attributes:
  15. auto (str): Automatically choose the best available index method.
  16. ivfflat (str): The ivfflat index method.
  17. hnsw (str): The hnsw index method.
  18. """
  19. auto = "auto"
  20. ivfflat = "ivfflat"
  21. hnsw = "hnsw"
  22. def __str__(self) -> str:
  23. return self.value
  24. class IndexMeasure(str, Enum):
  25. """
  26. An enum representing the types of distance measures available for indexing.
  27. Attributes:
  28. cosine_distance (str): The cosine distance measure for indexing.
  29. l2_distance (str): The Euclidean (L2) distance measure for indexing.
  30. max_inner_product (str): The maximum inner product measure for indexing.
  31. """
  32. l2_distance = "l2_distance"
  33. max_inner_product = "max_inner_product"
  34. cosine_distance = "cosine_distance"
  35. l1_distance = "l1_distance"
  36. hamming_distance = "hamming_distance"
  37. jaccard_distance = "jaccard_distance"
  38. def __str__(self) -> str:
  39. return self.value
  40. @property
  41. def ops(self) -> str:
  42. return {
  43. IndexMeasure.l2_distance: "_l2_ops",
  44. IndexMeasure.max_inner_product: "_ip_ops",
  45. IndexMeasure.cosine_distance: "_cosine_ops",
  46. IndexMeasure.l1_distance: "_l1_ops",
  47. IndexMeasure.hamming_distance: "_hamming_ops",
  48. IndexMeasure.jaccard_distance: "_jaccard_ops",
  49. }[self]
  50. @property
  51. def pgvector_repr(self) -> str:
  52. return {
  53. IndexMeasure.l2_distance: "<->",
  54. IndexMeasure.max_inner_product: "<#>",
  55. IndexMeasure.cosine_distance: "<=>",
  56. IndexMeasure.l1_distance: "<+>",
  57. IndexMeasure.hamming_distance: "<~>",
  58. IndexMeasure.jaccard_distance: "<%>",
  59. }[self]
  60. class IndexArgsIVFFlat(R2RSerializable):
  61. """
  62. A class for arguments that can optionally be supplied to the index creation
  63. method when building an IVFFlat type index.
  64. Attributes:
  65. nlist (int): The number of IVF centroids that the index should use
  66. """
  67. n_lists: int
  68. class IndexArgsHNSW(R2RSerializable):
  69. """
  70. A class for arguments that can optionally be supplied to the index creation
  71. method when building an HNSW type index.
  72. Ref: https://github.com/pgvector/pgvector#index-options
  73. Both attributes are Optional in case the user only wants to specify one and
  74. leave the other as default
  75. Attributes:
  76. m (int): Maximum number of connections per node per layer (default: 16)
  77. ef_construction (int): Size of the dynamic candidate list for
  78. constructing the graph (default: 64)
  79. """
  80. m: Optional[int] = 16
  81. ef_construction: Optional[int] = 64
  82. class VectorTableName(str, Enum):
  83. """
  84. This enum represents the different tables where we store vectors.
  85. """
  86. CHUNKS = "chunks"
  87. ENTITIES_DOCUMENT = "documents_entities"
  88. GRAPHS_ENTITIES = "graphs_entities"
  89. # TODO: Add support for relationships
  90. # TRIPLES = "relationship"
  91. COMMUNITIES = "graphs_communities"
  92. def __str__(self) -> str:
  93. return self.value
  94. class VectorQuantizationType(str, Enum):
  95. """
  96. An enum representing the types of quantization available for vectors.
  97. Attributes:
  98. FP32 (str): 32-bit floating point quantization.
  99. FP16 (str): 16-bit floating point quantization.
  100. INT1 (str): 1-bit integer quantization.
  101. SPARSE (str): Sparse vector quantization.
  102. """
  103. FP32 = "FP32"
  104. FP16 = "FP16"
  105. INT1 = "INT1"
  106. SPARSE = "SPARSE"
  107. def __str__(self) -> str:
  108. return self.value
  109. @property
  110. def db_type(self) -> str:
  111. db_type_mapping = {
  112. "FP32": "vector",
  113. "FP16": "halfvec",
  114. "INT1": "bit",
  115. "SPARSE": "sparsevec",
  116. }
  117. return db_type_mapping[self.value]
  118. class VectorQuantizationSettings(R2RSerializable):
  119. quantization_type: VectorQuantizationType = Field(
  120. default=VectorQuantizationType.FP32
  121. )
  122. class Vector(R2RSerializable):
  123. """A vector with the option to fix the number of elements."""
  124. data: list[float]
  125. type: VectorType = Field(default=VectorType.FIXED)
  126. length: int = Field(default=-1)
  127. def __init__(self, **data):
  128. super().__init__(**data)
  129. if (
  130. self.type == VectorType.FIXED
  131. and self.length > 0
  132. and len(self.data) != self.length
  133. ):
  134. raise ValueError(
  135. f"Vector must be exactly {self.length} elements long."
  136. )
  137. def __repr__(self) -> str:
  138. return (
  139. f"Vector(data={self.data}, type={self.type}, length={self.length})"
  140. )
  141. class VectorEntry(R2RSerializable):
  142. """A vector entry that can be stored directly in supported vector databases."""
  143. id: UUID
  144. document_id: UUID
  145. owner_id: UUID
  146. collection_ids: list[UUID]
  147. vector: Vector
  148. text: str
  149. metadata: dict[str, Any]
  150. def __str__(self) -> str:
  151. """Return a string representation of the VectorEntry."""
  152. return (
  153. f"VectorEntry("
  154. f"chunk_id={self.id}, "
  155. f"document_id={self.document_id}, "
  156. f"owner_id={self.owner_id}, "
  157. f"collection_ids={self.collection_ids}, "
  158. f"vector={self.vector}, "
  159. f"text={self.text}, "
  160. f"metadata={self.metadata})"
  161. )
  162. def __repr__(self) -> str:
  163. """Return an unambiguous string representation of the VectorEntry."""
  164. return self.__str__()
  165. class StorageResult(R2RSerializable):
  166. """A result of a storage operation."""
  167. success: bool
  168. document_id: UUID
  169. num_chunks: int = 0
  170. error_message: Optional[str] = None
  171. def __str__(self) -> str:
  172. """Return a string representation of the StorageResult."""
  173. return f"StorageResult(success={self.success}, error_message={self.error_message})"
  174. def __repr__(self) -> str:
  175. """Return an unambiguous string representation of the StorageResult."""
  176. return self.__str__()
  177. class IndexConfig(BaseModel):
  178. name: Optional[str] = Field(default=None)
  179. table_name: Optional[str] = Field(default=VectorTableName.CHUNKS)
  180. index_method: Optional[str] = Field(default=IndexMethod.hnsw)
  181. index_measure: Optional[str] = Field(default=IndexMeasure.cosine_distance)
  182. index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = Field(
  183. default=None
  184. )
  185. index_name: Optional[str] = Field(default=None)
  186. index_column: Optional[str] = Field(default=None)
  187. concurrently: Optional[bool] = Field(default=True)