conftest.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # tests/conftest.py
  2. import os
  3. from uuid import uuid4
  4. import pytest
  5. from core.base import AppConfig, DatabaseConfig, VectorQuantizationType
  6. from core.database.postgres import (
  7. PostgresChunksHandler,
  8. PostgresCollectionsHandler,
  9. PostgresConnectionManager,
  10. PostgresConversationsHandler,
  11. PostgresDatabaseProvider,
  12. PostgresDocumentsHandler,
  13. PostgresGraphsHandler,
  14. PostgresLimitsHandler,
  15. )
  16. from core.database.users import ( # Make sure this import is correct
  17. PostgresUserHandler,
  18. )
  19. from core.providers import NaClCryptoConfig, NaClCryptoProvider
  20. from core.utils import generate_user_id
  21. TEST_DB_CONNECTION_STRING = os.environ.get(
  22. "TEST_DB_CONNECTION_STRING",
  23. "postgresql://postgres:postgres@localhost:5432/test_db",
  24. )
  25. @pytest.fixture
  26. async def db_provider():
  27. crypto_provider = NaClCryptoProvider(NaClCryptoConfig(app={}))
  28. db_config = DatabaseConfig(
  29. app=AppConfig(project_name="test_project"),
  30. provider="postgres",
  31. connection_string=TEST_DB_CONNECTION_STRING,
  32. postgres_configuration_settings={
  33. "max_connections": 10,
  34. "statement_cache_size": 100,
  35. },
  36. project_name="test_project",
  37. )
  38. dimension = 4
  39. quantization_type = VectorQuantizationType.FP32
  40. db_provider = PostgresDatabaseProvider(
  41. db_config, dimension, crypto_provider, quantization_type
  42. )
  43. await db_provider.initialize()
  44. yield db_provider
  45. # Teardown logic if needed
  46. await db_provider.close()
  47. @pytest.fixture
  48. def crypto_provider():
  49. # Provide a crypto provider fixture if needed separately
  50. return NaClCryptoProvider(NaClCryptoConfig(app={}))
  51. @pytest.fixture
  52. async def chunks_handler(db_provider):
  53. dimension = db_provider.dimension
  54. quantization_type = db_provider.quantization_type
  55. project_name = db_provider.project_name
  56. connection_manager = db_provider.connection_manager
  57. handler = PostgresChunksHandler(
  58. project_name=project_name,
  59. connection_manager=connection_manager,
  60. dimension=dimension,
  61. quantization_type=quantization_type,
  62. )
  63. await handler.create_tables()
  64. return handler
  65. @pytest.fixture
  66. async def collections_handler(db_provider):
  67. project_name = db_provider.project_name
  68. connection_manager = db_provider.connection_manager
  69. config = db_provider.config
  70. handler = PostgresCollectionsHandler(
  71. project_name=project_name,
  72. connection_manager=connection_manager,
  73. config=config,
  74. )
  75. await handler.create_tables()
  76. return handler
  77. @pytest.fixture
  78. async def conversations_handler(db_provider):
  79. project_name = db_provider.project_name
  80. connection_manager = db_provider.connection_manager
  81. handler = PostgresConversationsHandler(project_name, connection_manager)
  82. await handler.create_tables()
  83. return handler
  84. @pytest.fixture
  85. async def documents_handler(db_provider):
  86. dimension = db_provider.dimension
  87. project_name = db_provider.project_name
  88. connection_manager = db_provider.connection_manager
  89. handler = PostgresDocumentsHandler(
  90. project_name=project_name,
  91. connection_manager=connection_manager,
  92. dimension=dimension,
  93. )
  94. await handler.create_tables()
  95. return handler
  96. @pytest.fixture
  97. async def graphs_handler(db_provider):
  98. project_name = db_provider.project_name
  99. connection_manager = db_provider.connection_manager
  100. dimension = db_provider.dimension
  101. quantization_type = db_provider.quantization_type
  102. # If collections_handler is needed, you can depend on the collections_handler fixture
  103. # or pass None if it's optional.
  104. handler = PostgresGraphsHandler(
  105. project_name=project_name,
  106. connection_manager=connection_manager,
  107. dimension=dimension,
  108. quantization_type=quantization_type,
  109. collections_handler=None, # if needed, or await collections_handler fixture
  110. )
  111. await handler.create_tables()
  112. return handler
  113. @pytest.fixture
  114. async def limits_handler(db_provider):
  115. project_name = db_provider.project_name
  116. connection_manager = db_provider.connection_manager
  117. config = db_provider.config
  118. handler = PostgresLimitsHandler(
  119. project_name=project_name,
  120. connection_manager=connection_manager,
  121. config=config,
  122. )
  123. await handler.create_tables()
  124. # Optionally truncate
  125. await connection_manager.execute_query(
  126. f"TRUNCATE {handler._get_table_name('request_log')};"
  127. )
  128. return handler
  129. @pytest.fixture
  130. async def users_handler(db_provider, crypto_provider):
  131. project_name = db_provider.project_name
  132. connection_manager = db_provider.connection_manager
  133. handler = PostgresUserHandler(
  134. project_name=project_name,
  135. connection_manager=connection_manager,
  136. crypto_provider=crypto_provider,
  137. )
  138. await handler.create_tables()
  139. # Optionally clean up users table before each test
  140. await connection_manager.execute_query(
  141. f"TRUNCATE {handler._get_table_name('users')} CASCADE;"
  142. )
  143. await connection_manager.execute_query(
  144. f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;"
  145. )
  146. return handler
  147. # # tests/conftest.py
  148. # import pytest
  149. # import os
  150. # from core.database.postgres import (
  151. # PostgresChunksHandler,
  152. # PostgresConnectionManager,
  153. # PostgresDatabaseProvider,
  154. # PostgresCollectionsHandler,
  155. # PostgresConversationsHandler,
  156. # PostgresDocumentsHandler,
  157. # PostgresGraphsHandler,
  158. # PostgresLimitsHandler,
  159. # PostgresUserHandler
  160. # )
  161. # from core.providers import NaClCryptoConfig, NaClCryptoProvider
  162. # from core.base import DatabaseConfig, VectorQuantizationType
  163. # TEST_DB_CONNECTION_STRING = os.environ.get(
  164. # "TEST_DB_CONNECTION_STRING",
  165. # "postgresql://postgres:postgres@localhost:5432/test_db",
  166. # )
  167. # @pytest.fixture
  168. # async def db_provider():
  169. # # Example: a crypto provider needed by the database
  170. # crypto_provider = NaClCryptoProvider(NaClCryptoConfig(app={}))
  171. # db_config = DatabaseConfig(
  172. # app={},
  173. # provider="postgres",
  174. # connection_string=TEST_DB_CONNECTION_STRING,
  175. # # Set these values as appropriate
  176. # postgres_configuration_settings={
  177. # "max_connections": 10,
  178. # "statement_cache_size": 100,
  179. # },
  180. # )
  181. # dimension = 4
  182. # quantization_type = VectorQuantizationType.FP32
  183. # db_provider = PostgresDatabaseProvider(
  184. # db_config, dimension, crypto_provider, quantization_type
  185. # )
  186. # await db_provider.initialize()
  187. # yield db_provider
  188. # # Teardown logic if needed: close pools, drop tables, etc.
  189. # await db_provider.close()
  190. # @pytest.fixture
  191. # async def chunks_handler(db_provider):
  192. # # Assuming project_name and dimension are retrieved from db_provider
  193. # dimension = db_provider.dimension
  194. # quantization_type = db_provider.quantization_type
  195. # project_name = db_provider.project_name
  196. # connection_manager = (
  197. # db_provider.connection_manager
  198. # ) # type: PostgresConnectionManager
  199. # handler = PostgresChunksHandler(
  200. # project_name=project_name,
  201. # connection_manager=connection_manager,
  202. # dimension=dimension,
  203. # quantization_type=quantization_type,
  204. # )
  205. # await handler.create_tables()
  206. # return handler
  207. # @pytest.fixture
  208. # async def collections_handler(db_provider):
  209. # project_name = db_provider.project_name
  210. # connection_manager = db_provider.connection_manager
  211. # config = db_provider.config
  212. # handler = PostgresCollectionsHandler(
  213. # project_name=project_name,
  214. # connection_manager=connection_manager,
  215. # config=config
  216. # )
  217. # await handler.create_tables()
  218. # return handler
  219. # @pytest.fixture
  220. # async def conversations_handler(db_provider):
  221. # project_name = db_provider.project_name
  222. # connection_manager = db_provider.connection_manager
  223. # handler = PostgresConversationsHandler(project_name, connection_manager)
  224. # await handler.create_tables()
  225. # return handler
  226. # @pytest.fixture
  227. # async def documents_handler(db_provider):
  228. # dimension = db_provider.dimension
  229. # project_name = db_provider.project_name
  230. # connection_manager = db_provider.connection_manager
  231. # handler = PostgresDocumentsHandler(
  232. # project_name=project_name,
  233. # connection_manager=connection_manager,
  234. # dimension=dimension,
  235. # )
  236. # await handler.create_tables()
  237. # return handler
  238. # @pytest.fixture
  239. # async def graphs_handler(db_provider):
  240. # project_name = db_provider.project_name
  241. # connection_manager = db_provider.connection_manager
  242. # dimension = db_provider.dimension
  243. # quantization_type = db_provider.quantization_type
  244. # # Constructing graphs handler with required args
  245. # handler = PostgresGraphsHandler(
  246. # project_name=project_name,
  247. # connection_manager=connection_manager,
  248. # dimension=dimension,
  249. # quantization_type=quantization_type,
  250. # collections_handler=None # If needed, you can mock or create a collections_handler
  251. # )
  252. # await handler.create_tables()
  253. # return handler
  254. # @pytest.fixture
  255. # async def limits_handler(db_provider):
  256. # project_name = db_provider.project_name
  257. # connection_manager = db_provider.connection_manager
  258. # config = db_provider.config # This has default limits
  259. # handler = PostgresLimitsHandler(
  260. # project_name=project_name,
  261. # connection_manager=connection_manager,
  262. # config=config,
  263. # )
  264. # await handler.create_tables()
  265. # # Optionally truncate after creation to ensure clean state
  266. # await connection_manager.execute_query(f"TRUNCATE {handler._get_table_name('request_log')};")
  267. # return handler
  268. # @pytest.fixture
  269. # async def users_handler(db_provider, crypto_provider):
  270. # project_name = db_provider.project_name
  271. # connection_manager = db_provider.connection_manager
  272. # handler = PostgresUserHandler(
  273. # project_name=project_name,
  274. # connection_manager=connection_manager,
  275. # crypto_provider=crypto_provider,
  276. # )
  277. # await handler.create_tables()
  278. # # Optionally clean up users table before each test
  279. # await connection_manager.execute_query(f"TRUNCATE {handler._get_table_name('users')} CASCADE;")
  280. # await connection_manager.execute_query(f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;")
  281. # return handler