2fac23e4d91b_migrate_to_document_search.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. """migrate_to_document_search.
  2. Revision ID: 2fac23e4d91b
  3. Revises:
  4. Create Date: 2024-11-11 11:55:49.461015
  5. """
  6. import asyncio
  7. import json
  8. import os
  9. from concurrent.futures import ThreadPoolExecutor
  10. from typing import Sequence, Union
  11. import sqlalchemy as sa
  12. from alembic import op
  13. from sqlalchemy import inspect
  14. from sqlalchemy.types import UserDefinedType
  15. from r2r import R2RAsyncClient
  16. # revision identifiers, used by Alembic.
  17. revision: str = "2fac23e4d91b"
  18. down_revision: Union[str, None] = "d342e632358a"
  19. branch_labels: Union[str, Sequence[str], None] = None
  20. depends_on: Union[str, Sequence[str], None] = None
  21. project_name = os.getenv("R2R_PROJECT_NAME")
  22. if not project_name:
  23. raise ValueError(
  24. "Environment variable `R2R_PROJECT_NAME` must be provided migrate, it should be set equal to the value of `project_name` in your `r2r.toml`."
  25. )
  26. dimension = os.getenv("R2R_EMBEDDING_DIMENSION")
  27. if not dimension:
  28. raise ValueError(
  29. "Environment variable `R2R_EMBEDDING_DIMENSION` must be provided migrate, it must should be set equal to the value of `base_dimension` in your `r2r.toml`."
  30. )
  31. class Vector(UserDefinedType):
  32. def get_col_spec(self, **kw):
  33. return f"vector({dimension})"
  34. def run_async(coroutine):
  35. """Helper function to run async code synchronously."""
  36. with ThreadPoolExecutor() as pool:
  37. return pool.submit(asyncio.run, coroutine).result()
  38. async def async_generate_all_summaries():
  39. """Asynchronous function to generate summaries."""
  40. base_url = os.getenv("R2R_BASE_URL")
  41. if not base_url:
  42. raise ValueError(
  43. "Environment variable `R2R_BASE_URL` must be provided, it must point at the R2R deployment you wish to migrate, e.g. `http://localhost:7272`."
  44. )
  45. print(f"Using R2R Base URL: {base_url})")
  46. base_model = os.getenv("R2R_BASE_MODEL")
  47. if not base_model:
  48. raise ValueError(
  49. "Environment variable `R2R_BASE_MODEL` must be provided, e.g. `openai/gpt-4o-mini`, it will be used for generating document summaries during migration."
  50. )
  51. print(f"Using R2R Base Model: {base_model}")
  52. client = R2RAsyncClient(base_url)
  53. offset = 0
  54. limit = 1_000
  55. documents = (await client.documents_overview(offset=offset, limit=limit))[
  56. "results"
  57. ]
  58. while len(documents) == limit:
  59. limit += offset
  60. documents += (
  61. await client.documents_overview(offset=offset, limit=limit)
  62. )["results"]
  63. # Load existing summaries if they exist
  64. document_summaries = {}
  65. if os.path.exists("document_summaries.json"):
  66. try:
  67. with open("document_summaries.json", "r") as f:
  68. document_summaries = json.load(f)
  69. print(
  70. f"Loaded {len(document_summaries)} existing document summaries"
  71. )
  72. except json.JSONDecodeError:
  73. print(
  74. "Existing document_summaries.json was invalid, starting fresh"
  75. )
  76. document_summaries = {}
  77. for document in documents:
  78. title = document["title"]
  79. doc_id = str(
  80. document["id"]
  81. ) # Convert UUID to string for JSON compatibility
  82. # Skip if document already has a summary
  83. if doc_id in document_summaries:
  84. print(
  85. f"Skipping document {title} ({doc_id}) - summary already exists"
  86. )
  87. continue
  88. print(f"Processing document: {title} ({doc_id})")
  89. try:
  90. document_text = f"Document Title:{title}\n"
  91. if document["metadata"]:
  92. metadata = json.dumps(document["metadata"])
  93. document_text += f"Document Metadata:\n{metadata}\n"
  94. full_chunks = (
  95. await client.document_chunks(document["id"], limit=10)
  96. )["results"]
  97. document_text += "Document Content:\n"
  98. for chunk in full_chunks:
  99. document_text += chunk["text"]
  100. summary_prompt = """## Task:
  101. Your task is to generate a descriptive summary of the document that follows. Your objective is to return a summary that is roughly 10% of the input document size while retaining as many key points as possible. Your response should begin with `The document contains `.
  102. ### Document:
  103. {document}
  104. ### Query:
  105. Reminder: Your task is to generate a descriptive summary of the document that was given. Your objective is to return a summary that is roughly 10% of the input document size while retaining as many key points as possible. Your response should begin with `The document contains `.
  106. ## Response:"""
  107. messages = [
  108. {
  109. "role": "user",
  110. "content": summary_prompt.format(
  111. **{"document": document_text}
  112. ),
  113. }
  114. ]
  115. summary = await client.completion(
  116. messages=messages, generation_config={"model": base_model}
  117. )
  118. summary_text = summary["results"]["choices"][0]["message"][
  119. "content"
  120. ]
  121. embedding_vector = await client.embedding(summary_text)
  122. # embedding_response = await openai_client.embeddings.create(
  123. # model=embedding_model, input=summary_text, dimensions=dimension
  124. # )
  125. # embedding_vector = embedding_response.data[0].embedding
  126. # Store in our results dictionary
  127. document_summaries[doc_id] = {
  128. "summary": summary_text,
  129. "embedding": embedding_vector,
  130. }
  131. # Save after each document
  132. with open("document_summaries.json", "w") as f:
  133. json.dump(document_summaries, f)
  134. print(f"Successfully processed document {doc_id}")
  135. except Exception as e:
  136. print(f"Error processing document {doc_id}: {str(e)}")
  137. # Continue with next document instead of failing
  138. continue
  139. return document_summaries
  140. def generate_all_summaries():
  141. """Synchronous wrapper for async_generate_all_summaries."""
  142. return run_async(async_generate_all_summaries())
  143. def check_if_upgrade_needed():
  144. """Check if the upgrade has already been applied or is needed."""
  145. # Get database connection
  146. connection = op.get_bind()
  147. inspector = inspect(connection)
  148. # First check if the document_info table exists
  149. if not inspector.has_table("document_info", schema=project_name):
  150. print(
  151. f"Migration not needed: '{project_name}.document_info' table doesn't exist yet"
  152. )
  153. return False
  154. # Then check if the columns exist
  155. existing_columns = [
  156. col["name"]
  157. for col in inspector.get_columns("document_info", schema=project_name)
  158. ]
  159. needs_upgrade = "summary" not in existing_columns
  160. if needs_upgrade:
  161. print(
  162. "Migration needed: 'summary' column does not exist in document_info table"
  163. )
  164. else:
  165. print(
  166. "Migration not needed: 'summary' column already exists in document_info table"
  167. )
  168. return needs_upgrade
  169. def upgrade() -> None:
  170. if check_if_upgrade_needed():
  171. # Load the document summaries
  172. generate_all_summaries()
  173. document_summaries = None
  174. try:
  175. with open("document_summaries.json", "r") as f:
  176. document_summaries = json.load(f)
  177. print(f"Loaded {len(document_summaries)} document summaries")
  178. except FileNotFoundError:
  179. print(
  180. "document_summaries.json not found. Continuing without summaries and/or summary embeddings."
  181. )
  182. pass
  183. except json.JSONDecodeError:
  184. raise ValueError("Invalid document_summaries.json file") from None
  185. # Create the vector extension if it doesn't exist
  186. op.execute("CREATE EXTENSION IF NOT EXISTS vector")
  187. # Add new columns to document_info
  188. op.add_column(
  189. "document_info",
  190. sa.Column("summary", sa.Text(), nullable=True),
  191. schema=project_name,
  192. )
  193. op.add_column(
  194. "document_info",
  195. sa.Column("summary_embedding", Vector, nullable=True),
  196. schema=project_name,
  197. )
  198. # Add generated column for full text search
  199. op.execute(f"""
  200. ALTER TABLE {project_name}.document_info
  201. ADD COLUMN doc_search_vector tsvector
  202. GENERATED ALWAYS AS (
  203. setweight(to_tsvector('english', COALESCE(title, '')), 'A') ||
  204. setweight(to_tsvector('english', COALESCE(summary, '')), 'B') ||
  205. setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C')
  206. ) STORED;
  207. """)
  208. # Create index for full text search
  209. op.execute(f"""
  210. CREATE INDEX idx_doc_search_{project_name}
  211. ON {project_name}.document_info
  212. USING GIN (doc_search_vector);
  213. """)
  214. if document_summaries:
  215. # Update existing documents with summaries and embeddings
  216. for doc_id, doc_data in document_summaries.items():
  217. # Convert the embedding array to the PostgreSQL vector format
  218. embedding_str = (
  219. f"[{','.join(str(x) for x in doc_data['embedding'])}]"
  220. )
  221. # Use plain SQL with proper escaping for PostgreSQL
  222. op.execute(f"""
  223. UPDATE {project_name}.document_info
  224. SET
  225. summary = '{doc_data["summary"].replace("'", "''")}',
  226. summary_embedding = '{embedding_str}'::vector({dimension})
  227. WHERE document_id = '{doc_id}'::uuid;
  228. """)
  229. else:
  230. print(
  231. "No document summaries found, skipping update of existing documents"
  232. )
  233. def downgrade() -> None:
  234. # First drop any dependencies on the columns we want to remove
  235. op.execute(f"""
  236. -- Drop the full text search index first
  237. DROP INDEX IF EXISTS {project_name}.idx_doc_search_{project_name};
  238. -- Drop the generated column that depends on the summary column
  239. ALTER TABLE {project_name}.document_info
  240. DROP COLUMN IF EXISTS doc_search_vector;
  241. """)
  242. # Now we can safely drop the summary and embedding columns
  243. op.drop_column("document_info", "summary_embedding", schema=project_name)
  244. op.drop_column("document_info", "summary", schema=project_name)