123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344 |
- # tests/conftest.py
- import os
- from uuid import uuid4
- import pytest
- from core.base import AppConfig, DatabaseConfig, VectorQuantizationType
- from core.database.postgres import (
- PostgresChunksHandler,
- PostgresCollectionsHandler,
- PostgresConnectionManager,
- PostgresConversationsHandler,
- PostgresDatabaseProvider,
- PostgresDocumentsHandler,
- PostgresGraphsHandler,
- PostgresLimitsHandler,
- )
- from core.database.users import ( # Make sure this import is correct
- PostgresUserHandler,
- )
- from core.providers import NaClCryptoConfig, NaClCryptoProvider
- from core.utils import generate_user_id
- TEST_DB_CONNECTION_STRING = os.environ.get(
- "TEST_DB_CONNECTION_STRING",
- "postgresql://postgres:postgres@localhost:5432/test_db",
- )
- @pytest.fixture
- async def db_provider():
- crypto_provider = NaClCryptoProvider(NaClCryptoConfig(app={}))
- db_config = DatabaseConfig(
- app=AppConfig(project_name="test_project"),
- provider="postgres",
- connection_string=TEST_DB_CONNECTION_STRING,
- postgres_configuration_settings={
- "max_connections": 10,
- "statement_cache_size": 100,
- },
- project_name="test_project",
- )
- dimension = 4
- quantization_type = VectorQuantizationType.FP32
- db_provider = PostgresDatabaseProvider(
- db_config, dimension, crypto_provider, quantization_type
- )
- await db_provider.initialize()
- yield db_provider
- # Teardown logic if needed
- await db_provider.close()
- @pytest.fixture
- def crypto_provider():
- # Provide a crypto provider fixture if needed separately
- return NaClCryptoProvider(NaClCryptoConfig(app={}))
- @pytest.fixture
- async def chunks_handler(db_provider):
- dimension = db_provider.dimension
- quantization_type = db_provider.quantization_type
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- handler = PostgresChunksHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- dimension=dimension,
- quantization_type=quantization_type,
- )
- await handler.create_tables()
- return handler
- @pytest.fixture
- async def collections_handler(db_provider):
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- config = db_provider.config
- handler = PostgresCollectionsHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- config=config,
- )
- await handler.create_tables()
- return handler
- @pytest.fixture
- async def conversations_handler(db_provider):
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- handler = PostgresConversationsHandler(project_name, connection_manager)
- await handler.create_tables()
- return handler
- @pytest.fixture
- async def documents_handler(db_provider):
- dimension = db_provider.dimension
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- handler = PostgresDocumentsHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- dimension=dimension,
- )
- await handler.create_tables()
- return handler
- @pytest.fixture
- async def graphs_handler(db_provider):
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- dimension = db_provider.dimension
- quantization_type = db_provider.quantization_type
- # If collections_handler is needed, you can depend on the collections_handler fixture
- # or pass None if it's optional.
- handler = PostgresGraphsHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- dimension=dimension,
- quantization_type=quantization_type,
- collections_handler=None, # if needed, or await collections_handler fixture
- )
- await handler.create_tables()
- return handler
- @pytest.fixture
- async def limits_handler(db_provider):
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- config = db_provider.config
- handler = PostgresLimitsHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- config=config,
- )
- await handler.create_tables()
- # Optionally truncate
- await connection_manager.execute_query(
- f"TRUNCATE {handler._get_table_name('request_log')};"
- )
- return handler
- @pytest.fixture
- async def users_handler(db_provider, crypto_provider):
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- handler = PostgresUserHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- crypto_provider=crypto_provider,
- )
- await handler.create_tables()
- # Optionally clean up users table before each test
- await connection_manager.execute_query(
- f"TRUNCATE {handler._get_table_name('users')} CASCADE;"
- )
- await connection_manager.execute_query(
- f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;"
- )
- return handler
- # # tests/conftest.py
- # import pytest
- # import os
- # from core.database.postgres import (
- # PostgresChunksHandler,
- # PostgresConnectionManager,
- # PostgresDatabaseProvider,
- # PostgresCollectionsHandler,
- # PostgresConversationsHandler,
- # PostgresDocumentsHandler,
- # PostgresGraphsHandler,
- # PostgresLimitsHandler,
- # PostgresUserHandler
- # )
- # from core.providers import NaClCryptoConfig, NaClCryptoProvider
- # from core.base import DatabaseConfig, VectorQuantizationType
- # TEST_DB_CONNECTION_STRING = os.environ.get(
- # "TEST_DB_CONNECTION_STRING",
- # "postgresql://postgres:postgres@localhost:5432/test_db",
- # )
- # @pytest.fixture
- # async def db_provider():
- # # Example: a crypto provider needed by the database
- # crypto_provider = NaClCryptoProvider(NaClCryptoConfig(app={}))
- # db_config = DatabaseConfig(
- # app={},
- # provider="postgres",
- # connection_string=TEST_DB_CONNECTION_STRING,
- # # Set these values as appropriate
- # postgres_configuration_settings={
- # "max_connections": 10,
- # "statement_cache_size": 100,
- # },
- # )
- # dimension = 4
- # quantization_type = VectorQuantizationType.FP32
- # db_provider = PostgresDatabaseProvider(
- # db_config, dimension, crypto_provider, quantization_type
- # )
- # await db_provider.initialize()
- # yield db_provider
- # # Teardown logic if needed: close pools, drop tables, etc.
- # await db_provider.close()
- # @pytest.fixture
- # async def chunks_handler(db_provider):
- # # Assuming project_name and dimension are retrieved from db_provider
- # dimension = db_provider.dimension
- # quantization_type = db_provider.quantization_type
- # project_name = db_provider.project_name
- # connection_manager = (
- # db_provider.connection_manager
- # ) # type: PostgresConnectionManager
- # handler = PostgresChunksHandler(
- # project_name=project_name,
- # connection_manager=connection_manager,
- # dimension=dimension,
- # quantization_type=quantization_type,
- # )
- # await handler.create_tables()
- # return handler
- # @pytest.fixture
- # async def collections_handler(db_provider):
- # project_name = db_provider.project_name
- # connection_manager = db_provider.connection_manager
- # config = db_provider.config
- # handler = PostgresCollectionsHandler(
- # project_name=project_name,
- # connection_manager=connection_manager,
- # config=config
- # )
- # await handler.create_tables()
- # return handler
- # @pytest.fixture
- # async def conversations_handler(db_provider):
- # project_name = db_provider.project_name
- # connection_manager = db_provider.connection_manager
- # handler = PostgresConversationsHandler(project_name, connection_manager)
- # await handler.create_tables()
- # return handler
- # @pytest.fixture
- # async def documents_handler(db_provider):
- # dimension = db_provider.dimension
- # project_name = db_provider.project_name
- # connection_manager = db_provider.connection_manager
- # handler = PostgresDocumentsHandler(
- # project_name=project_name,
- # connection_manager=connection_manager,
- # dimension=dimension,
- # )
- # await handler.create_tables()
- # return handler
- # @pytest.fixture
- # async def graphs_handler(db_provider):
- # project_name = db_provider.project_name
- # connection_manager = db_provider.connection_manager
- # dimension = db_provider.dimension
- # quantization_type = db_provider.quantization_type
- # # Constructing graphs handler with required args
- # handler = PostgresGraphsHandler(
- # project_name=project_name,
- # connection_manager=connection_manager,
- # dimension=dimension,
- # quantization_type=quantization_type,
- # collections_handler=None # If needed, you can mock or create a collections_handler
- # )
- # await handler.create_tables()
- # return handler
- # @pytest.fixture
- # async def limits_handler(db_provider):
- # project_name = db_provider.project_name
- # connection_manager = db_provider.connection_manager
- # config = db_provider.config # This has default limits
- # handler = PostgresLimitsHandler(
- # project_name=project_name,
- # connection_manager=connection_manager,
- # config=config,
- # )
- # await handler.create_tables()
- # # Optionally truncate after creation to ensure clean state
- # await connection_manager.execute_query(f"TRUNCATE {handler._get_table_name('request_log')};")
- # return handler
- # @pytest.fixture
- # async def users_handler(db_provider, crypto_provider):
- # project_name = db_provider.project_name
- # connection_manager = db_provider.connection_manager
- # handler = PostgresUserHandler(
- # project_name=project_name,
- # connection_manager=connection_manager,
- # crypto_provider=crypto_provider,
- # )
- # await handler.create_tables()
- # # Optionally clean up users table before each test
- # await connection_manager.execute_query(f"TRUNCATE {handler._get_table_name('users')} CASCADE;")
- # await connection_manager.execute_query(f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;")
- # return handler
|