__init__.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import re
  2. from typing import Set, Tuple
  3. from shared.utils.base_utils import (
  4. SearchResultsCollector,
  5. SSEFormatter,
  6. convert_nonserializable_objects,
  7. deep_update,
  8. dump_collector,
  9. dump_obj,
  10. format_search_results_for_llm,
  11. generate_default_user_collection_id,
  12. generate_document_id,
  13. generate_extraction_id,
  14. generate_id,
  15. generate_user_id,
  16. num_tokens,
  17. num_tokens_from_messages,
  18. update_settings_from_dict,
  19. validate_uuid,
  20. yield_sse_event,
  21. )
  22. from shared.utils.splitter.text import (
  23. RecursiveCharacterTextSplitter,
  24. TextSplitter,
  25. )
  26. def extract_citations(text: str) -> list[str]:
  27. """
  28. Extract citation IDs enclosed in brackets like [abc1234].
  29. Returns a list of citation IDs.
  30. Args:
  31. text: The text to search for citations. If None, returns an empty list.
  32. Returns:
  33. List of citation IDs matching the pattern [A-Za-z0-9]{7,8}
  34. """
  35. # Handle None or empty input
  36. if text is None or text == "":
  37. return []
  38. # Direct pattern to match IDs inside brackets with alphanumeric pattern
  39. CITATION_PATTERN = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
  40. sids = []
  41. for match in CITATION_PATTERN.finditer(text):
  42. sid = match.group(1)
  43. sids.append(sid)
  44. return sids
  45. def extract_citation_spans(text: str) -> dict[str, list[Tuple[int, int]]]:
  46. """
  47. Extract citation IDs with their positions in the text.
  48. Args:
  49. text: The text to search for citations. If None, returns an empty dict.
  50. Returns:
  51. Dictionary mapping citation IDs to lists of (start, end) position tuples,
  52. where start is the position of the opening bracket and end is the position
  53. just after the closing bracket.
  54. """
  55. # Handle None or empty input
  56. if text is None or text == "":
  57. return {}
  58. # Use the same pattern as the original extract_citations
  59. CITATION_PATTERN = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
  60. citation_spans: dict = {}
  61. for match in CITATION_PATTERN.finditer(text):
  62. sid = match.group(1)
  63. start = match.start()
  64. end = match.end()
  65. if sid not in citation_spans:
  66. citation_spans[sid] = []
  67. # Add the position span
  68. citation_spans[sid].append((start, end))
  69. return citation_spans
  70. class CitationTracker:
  71. """
  72. Tracks citation spans to ensure proper consolidation and deduplication.
  73. This class serves two purposes:
  74. 1. Tracking which spans have already been processed to avoid duplicate emissions
  75. 2. Maintaining a consolidated record of all citation spans for final answers
  76. The is_new_span method both checks if a span is new AND marks it as processed
  77. if it is new, which is important to understand when using this class.
  78. """
  79. def __init__(self):
  80. # Track which citation spans we've processed
  81. # Format: {citation_id: {(start, end), (start, end), ...}}
  82. self.processed_spans: dict[str, Set[Tuple[int, int]]] = {}
  83. # Track which citation IDs we've seen
  84. self.seen_citation_ids: Set[str] = set()
  85. def is_new_citation(self, citation_id: str) -> bool:
  86. """
  87. Check if this is the first occurrence of this citation ID.
  88. Args:
  89. citation_id: The citation ID to check
  90. Returns:
  91. True if this is the first time seeing this citation ID, False otherwise.
  92. Also adds the ID to seen_citation_ids if it's new.
  93. """
  94. if citation_id is None or citation_id == "":
  95. return False
  96. is_new = citation_id not in self.seen_citation_ids
  97. if is_new:
  98. self.seen_citation_ids.add(citation_id)
  99. return is_new
  100. def is_new_span(self, citation_id: str, span: Tuple[int, int]) -> bool:
  101. """
  102. Check if this span has already been processed for this citation ID.
  103. This method both checks if a span is new AND marks it as processed if it is new.
  104. Args:
  105. citation_id: The citation ID
  106. span: (start, end) position tuple
  107. Returns:
  108. True if this span hasn't been processed yet, False otherwise.
  109. Also adds the span to processed_spans if it's new.
  110. """
  111. # Handle invalid inputs
  112. if citation_id is None or citation_id == "" or span is None:
  113. return False
  114. # Initialize set for this citation ID if needed
  115. if citation_id not in self.processed_spans:
  116. self.processed_spans[citation_id] = set()
  117. # Check if we've seen this span before
  118. if span in self.processed_spans[citation_id]:
  119. return False
  120. # This is a new span, track it
  121. self.processed_spans[citation_id].add(span)
  122. return True
  123. def get_all_spans(self) -> dict[str, list[Tuple[int, int]]]:
  124. """
  125. Get all processed spans for final answer consolidation.
  126. Returns:
  127. Dictionary mapping citation IDs to lists of their (start, end) spans.
  128. """
  129. return {
  130. cid: list(spans) for cid, spans in self.processed_spans.items()
  131. }
  132. def reset(self) -> None:
  133. """
  134. Reset the tracker to its initial empty state.
  135. Useful for testing or when reusing a tracker instance.
  136. """
  137. self.processed_spans.clear()
  138. self.seen_citation_ids.clear()
  139. def find_new_citation_spans(
  140. text: str, tracker: CitationTracker
  141. ) -> dict[str, list[Tuple[int, int]]]:
  142. """
  143. Extract citation spans that haven't been processed yet.
  144. Args:
  145. text: Text to search. If None, returns an empty dict.
  146. tracker: The CitationTracker instance to check against for new spans
  147. Returns:
  148. Dictionary of citation IDs to lists of new (start, end) spans
  149. that haven't been processed by the tracker yet.
  150. """
  151. # Handle None or empty input
  152. if text is None or text == "":
  153. return {}
  154. # Get all citation spans in the text
  155. all_spans = extract_citation_spans(text)
  156. # Filter to only spans we haven't processed yet
  157. new_spans: dict = {}
  158. for cid, spans in all_spans.items():
  159. for span in spans:
  160. if tracker.is_new_span(cid, span):
  161. if cid not in new_spans:
  162. new_spans[cid] = []
  163. new_spans[cid].append(span)
  164. return new_spans
  165. __all__ = [
  166. "format_search_results_for_llm",
  167. "generate_id",
  168. "generate_document_id",
  169. "generate_extraction_id",
  170. "generate_user_id",
  171. "generate_default_user_collection_id",
  172. "validate_uuid",
  173. "yield_sse_event",
  174. "dump_collector",
  175. "dump_obj",
  176. "convert_nonserializable_objects",
  177. "num_tokens",
  178. "num_tokens_from_messages",
  179. "SSEFormatter",
  180. "SearchResultsCollector",
  181. "update_settings_from_dict",
  182. "deep_update",
  183. # Text splitter
  184. "RecursiveCharacterTextSplitter",
  185. "TextSplitter",
  186. "extract_citations",
  187. "extract_citation_spans",
  188. "CitationTracker",
  189. "find_new_citation_spans",
  190. ]