d342e632358a_migrate_to_asyncpg.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. """migrate_to_asyncpg.
  2. Revision ID: d342e632358a
  3. Revises:
  4. Create Date: 2024-10-22 11:55:49.461015
  5. """
  6. import os
  7. from typing import Sequence, Union
  8. import sqlalchemy as sa
  9. from alembic import op
  10. from sqlalchemy import inspect
  11. from sqlalchemy.dialects import postgresql
  12. from sqlalchemy.types import UserDefinedType
  13. # revision identifiers, used by Alembic.
  14. revision: str = "d342e632358a"
  15. down_revision: Union[str, None] = None
  16. branch_labels: Union[str, Sequence[str], None] = None
  17. depends_on: Union[str, Sequence[str], None] = None
  18. project_name = os.getenv("R2R_PROJECT_NAME") or "r2r_default"
  19. new_vector_table_name = "vectors"
  20. old_vector_table_name = project_name
  21. class Vector(UserDefinedType):
  22. def get_col_spec(self, **kw):
  23. return "vector"
  24. def check_if_upgrade_needed():
  25. """Check if the upgrade has already been applied or is needed."""
  26. connection = op.get_bind()
  27. inspector = inspect(connection)
  28. # First check if the old table exists - if it doesn't, we don't need this migration
  29. has_old_table = inspector.has_table(
  30. old_vector_table_name, schema=project_name
  31. )
  32. if not has_old_table:
  33. print(
  34. f"Migration not needed: Original '{old_vector_table_name}' table doesn't exist"
  35. )
  36. # Skip this migration since we're starting from a newer state
  37. return False
  38. # Only if the old table exists, check if we need to migrate it
  39. has_new_table = inspector.has_table(
  40. new_vector_table_name, schema=project_name
  41. )
  42. if has_new_table:
  43. print(
  44. f"Migration not needed: '{new_vector_table_name}' table already exists"
  45. )
  46. return False
  47. print(
  48. f"Migration needed: Need to migrate from '{old_vector_table_name}' to '{new_vector_table_name}'"
  49. )
  50. return True
  51. def upgrade() -> None:
  52. if check_if_upgrade_needed():
  53. # Create required extensions
  54. op.execute("CREATE EXTENSION IF NOT EXISTS vector")
  55. op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
  56. op.execute("CREATE EXTENSION IF NOT EXISTS btree_gin")
  57. # KG table migrations
  58. op.execute(
  59. f"ALTER TABLE IF EXISTS {project_name}.entity_raw RENAME TO chunk_entity"
  60. )
  61. op.execute(
  62. f"ALTER TABLE IF EXISTS {project_name}.triple_raw RENAME TO chunk_triple"
  63. )
  64. op.execute(
  65. f"ALTER TABLE IF EXISTS {project_name}.entity_embedding RENAME TO document_entity"
  66. )
  67. op.execute(
  68. f"ALTER TABLE IF EXISTS {project_name}.community RENAME TO community_info"
  69. )
  70. # Create the new table
  71. op.create_table(
  72. new_vector_table_name,
  73. sa.Column("extraction_id", postgresql.UUID(), nullable=False),
  74. sa.Column("document_id", postgresql.UUID(), nullable=False),
  75. sa.Column("user_id", postgresql.UUID(), nullable=False),
  76. sa.Column(
  77. "collection_ids",
  78. postgresql.ARRAY(postgresql.UUID()),
  79. server_default="{}",
  80. ),
  81. sa.Column("vec", Vector), # This will be handled as a vector type
  82. sa.Column("text", sa.Text(), nullable=True),
  83. sa.Column(
  84. "fts",
  85. postgresql.TSVECTOR,
  86. nullable=False,
  87. server_default=sa.text(
  88. "to_tsvector('english'::regconfig, '')"
  89. ),
  90. ),
  91. sa.Column(
  92. "metadata",
  93. postgresql.JSONB(),
  94. server_default="{}",
  95. nullable=False,
  96. ),
  97. sa.PrimaryKeyConstraint("extraction_id"),
  98. schema=project_name,
  99. )
  100. # Create indices
  101. op.create_index(
  102. "idx_vectors_document_id",
  103. new_vector_table_name,
  104. ["document_id"],
  105. schema=project_name,
  106. )
  107. op.create_index(
  108. "idx_vectors_user_id",
  109. new_vector_table_name,
  110. ["user_id"],
  111. schema=project_name,
  112. )
  113. op.create_index(
  114. "idx_vectors_collection_ids",
  115. new_vector_table_name,
  116. ["collection_ids"],
  117. schema=project_name,
  118. postgresql_using="gin",
  119. )
  120. op.create_index(
  121. "idx_vectors_fts",
  122. new_vector_table_name,
  123. ["fts"],
  124. schema=project_name,
  125. postgresql_using="gin",
  126. )
  127. # Migrate data from old table (assuming old table name is 'old_vectors')
  128. # Note: You'll need to replace 'old_schema' and 'old_vectors' with your actual names
  129. op.execute(f"""
  130. INSERT INTO {project_name}.{new_vector_table_name}
  131. (extraction_id, document_id, user_id, collection_ids, vec, text, metadata)
  132. SELECT
  133. extraction_id,
  134. document_id,
  135. user_id,
  136. collection_ids,
  137. vec,
  138. text,
  139. metadata
  140. FROM {project_name}.{old_vector_table_name}
  141. """)
  142. # Verify data migration
  143. op.execute(f"""
  144. SELECT COUNT(*) old_count FROM {project_name}.{old_vector_table_name};
  145. SELECT COUNT(*) new_count FROM {project_name}.{new_vector_table_name};
  146. """)
  147. # If we get here, migration was successful, so drop the old table
  148. op.execute(f"""
  149. DROP TABLE IF EXISTS {project_name}.{old_vector_table_name};
  150. """)
  151. def downgrade() -> None:
  152. # Drop all indices
  153. op.drop_index("idx_vectors_fts", schema=project_name)
  154. op.drop_index("idx_vectors_collection_ids", schema=project_name)
  155. op.drop_index("idx_vectors_user_id", schema=project_name)
  156. op.drop_index("idx_vectors_document_id", schema=project_name)
  157. # Drop the new table
  158. op.drop_table(new_vector_table_name, schema=project_name)
  159. # Revert KG table migrations
  160. op.execute(
  161. f"ALTER TABLE IF EXISTS {project_name}.chunk_entity RENAME TO entity_raw"
  162. )
  163. op.execute(
  164. f"ALTER TABLE IF EXISTS {project_name}.chunk_relationship RENAME TO relationship_raw"
  165. )
  166. op.execute(
  167. f"ALTER TABLE IF EXISTS {project_name}.document_entity RENAME TO entity_embedding"
  168. )
  169. op.execute(
  170. f"ALTER TABLE IF EXISTS {project_name}.community_info RENAME TO community"
  171. )