123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- """add_total_tokens_to_documents.
- Revision ID: 3efc7b3b1b3d
- Revises: 7eb70560f406
- Create Date: 2025-01-21 14:59:00.000000
- """
- import logging
- import math
- import os
- import sqlalchemy as sa
- import tiktoken
- from alembic import op
- from sqlalchemy import inspect, text
- # revision identifiers, used by Alembic.
- revision = "3efc7b3b1b3d"
- down_revision = "7eb70560f406"
- branch_labels = None
- depends_on = None
- logger = logging.getLogger("alembic.runtime.migration")
- # Get project name from environment variable, defaulting to 'r2r_default'
- project_name = os.getenv("R2R_PROJECT_NAME", "r2r_default")
- def count_tokens_for_text(text: str, model: str = "gpt-3.5-turbo") -> int:
- """Count the number of tokens in the given text using tiktoken.
- Default model is set to "gpt-3.5-turbo". Adjust if you prefer a different
- model.
- """
- try:
- encoding = tiktoken.encoding_for_model(model)
- except KeyError:
- # Fallback to a known encoding if model not recognized
- encoding = tiktoken.get_encoding("cl100k_base")
- return len(encoding.encode(text))
- def check_if_upgrade_needed() -> bool:
- """Check if the upgrade has already been applied."""
- connection = op.get_bind()
- inspector = inspect(connection)
- # Check if documents table exists in the correct schema
- if not inspector.has_table("documents", schema=project_name):
- logger.info(
- f"Migration not needed: '{project_name}.documents' table doesn't exist"
- )
- return False
- # Check if total_tokens column already exists
- columns = {
- col["name"]
- for col in inspector.get_columns("documents", schema=project_name)
- }
- if "total_tokens" in columns:
- logger.info(
- "Migration not needed: documents table already has total_tokens column"
- )
- return False
- logger.info("Migration needed: documents table needs total_tokens column")
- return True
- def upgrade() -> None:
- if not check_if_upgrade_needed():
- return
- connection = op.get_bind()
- # Add the total_tokens column
- logger.info("Adding 'total_tokens' column to 'documents' table...")
- op.add_column(
- "documents",
- sa.Column(
- "total_tokens",
- sa.Integer(),
- nullable=False,
- server_default="0",
- ),
- schema=project_name,
- )
- # Process documents in batches
- BATCH_SIZE = 500
- # Count total documents
- logger.info("Determining how many documents need updating...")
- doc_count_query = text(f"SELECT COUNT(*) FROM {project_name}.documents")
- total_docs = connection.execute(doc_count_query).scalar() or 0
- logger.info(f"Total documents found: {total_docs}")
- if total_docs == 0:
- logger.info("No documents found, nothing to update.")
- return
- pages = math.ceil(total_docs / BATCH_SIZE)
- logger.info(
- f"Updating total_tokens in {pages} batches of up to {BATCH_SIZE} documents..."
- )
- default_model = os.getenv("R2R_TOKCOUNT_MODEL", "gpt-3.5-turbo")
- offset = 0
- for page_idx in range(pages):
- logger.info(
- f"Processing batch {page_idx + 1} / {pages} (OFFSET={offset}, LIMIT={BATCH_SIZE})"
- )
- # Fetch next batch of document IDs
- batch_docs_query = text(f"""
- SELECT id
- FROM {project_name}.documents
- ORDER BY id
- LIMIT :limit_val
- OFFSET :offset_val
- """)
- batch_docs = connection.execute(
- batch_docs_query, {"limit_val": BATCH_SIZE, "offset_val": offset}
- ).fetchall()
- if not batch_docs:
- break
- doc_ids = [row["id"] for row in batch_docs]
- offset += BATCH_SIZE
- # Process each document in the batch
- for doc_id in doc_ids:
- chunks_query = text(f"""
- SELECT data
- FROM {project_name}.chunks
- WHERE document_id = :doc_id
- """)
- chunk_rows = connection.execute(
- chunks_query, {"doc_id": doc_id}
- ).fetchall()
- total_tokens = 0
- for c_row in chunk_rows:
- chunk_text = c_row["data"] or ""
- total_tokens += count_tokens_for_text(
- chunk_text, model=default_model
- )
- # Update total_tokens for this document
- update_query = text(f"""
- UPDATE {project_name}.documents
- SET total_tokens = :tokcount
- WHERE id = :doc_id
- """)
- connection.execute(
- update_query, {"tokcount": total_tokens, "doc_id": doc_id}
- )
- logger.info(f"Finished batch {page_idx + 1}")
- logger.info("Done updating total_tokens.")
- def downgrade() -> None:
- """Remove the total_tokens column on downgrade."""
- logger.info(
- "Dropping column 'total_tokens' from 'documents' table (downgrade)."
- )
- op.drop_column("documents", "total_tokens", schema=project_name)
|