text.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. """
  2. The `vecs.experimental.adapter.text` module provides adapter steps specifically designed for
  3. handling text data. It provides two main classes, `TextEmbedding` and `ParagraphChunker`.
  4. All public classes, enums, and functions are re-exported by `vecs.adapters` module.
  5. """
  6. from typing import Any, Generator, Iterable, Literal, Optional, Tuple
  7. from flupy import flu
  8. from vecs.exc import MissingDependency
  9. from .base import AdapterContext, AdapterStep
  10. TextEmbeddingModel = Literal[
  11. "all-mpnet-base-v2",
  12. "multi-qa-mpnet-base-dot-v1",
  13. "all-distilroberta-v1",
  14. "all-MiniLM-L12-v2",
  15. "multi-qa-distilbert-cos-v1",
  16. "mixedbread-ai/mxbai-embed-large-v1",
  17. "multi-qa-MiniLM-L6-cos-v1",
  18. "paraphrase-multilingual-mpnet-base-v2",
  19. "paraphrase-albert-small-v2",
  20. "paraphrase-multilingual-MiniLM-L12-v2",
  21. "paraphrase-MiniLM-L3-v2",
  22. "distiluse-base-multilingual-cased-v1",
  23. "distiluse-base-multilingual-cased-v2",
  24. ]
  25. class TextEmbedding(AdapterStep):
  26. """
  27. TextEmbedding is an AdapterStep that converts text media into
  28. embeddings using a specified sentence transformers model.
  29. """
  30. def __init__(
  31. self,
  32. *,
  33. model: TextEmbeddingModel,
  34. batch_size: int = 8,
  35. use_auth_token: str = None,
  36. ):
  37. """
  38. Initializes the TextEmbedding adapter with a sentence transformers model.
  39. Args:
  40. model (TextEmbeddingModel): The sentence transformers model to use for embeddings.
  41. batch_size (int): The number of records to encode simultaneously.
  42. use_auth_token (str): The HuggingFace Hub auth token to use for private models.
  43. Raises:
  44. MissingDependency: If the sentence_transformers library is not installed.
  45. """
  46. try:
  47. from sentence_transformers import SentenceTransformer as ST
  48. except ImportError:
  49. raise MissingDependency(
  50. "Missing feature vecs[text_embedding]. Hint: `pip install 'vecs[text_embedding]'`"
  51. )
  52. self.model = ST(model, use_auth_token=use_auth_token)
  53. self._exported_dimension = (
  54. self.model.get_sentence_embedding_dimension()
  55. )
  56. self.batch_size = batch_size
  57. @property
  58. def exported_dimension(self) -> Optional[int]:
  59. """
  60. Returns the dimension of the embeddings produced by the sentence transformers model.
  61. Returns:
  62. int: The dimension of the embeddings.
  63. """
  64. return self._exported_dimension
  65. def __call__(
  66. self,
  67. records: Iterable[Tuple[str, Any, Optional[dict]]],
  68. adapter_context: AdapterContext, # pyright: ignore
  69. ) -> Generator[Tuple[str, Any, dict], None, None]:
  70. """
  71. Converts each media in the records to an embedding and yields the result.
  72. Args:
  73. records: Iterable of tuples each containing an id, a media and an optional dict.
  74. adapter_context: Context of the adapter.
  75. Yields:
  76. Tuple[str, Any, dict]: The id, the embedding, and the metadata.
  77. """
  78. for batch in flu(records).chunk(self.batch_size):
  79. batch_records = [x for x in batch]
  80. media = [text for _, text, _ in batch_records]
  81. embeddings = self.model.encode(media, normalize_embeddings=True)
  82. for (id, _, metadata), embedding in zip(batch_records, embeddings): # type: ignore
  83. yield (id, embedding, metadata or {})
  84. class ParagraphChunker(AdapterStep):
  85. """
  86. ParagraphChunker is an AdapterStep that splits text media into
  87. paragraphs and yields each paragraph as a separate record.
  88. """
  89. def __init__(self, *, skip_during_query: bool):
  90. """
  91. Initializes the ParagraphChunker adapter.
  92. Args:
  93. skip_during_query (bool): Whether to skip chunking during querying.
  94. """
  95. self.skip_during_query = skip_during_query
  96. def __call__(
  97. self,
  98. records: Iterable[Tuple[str, Any, Optional[dict]]],
  99. adapter_context: AdapterContext,
  100. ) -> Generator[Tuple[str, Any, dict], None, None]:
  101. """
  102. Splits each media in the records into paragraphs and yields each paragraph
  103. as a separate record. If the `skip_during_query` attribute is set to True,
  104. this step is skipped during querying.
  105. Args:
  106. records (Iterable[Tuple[str, Any, Optional[dict]]]): Iterable of tuples each containing an id, a media and an optional dict.
  107. adapter_context (AdapterContext): Context of the adapter.
  108. Yields:
  109. Tuple[str, Any, dict]: The id appended with paragraph index, the paragraph, and the metadata.
  110. """
  111. if (
  112. adapter_context == AdapterContext("query")
  113. and self.skip_during_query
  114. ):
  115. for id, media, metadata in records:
  116. yield (id, media, metadata or {})
  117. else:
  118. for id, media, metadata in records:
  119. paragraphs = media.split("\n\n")
  120. for paragraph_ix, paragraph in enumerate(paragraphs):
  121. yield (
  122. f"{id}_para_{str(paragraph_ix).zfill(3)}",
  123. paragraph,
  124. metadata or {},
  125. )