3efc7b3b1b3d_add_total_tokens_count.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. """add_total_tokens_to_documents.
  2. Revision ID: 3efc7b3b1b3d
  3. Revises: 7eb70560f406
  4. Create Date: 2025-01-21 14:59:00.000000
  5. """
  6. import logging
  7. import math
  8. import os
  9. import sqlalchemy as sa
  10. import tiktoken
  11. from alembic import op
  12. from sqlalchemy import inspect, text
  13. # revision identifiers, used by Alembic.
  14. revision = "3efc7b3b1b3d"
  15. down_revision = "7eb70560f406"
  16. branch_labels = None
  17. depends_on = None
  18. logger = logging.getLogger("alembic.runtime.migration")
  19. # Get project name from environment variable, defaulting to 'r2r_default'
  20. project_name = os.getenv("R2R_PROJECT_NAME", "r2r_default")
  21. def count_tokens_for_text(text: str, model: str = "gpt-3.5-turbo") -> int:
  22. """Count the number of tokens in the given text using tiktoken.
  23. Default model is set to "gpt-3.5-turbo". Adjust if you prefer a different
  24. model.
  25. """
  26. try:
  27. encoding = tiktoken.encoding_for_model(model)
  28. except KeyError:
  29. # Fallback to a known encoding if model not recognized
  30. encoding = tiktoken.get_encoding("cl100k_base")
  31. return len(encoding.encode(text))
  32. def check_if_upgrade_needed() -> bool:
  33. """Check if the upgrade has already been applied."""
  34. connection = op.get_bind()
  35. inspector = inspect(connection)
  36. # Check if documents table exists in the correct schema
  37. if not inspector.has_table("documents", schema=project_name):
  38. logger.info(
  39. f"Migration not needed: '{project_name}.documents' table doesn't exist"
  40. )
  41. return False
  42. # Check if total_tokens column already exists
  43. columns = {
  44. col["name"]
  45. for col in inspector.get_columns("documents", schema=project_name)
  46. }
  47. if "total_tokens" in columns:
  48. logger.info(
  49. "Migration not needed: documents table already has total_tokens column"
  50. )
  51. return False
  52. logger.info("Migration needed: documents table needs total_tokens column")
  53. return True
  54. def upgrade() -> None:
  55. if not check_if_upgrade_needed():
  56. return
  57. connection = op.get_bind()
  58. # Add the total_tokens column
  59. logger.info("Adding 'total_tokens' column to 'documents' table...")
  60. op.add_column(
  61. "documents",
  62. sa.Column(
  63. "total_tokens",
  64. sa.Integer(),
  65. nullable=False,
  66. server_default="0",
  67. ),
  68. schema=project_name,
  69. )
  70. # Process documents in batches
  71. BATCH_SIZE = 500
  72. # Count total documents
  73. logger.info("Determining how many documents need updating...")
  74. doc_count_query = text(f"SELECT COUNT(*) FROM {project_name}.documents")
  75. total_docs = connection.execute(doc_count_query).scalar() or 0
  76. logger.info(f"Total documents found: {total_docs}")
  77. if total_docs == 0:
  78. logger.info("No documents found, nothing to update.")
  79. return
  80. pages = math.ceil(total_docs / BATCH_SIZE)
  81. logger.info(
  82. f"Updating total_tokens in {pages} batches of up to {BATCH_SIZE} documents..."
  83. )
  84. default_model = os.getenv("R2R_TOKCOUNT_MODEL", "gpt-3.5-turbo")
  85. offset = 0
  86. for page_idx in range(pages):
  87. logger.info(
  88. f"Processing batch {page_idx + 1} / {pages} (OFFSET={offset}, LIMIT={BATCH_SIZE})"
  89. )
  90. # Fetch next batch of document IDs
  91. batch_docs_query = text(f"""
  92. SELECT id
  93. FROM {project_name}.documents
  94. ORDER BY id
  95. LIMIT :limit_val
  96. OFFSET :offset_val
  97. """)
  98. batch_docs = connection.execute(
  99. batch_docs_query, {"limit_val": BATCH_SIZE, "offset_val": offset}
  100. ).fetchall()
  101. if not batch_docs:
  102. break
  103. doc_ids = [row["id"] for row in batch_docs]
  104. offset += BATCH_SIZE
  105. # Process each document in the batch
  106. for doc_id in doc_ids:
  107. chunks_query = text(f"""
  108. SELECT data
  109. FROM {project_name}.chunks
  110. WHERE document_id = :doc_id
  111. """)
  112. chunk_rows = connection.execute(
  113. chunks_query, {"doc_id": doc_id}
  114. ).fetchall()
  115. total_tokens = 0
  116. for c_row in chunk_rows:
  117. chunk_text = c_row["data"] or ""
  118. total_tokens += count_tokens_for_text(
  119. chunk_text, model=default_model
  120. )
  121. # Update total_tokens for this document
  122. update_query = text(f"""
  123. UPDATE {project_name}.documents
  124. SET total_tokens = :tokcount
  125. WHERE id = :doc_id
  126. """)
  127. connection.execute(
  128. update_query, {"tokcount": total_tokens, "doc_id": doc_id}
  129. )
  130. logger.info(f"Finished batch {page_idx + 1}")
  131. logger.info("Done updating total_tokens.")
  132. def downgrade() -> None:
  133. """Remove the total_tokens column on downgrade."""
  134. logger.info(
  135. "Dropping column 'total_tokens' from 'documents' table (downgrade)."
  136. )
  137. op.drop_column("documents", "total_tokens", schema=project_name)