2fac23e4d91b_migrate_to_document_search.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. """migrate_to_document_search
  2. Revision ID: 2fac23e4d91b
  3. Revises:
  4. Create Date: 2024-11-11 11:55:49.461015
  5. """
  6. import asyncio
  7. import json
  8. import os
  9. from concurrent.futures import ThreadPoolExecutor
  10. from typing import Sequence, Union
  11. import sqlalchemy as sa
  12. from alembic import op
  13. from openai import AsyncOpenAI
  14. from sqlalchemy import inspect
  15. from sqlalchemy.types import UserDefinedType
  16. from r2r import R2RAsyncClient
  17. # revision identifiers, used by Alembic.
  18. revision: str = "2fac23e4d91b"
  19. down_revision: Union[str, None] = "d342e632358a"
  20. branch_labels: Union[str, Sequence[str], None] = None
  21. depends_on: Union[str, Sequence[str], None] = None
  22. project_name = os.getenv("R2R_PROJECT_NAME")
  23. if not project_name:
  24. raise ValueError(
  25. "Environment variable `R2R_PROJECT_NAME` must be provided migrate, it should be set equal to the value of `project_name` in your `r2r.toml`."
  26. )
  27. dimension = os.getenv("R2R_EMBEDDING_DIMENSION")
  28. if not dimension:
  29. raise ValueError(
  30. "Environment variable `R2R_EMBEDDING_DIMENSION` must be provided migrate, it must should be set equal to the value of `base_dimension` in your `r2r.toml`."
  31. )
  32. class Vector(UserDefinedType):
  33. def get_col_spec(self, **kw):
  34. return f"vector({dimension})"
  35. def run_async(coroutine):
  36. """Helper function to run async code synchronously"""
  37. with ThreadPoolExecutor() as pool:
  38. return pool.submit(asyncio.run, coroutine).result()
  39. async def async_generate_all_summaries():
  40. """Asynchronous function to generate summaries"""
  41. base_url = os.getenv("R2R_BASE_URL")
  42. if not base_url:
  43. raise ValueError(
  44. "Environment variable `R2R_BASE_URL` must be provided, it must point at the R2R deployment you wish to migrate, e.g. `http://localhost:7272`."
  45. )
  46. print(f"Using R2R Base URL: {base_url})")
  47. base_model = os.getenv("R2R_BASE_MODEL")
  48. if not base_model:
  49. raise ValueError(
  50. "Environment variable `R2R_BASE_MODEL` must be provided, e.g. `openai/gpt-4o-mini`, it will be used for generating document summaries during migration."
  51. )
  52. print(f"Using R2R Base Model: {base_model}")
  53. client = R2RAsyncClient(base_url)
  54. offset = 0
  55. limit = 1_000
  56. documents = (await client.documents_overview(offset=offset, limit=limit))[
  57. "results"
  58. ]
  59. while len(documents) == limit:
  60. limit += offset
  61. documents += (
  62. await client.documents_overview(offset=offset, limit=limit)
  63. )["results"]
  64. # Load existing summaries if they exist
  65. document_summaries = {}
  66. if os.path.exists("document_summaries.json"):
  67. try:
  68. with open("document_summaries.json", "r") as f:
  69. document_summaries = json.load(f)
  70. print(
  71. f"Loaded {len(document_summaries)} existing document summaries"
  72. )
  73. except json.JSONDecodeError:
  74. print(
  75. "Existing document_summaries.json was invalid, starting fresh"
  76. )
  77. document_summaries = {}
  78. for document in documents:
  79. title = document["title"]
  80. doc_id = str(
  81. document["id"]
  82. ) # Convert UUID to string for JSON compatibility
  83. # Skip if document already has a summary
  84. if doc_id in document_summaries:
  85. print(
  86. f"Skipping document {title} ({doc_id}) - summary already exists"
  87. )
  88. continue
  89. print(f"Processing document: {title} ({doc_id})")
  90. try:
  91. document_text = f"Document Title:{title}\n"
  92. if document["metadata"]:
  93. metadata = json.dumps(document["metadata"])
  94. document_text += f"Document Metadata:\n{metadata}\n"
  95. full_chunks = (
  96. await client.document_chunks(document["id"], limit=10)
  97. )["results"]
  98. document_text += "Document Content:\n"
  99. for chunk in full_chunks:
  100. document_text += chunk["text"]
  101. summary_prompt = """## Task:
  102. Your task is to generate a descriptive summary of the document that follows. Your objective is to return a summary that is roughly 10% of the input document size while retaining as many key points as possible. Your response should begin with `The document contains `.
  103. ### Document:
  104. {document}
  105. ### Query:
  106. Reminder: Your task is to generate a descriptive summary of the document that was given. Your objective is to return a summary that is roughly 10% of the input document size while retaining as many key points as possible. Your response should begin with `The document contains `.
  107. ## Response:"""
  108. messages = [
  109. {
  110. "role": "user",
  111. "content": summary_prompt.format(
  112. **{"document": document_text}
  113. ),
  114. }
  115. ]
  116. summary = await client.completion(
  117. messages=messages, generation_config={"model": base_model}
  118. )
  119. summary_text = summary["results"]["choices"][0]["message"][
  120. "content"
  121. ]
  122. embedding_vector = await client.embedding(summary_text)
  123. # embedding_response = await openai_client.embeddings.create(
  124. # model=embedding_model, input=summary_text, dimensions=dimension
  125. # )
  126. # embedding_vector = embedding_response.data[0].embedding
  127. # Store in our results dictionary
  128. document_summaries[doc_id] = {
  129. "summary": summary_text,
  130. "embedding": embedding_vector,
  131. }
  132. # Save after each document
  133. with open("document_summaries.json", "w") as f:
  134. json.dump(document_summaries, f)
  135. print(f"Successfully processed document {doc_id}")
  136. except Exception as e:
  137. print(f"Error processing document {doc_id}: {str(e)}")
  138. # Continue with next document instead of failing
  139. continue
  140. return document_summaries
  141. def generate_all_summaries():
  142. """Synchronous wrapper for async_generate_all_summaries"""
  143. return run_async(async_generate_all_summaries())
  144. def check_if_upgrade_needed():
  145. """Check if the upgrade has already been applied by checking for summary column"""
  146. # Get database connection
  147. connection = op.get_bind()
  148. inspector = inspect(connection)
  149. # Check if the columns exist
  150. existing_columns = [
  151. col["name"]
  152. for col in inspector.get_columns(f"document_info", schema=project_name)
  153. ]
  154. needs_upgrade = "summary" not in existing_columns
  155. if needs_upgrade:
  156. print(
  157. "Migration needed: 'summary' column does not exist in document_info table"
  158. )
  159. else:
  160. print(
  161. "Migration not needed: 'summary' column already exists in document_info table"
  162. )
  163. return needs_upgrade
  164. def upgrade() -> None:
  165. if check_if_upgrade_needed():
  166. # Load the document summaries
  167. generate_all_summaries()
  168. document_summaries = None
  169. try:
  170. with open("document_summaries.json", "r") as f:
  171. document_summaries = json.load(f)
  172. print(f"Loaded {len(document_summaries)} document summaries")
  173. except FileNotFoundError:
  174. print(
  175. "document_summaries.json not found. Continuing without summaries and/or summary embeddings."
  176. )
  177. pass
  178. except json.JSONDecodeError:
  179. raise ValueError("Invalid document_summaries.json file")
  180. # Create the vector extension if it doesn't exist
  181. op.execute("CREATE EXTENSION IF NOT EXISTS vector")
  182. # Add new columns to document_info
  183. op.add_column(
  184. "document_info",
  185. sa.Column("summary", sa.Text(), nullable=True),
  186. schema=project_name,
  187. )
  188. op.add_column(
  189. "document_info",
  190. sa.Column("summary_embedding", Vector, nullable=True),
  191. schema=project_name,
  192. )
  193. # Add generated column for full text search
  194. op.execute(
  195. f"""
  196. ALTER TABLE {project_name}.document_info
  197. ADD COLUMN doc_search_vector tsvector
  198. GENERATED ALWAYS AS (
  199. setweight(to_tsvector('english', COALESCE(title, '')), 'A') ||
  200. setweight(to_tsvector('english', COALESCE(summary, '')), 'B') ||
  201. setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C')
  202. ) STORED;
  203. """
  204. )
  205. # Create index for full text search
  206. op.execute(
  207. f"""
  208. CREATE INDEX idx_doc_search_{project_name}
  209. ON {project_name}.document_info
  210. USING GIN (doc_search_vector);
  211. """
  212. )
  213. if document_summaries:
  214. # Update existing documents with summaries and embeddings
  215. for doc_id, doc_data in document_summaries.items():
  216. # Convert the embedding array to the PostgreSQL vector format
  217. embedding_str = (
  218. f"[{','.join(str(x) for x in doc_data['embedding'])}]"
  219. )
  220. # Use plain SQL with proper escaping for PostgreSQL
  221. op.execute(
  222. f"""
  223. UPDATE {project_name}.document_info
  224. SET
  225. summary = '{doc_data['summary'].replace("'", "''")}',
  226. summary_embedding = '{embedding_str}'::vector({dimension})
  227. WHERE document_id = '{doc_id}'::uuid;
  228. """
  229. )
  230. else:
  231. print(
  232. "No document summaries found, skipping update of existing documents"
  233. )
  234. def downgrade() -> None:
  235. # First drop any dependencies on the columns we want to remove
  236. op.execute(
  237. f"""
  238. -- Drop the full text search index first
  239. DROP INDEX IF EXISTS {project_name}.idx_doc_search_{project_name};
  240. -- Drop the generated column that depends on the summary column
  241. ALTER TABLE {project_name}.document_info
  242. DROP COLUMN IF EXISTS doc_search_vector;
  243. """
  244. )
  245. # Now we can safely drop the summary and embedding columns
  246. op.drop_column("document_info", "summary_embedding", schema=project_name)
  247. op.drop_column("document_info", "summary", schema=project_name)