123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310 |
- """migrate_to_document_search
- Revision ID: 2fac23e4d91b
- Revises:
- Create Date: 2024-11-11 11:55:49.461015
- """
- import asyncio
- import json
- import os
- from concurrent.futures import ThreadPoolExecutor
- from typing import Sequence, Union
- import sqlalchemy as sa
- from alembic import op
- from openai import AsyncOpenAI
- from sqlalchemy import inspect
- from sqlalchemy.types import UserDefinedType
- from r2r import R2RAsyncClient
- # revision identifiers, used by Alembic.
- revision: str = "2fac23e4d91b"
- down_revision: Union[str, None] = "d342e632358a"
- branch_labels: Union[str, Sequence[str], None] = None
- depends_on: Union[str, Sequence[str], None] = None
- project_name = os.getenv("R2R_PROJECT_NAME")
- if not project_name:
- raise ValueError(
- "Environment variable `R2R_PROJECT_NAME` must be provided migrate, it should be set equal to the value of `project_name` in your `r2r.toml`."
- )
- dimension = os.getenv("R2R_EMBEDDING_DIMENSION")
- if not dimension:
- raise ValueError(
- "Environment variable `R2R_EMBEDDING_DIMENSION` must be provided migrate, it must should be set equal to the value of `base_dimension` in your `r2r.toml`."
- )
- class Vector(UserDefinedType):
- def get_col_spec(self, **kw):
- return f"vector({dimension})"
- def run_async(coroutine):
- """Helper function to run async code synchronously"""
- with ThreadPoolExecutor() as pool:
- return pool.submit(asyncio.run, coroutine).result()
- async def async_generate_all_summaries():
- """Asynchronous function to generate summaries"""
- base_url = os.getenv("R2R_BASE_URL")
- if not base_url:
- raise ValueError(
- "Environment variable `R2R_BASE_URL` must be provided, it must point at the R2R deployment you wish to migrate, e.g. `http://localhost:7272`."
- )
- print(f"Using R2R Base URL: {base_url})")
- base_model = os.getenv("R2R_BASE_MODEL")
- if not base_model:
- raise ValueError(
- "Environment variable `R2R_BASE_MODEL` must be provided, e.g. `openai/gpt-4o-mini`, it will be used for generating document summaries during migration."
- )
- print(f"Using R2R Base Model: {base_model}")
- client = R2RAsyncClient(base_url)
- offset = 0
- limit = 1_000
- documents = (await client.documents_overview(offset=offset, limit=limit))[
- "results"
- ]
- while len(documents) == limit:
- limit += offset
- documents += (
- await client.documents_overview(offset=offset, limit=limit)
- )["results"]
- # Load existing summaries if they exist
- document_summaries = {}
- if os.path.exists("document_summaries.json"):
- try:
- with open("document_summaries.json", "r") as f:
- document_summaries = json.load(f)
- print(
- f"Loaded {len(document_summaries)} existing document summaries"
- )
- except json.JSONDecodeError:
- print(
- "Existing document_summaries.json was invalid, starting fresh"
- )
- document_summaries = {}
- for document in documents:
- title = document["title"]
- doc_id = str(
- document["id"]
- ) # Convert UUID to string for JSON compatibility
- # Skip if document already has a summary
- if doc_id in document_summaries:
- print(
- f"Skipping document {title} ({doc_id}) - summary already exists"
- )
- continue
- print(f"Processing document: {title} ({doc_id})")
- try:
- document_text = f"Document Title:{title}\n"
- if document["metadata"]:
- metadata = json.dumps(document["metadata"])
- document_text += f"Document Metadata:\n{metadata}\n"
- full_chunks = (
- await client.document_chunks(document["id"], limit=10)
- )["results"]
- document_text += "Document Content:\n"
- for chunk in full_chunks:
- document_text += chunk["text"]
- summary_prompt = """## Task:
- Your task is to generate a descriptive summary of the document that follows. Your objective is to return a summary that is roughly 10% of the input document size while retaining as many key points as possible. Your response should begin with `The document contains `.
- ### Document:
- {document}
- ### Query:
- Reminder: Your task is to generate a descriptive summary of the document that was given. Your objective is to return a summary that is roughly 10% of the input document size while retaining as many key points as possible. Your response should begin with `The document contains `.
- ## Response:"""
- messages = [
- {
- "role": "user",
- "content": summary_prompt.format(
- **{"document": document_text}
- ),
- }
- ]
- summary = await client.completion(
- messages=messages, generation_config={"model": base_model}
- )
- summary_text = summary["results"]["choices"][0]["message"][
- "content"
- ]
- embedding_vector = await client.embedding(summary_text)
- # embedding_response = await openai_client.embeddings.create(
- # model=embedding_model, input=summary_text, dimensions=dimension
- # )
- # embedding_vector = embedding_response.data[0].embedding
- # Store in our results dictionary
- document_summaries[doc_id] = {
- "summary": summary_text,
- "embedding": embedding_vector,
- }
- # Save after each document
- with open("document_summaries.json", "w") as f:
- json.dump(document_summaries, f)
- print(f"Successfully processed document {doc_id}")
- except Exception as e:
- print(f"Error processing document {doc_id}: {str(e)}")
- # Continue with next document instead of failing
- continue
- return document_summaries
- def generate_all_summaries():
- """Synchronous wrapper for async_generate_all_summaries"""
- return run_async(async_generate_all_summaries())
- def check_if_upgrade_needed():
- """Check if the upgrade has already been applied by checking for summary column"""
- # Get database connection
- connection = op.get_bind()
- inspector = inspect(connection)
- # Check if the columns exist
- existing_columns = [
- col["name"]
- for col in inspector.get_columns(f"document_info", schema=project_name)
- ]
- needs_upgrade = "summary" not in existing_columns
- if needs_upgrade:
- print(
- "Migration needed: 'summary' column does not exist in document_info table"
- )
- else:
- print(
- "Migration not needed: 'summary' column already exists in document_info table"
- )
- return needs_upgrade
- def upgrade() -> None:
- if check_if_upgrade_needed():
- # Load the document summaries
- generate_all_summaries()
- document_summaries = None
- try:
- with open("document_summaries.json", "r") as f:
- document_summaries = json.load(f)
- print(f"Loaded {len(document_summaries)} document summaries")
- except FileNotFoundError:
- print(
- "document_summaries.json not found. Continuing without summaries and/or summary embeddings."
- )
- pass
- except json.JSONDecodeError:
- raise ValueError("Invalid document_summaries.json file")
- # Create the vector extension if it doesn't exist
- op.execute("CREATE EXTENSION IF NOT EXISTS vector")
- # Add new columns to document_info
- op.add_column(
- "document_info",
- sa.Column("summary", sa.Text(), nullable=True),
- schema=project_name,
- )
- op.add_column(
- "document_info",
- sa.Column("summary_embedding", Vector, nullable=True),
- schema=project_name,
- )
- # Add generated column for full text search
- op.execute(
- f"""
- ALTER TABLE {project_name}.document_info
- ADD COLUMN doc_search_vector tsvector
- GENERATED ALWAYS AS (
- setweight(to_tsvector('english', COALESCE(title, '')), 'A') ||
- setweight(to_tsvector('english', COALESCE(summary, '')), 'B') ||
- setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C')
- ) STORED;
- """
- )
- # Create index for full text search
- op.execute(
- f"""
- CREATE INDEX idx_doc_search_{project_name}
- ON {project_name}.document_info
- USING GIN (doc_search_vector);
- """
- )
- if document_summaries:
- # Update existing documents with summaries and embeddings
- for doc_id, doc_data in document_summaries.items():
- # Convert the embedding array to the PostgreSQL vector format
- embedding_str = (
- f"[{','.join(str(x) for x in doc_data['embedding'])}]"
- )
- # Use plain SQL with proper escaping for PostgreSQL
- op.execute(
- f"""
- UPDATE {project_name}.document_info
- SET
- summary = '{doc_data['summary'].replace("'", "''")}',
- summary_embedding = '{embedding_str}'::vector({dimension})
- WHERE document_id = '{doc_id}'::uuid;
- """
- )
- else:
- print(
- "No document summaries found, skipping update of existing documents"
- )
- def downgrade() -> None:
- # First drop any dependencies on the columns we want to remove
- op.execute(
- f"""
- -- Drop the full text search index first
- DROP INDEX IF EXISTS {project_name}.idx_doc_search_{project_name};
- -- Drop the generated column that depends on the summary column
- ALTER TABLE {project_name}.document_info
- DROP COLUMN IF EXISTS doc_search_vector;
- """
- )
- # Now we can safely drop the summary and embedding columns
- op.drop_column("document_info", "summary_embedding", schema=project_name)
- op.drop_column("document_info", "summary", schema=project_name)
|