test_graphs.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. import uuid
  2. from enum import Enum
  3. import pytest
  4. from core.base.api.models import GraphResponse
  5. class StoreType(str, Enum):
  6. GRAPHS = "graphs"
  7. DOCUMENTS = "documents"
  8. @pytest.mark.asyncio
  9. async def test_create_graph(graphs_handler):
  10. coll_id = uuid.uuid4()
  11. resp = await graphs_handler.create(collection_id=coll_id,
  12. name="My Graph",
  13. description="Test Graph")
  14. assert isinstance(resp, GraphResponse)
  15. assert resp.name == "My Graph"
  16. assert resp.collection_id == coll_id
  17. @pytest.mark.asyncio
  18. async def test_add_entities_and_relationships(graphs_handler):
  19. # Create a graph
  20. coll_id = uuid.uuid4()
  21. graph_resp = await graphs_handler.create(collection_id=coll_id,
  22. name="TestGraph")
  23. graph_id = graph_resp.id
  24. # Add an entity
  25. entity = await graphs_handler.entities.create(
  26. parent_id=graph_id,
  27. store_type=StoreType.GRAPHS,
  28. name="TestEntity",
  29. category="Person",
  30. description="A test entity",
  31. )
  32. assert entity.name == "TestEntity"
  33. # Add another entity
  34. entity2 = await graphs_handler.entities.create(
  35. parent_id=graph_id,
  36. store_type=StoreType.GRAPHS,
  37. name="AnotherEntity",
  38. category="Place",
  39. description="A test place",
  40. )
  41. # Add a relationship between them
  42. rel = await graphs_handler.relationships.create(
  43. subject="TestEntity",
  44. subject_id=entity.id,
  45. predicate="lives_in",
  46. object="AnotherEntity",
  47. object_id=entity2.id,
  48. parent_id=graph_id,
  49. store_type=StoreType.GRAPHS,
  50. description="Entity lives in AnotherEntity",
  51. )
  52. assert rel.predicate == "lives_in"
  53. # Verify entities retrieval
  54. ents, total_ents = await graphs_handler.get_entities(parent_id=graph_id,
  55. offset=0,
  56. limit=10)
  57. assert total_ents == 2
  58. names = [e.name for e in ents]
  59. assert "TestEntity" in names and "AnotherEntity" in names
  60. # Verify relationships retrieval
  61. rels, total_rels = await graphs_handler.get_relationships(
  62. parent_id=graph_id, offset=0, limit=10)
  63. assert total_rels == 1
  64. assert rels[0].predicate == "lives_in"
  65. @pytest.mark.asyncio
  66. async def test_delete_entities_and_relationships(graphs_handler):
  67. # Create another graph
  68. coll_id = uuid.uuid4()
  69. graph_resp = await graphs_handler.create(collection_id=coll_id,
  70. name="DeletableGraph")
  71. graph_id = graph_resp.id
  72. # Add entities
  73. e1 = await graphs_handler.entities.create(
  74. parent_id=graph_id,
  75. store_type=StoreType.GRAPHS,
  76. name="DeleteMe",
  77. )
  78. e2 = await graphs_handler.entities.create(
  79. parent_id=graph_id,
  80. store_type=StoreType.GRAPHS,
  81. name="DeleteMeToo",
  82. )
  83. # Add relationship
  84. rel = await graphs_handler.relationships.create(
  85. subject="DeleteMe",
  86. subject_id=e1.id,
  87. predicate="related_to",
  88. object="DeleteMeToo",
  89. object_id=e2.id,
  90. parent_id=graph_id,
  91. store_type=StoreType.GRAPHS,
  92. )
  93. # Delete one entity
  94. await graphs_handler.entities.delete(
  95. parent_id=graph_id,
  96. entity_ids=[e1.id],
  97. store_type=StoreType.GRAPHS,
  98. )
  99. ents, count = await graphs_handler.get_entities(parent_id=graph_id,
  100. offset=0,
  101. limit=10)
  102. assert count == 1
  103. assert ents[0].id == e2.id
  104. # Delete the relationship
  105. await graphs_handler.relationships.delete(
  106. parent_id=graph_id,
  107. relationship_ids=[rel.id],
  108. store_type=StoreType.GRAPHS,
  109. )
  110. rels, rel_count = await graphs_handler.get_relationships(
  111. parent_id=graph_id, offset=0, limit=10)
  112. assert rel_count == 0
  113. @pytest.mark.asyncio
  114. async def test_communities(graphs_handler):
  115. # Insert a community for a collection_id (not strictly related to a graph_id)
  116. coll_id = uuid.uuid4()
  117. await graphs_handler.communities.create(
  118. parent_id=coll_id,
  119. store_type=StoreType.GRAPHS,
  120. name="CommunityOne",
  121. summary="Test community",
  122. findings=["finding1", "finding2"],
  123. rating=4.5,
  124. rating_explanation="Excellent",
  125. description_embedding=[0.1, 0.2, 0.3, 0.4],
  126. )
  127. comms, count = await graphs_handler.communities.get(
  128. parent_id=coll_id,
  129. store_type=StoreType.GRAPHS,
  130. offset=0,
  131. limit=10,
  132. )
  133. assert count == 1
  134. assert comms[0].name == "CommunityOne"
  135. # TODO - Fix code such that these tests pass
  136. # # @pytest.mark.asyncio
  137. # # async def test_delete_graph(graphs_handler):
  138. # # # Create a graph and then delete it
  139. # # coll_id = uuid.uuid4()
  140. # # graph_resp = await graphs_handler.create(collection_id=coll_id, name="TempGraph")
  141. # # graph_id = graph_resp.id
  142. # # # reset or delete calls are complicated in the code. We'll just call `reset` and `delete`
  143. # # await graphs_handler.reset(graph_id)
  144. # # # This should remove all entities & relationships from the graph_id
  145. # # # Now delete the graph itself
  146. # # # The `delete` method seems to be tied to collection_id rather than graph_id
  147. # # await graphs_handler.delete(collection_id=graph_id, cascade=False)
  148. # # # If the code is structured so that delete requires a collection_id,
  149. # # # ensure `graph_id == collection_id` or adapt the code accordingly.
  150. # # # Try fetching the graph
  151. # # overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[graph_id])
  152. # # assert overview["total_entries"] == 0, "Graph should be deleted"
  153. # @pytest.mark.asyncio
  154. # async def test_delete_graph(graphs_handler):
  155. # # Create a graph and then delete it
  156. # coll_id = uuid.uuid4()
  157. # graph_resp = await graphs_handler.create(collection_id=coll_id, name="TempGraph")
  158. # graph_id = graph_resp.id
  159. # # Reset the graph (remove entities, relationships, communities)
  160. # await graphs_handler.reset(graph_id)
  161. # # Now delete the graph using collection_id (which equals graph_id in this code)
  162. # await graphs_handler.delete(collection_id=coll_id)
  163. # # Verify the graph is deleted
  164. # overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[coll_id])
  165. # assert overview["total_entries"] == 0, "Graph should be deleted"
  166. @pytest.mark.asyncio
  167. async def test_create_graph_defaults(graphs_handler):
  168. # Create a graph without specifying name or description
  169. coll_id = uuid.uuid4()
  170. resp = await graphs_handler.create(collection_id=coll_id)
  171. assert resp.collection_id == coll_id
  172. # The code sets a default name, which should be "Graph {coll_id}"
  173. assert resp.name == f"Graph {coll_id}"
  174. # Default description should be empty string as per code
  175. assert resp.description == ""
  176. # @pytest.mark.asyncio
  177. # async def test_list_multiple_graphs(graphs_handler):
  178. # # Create multiple graphs
  179. # coll_id1 = uuid.uuid4()
  180. # coll_id2 = uuid.uuid4()
  181. # graph_resp1 = await graphs_handler.create(collection_id=coll_id1, name="Graph1")
  182. # graph_resp2 = await graphs_handler.create(collection_id=coll_id2, name="Graph2")
  183. # graph_resp3 = await graphs_handler.create(collection_id=coll_id2, name="Graph3")
  184. # # List all graphs without filters
  185. # overview = await graphs_handler.list_graphs(offset=0, limit=10)
  186. # # Ensure at least these three are in there
  187. # found_ids = [g.id for g in overview["results"]]
  188. # assert graph_resp1.id in found_ids
  189. # assert graph_resp2.id in found_ids
  190. # assert graph_resp3.id in found_ids
  191. # # Filter by collection_id = coll_id2 should return Graph2 and Graph3 (the most recent one first if same collection)
  192. # overview_coll2 = await graphs_handler.list_graphs(offset=0, limit=10, filter_collection_id=coll_id2)
  193. # returned_ids = [g.id for g in overview_coll2["results"]]
  194. # # According to the code, we only see the "most recent" graph per collection. Verify this logic.
  195. # # If your code is returning only the most recent graph per collection, we should see only one graph per collection_id here.
  196. # # Adjust test according to actual logic you desire.
  197. # # For this example, let's assume we should only get the latest graph per collection. Graph3 should be newer than Graph2.
  198. # assert len(returned_ids) == 1
  199. # assert graph_resp3.id in returned_ids
  200. @pytest.mark.asyncio
  201. async def test_update_graph(graphs_handler):
  202. coll_id = uuid.uuid4()
  203. graph_resp = await graphs_handler.create(collection_id=coll_id,
  204. name="OldName",
  205. description="OldDescription")
  206. graph_id = graph_resp.id
  207. # Update name and description
  208. updated_resp = await graphs_handler.update(collection_id=graph_id,
  209. name="NewName",
  210. description="NewDescription")
  211. assert updated_resp.name == "NewName"
  212. assert updated_resp.description == "NewDescription"
  213. # Retrieve and verify
  214. overview = await graphs_handler.list_graphs(offset=0,
  215. limit=10,
  216. filter_graph_ids=[graph_id])
  217. assert overview["total_entries"] == 1
  218. fetched_graph = overview["results"][0]
  219. assert fetched_graph.name == "NewName"
  220. assert fetched_graph.description == "NewDescription"
  221. @pytest.mark.asyncio
  222. async def test_bulk_entities(graphs_handler):
  223. coll_id = uuid.uuid4()
  224. graph_resp = await graphs_handler.create(collection_id=coll_id,
  225. name="BulkEntities")
  226. graph_id = graph_resp.id
  227. # Add multiple entities
  228. entities_to_add = [
  229. {
  230. "name": "EntityA",
  231. "category": "CategoryA",
  232. "description": "DescA"
  233. },
  234. {
  235. "name": "EntityB",
  236. "category": "CategoryB",
  237. "description": "DescB"
  238. },
  239. {
  240. "name": "EntityC",
  241. "category": "CategoryC",
  242. "description": "DescC"
  243. },
  244. ]
  245. for ent in entities_to_add:
  246. await graphs_handler.entities.create(
  247. parent_id=graph_id,
  248. store_type=StoreType.GRAPHS,
  249. name=ent["name"],
  250. category=ent["category"],
  251. description=ent["description"],
  252. )
  253. ents, total = await graphs_handler.get_entities(parent_id=graph_id,
  254. offset=0,
  255. limit=10)
  256. assert total == 3
  257. fetched_names = [e.name for e in ents]
  258. for ent in entities_to_add:
  259. assert ent["name"] in fetched_names
  260. @pytest.mark.asyncio
  261. async def test_relationship_filtering(graphs_handler):
  262. coll_id = uuid.uuid4()
  263. graph_resp = await graphs_handler.create(collection_id=coll_id,
  264. name="RelFilteringGraph")
  265. graph_id = graph_resp.id
  266. # Add entities
  267. e1 = await graphs_handler.entities.create(parent_id=graph_id,
  268. store_type=StoreType.GRAPHS,
  269. name="Node1")
  270. e2 = await graphs_handler.entities.create(parent_id=graph_id,
  271. store_type=StoreType.GRAPHS,
  272. name="Node2")
  273. e3 = await graphs_handler.entities.create(parent_id=graph_id,
  274. store_type=StoreType.GRAPHS,
  275. name="Node3")
  276. # Add different relationships
  277. await graphs_handler.relationships.create(
  278. subject="Node1",
  279. subject_id=e1.id,
  280. predicate="connected_to",
  281. object="Node2",
  282. object_id=e2.id,
  283. parent_id=graph_id,
  284. store_type=StoreType.GRAPHS,
  285. )
  286. await graphs_handler.relationships.create(
  287. subject="Node2",
  288. subject_id=e2.id,
  289. predicate="linked_with",
  290. object="Node3",
  291. object_id=e3.id,
  292. parent_id=graph_id,
  293. store_type=StoreType.GRAPHS,
  294. )
  295. # Get all relationships
  296. all_rels, all_count = await graphs_handler.get_relationships(
  297. parent_id=graph_id, offset=0, limit=10)
  298. assert all_count == 2
  299. # Filter by relationship_type = ["connected_to"]
  300. filtered_rels, filt_count = await graphs_handler.get_relationships(
  301. parent_id=graph_id,
  302. offset=0,
  303. limit=10,
  304. relationship_types=["connected_to"],
  305. )
  306. assert filt_count == 1
  307. assert filtered_rels[0].predicate == "connected_to"
  308. @pytest.mark.asyncio
  309. async def test_delete_all_entities(graphs_handler):
  310. coll_id = uuid.uuid4()
  311. graph_resp = await graphs_handler.create(collection_id=coll_id,
  312. name="DeleteAllEntities")
  313. graph_id = graph_resp.id
  314. # Add some entities
  315. await graphs_handler.entities.create(parent_id=graph_id,
  316. store_type=StoreType.GRAPHS,
  317. name="E1")
  318. await graphs_handler.entities.create(parent_id=graph_id,
  319. store_type=StoreType.GRAPHS,
  320. name="E2")
  321. # Delete all entities without specifying IDs
  322. await graphs_handler.entities.delete(parent_id=graph_id,
  323. store_type=StoreType.GRAPHS)
  324. ents, count = await graphs_handler.get_entities(parent_id=graph_id,
  325. offset=0,
  326. limit=10)
  327. assert count == 0
  328. @pytest.mark.asyncio
  329. async def test_delete_all_relationships(graphs_handler):
  330. coll_id = uuid.uuid4()
  331. graph_resp = await graphs_handler.create(collection_id=coll_id,
  332. name="DeleteAllRels")
  333. graph_id = graph_resp.id
  334. # Add two entities and a relationship
  335. e1 = await graphs_handler.entities.create(parent_id=graph_id,
  336. store_type=StoreType.GRAPHS,
  337. name="E1")
  338. e2 = await graphs_handler.entities.create(parent_id=graph_id,
  339. store_type=StoreType.GRAPHS,
  340. name="E2")
  341. await graphs_handler.relationships.create(
  342. subject="E1",
  343. subject_id=e1.id,
  344. predicate="connected",
  345. object="E2",
  346. object_id=e2.id,
  347. parent_id=graph_id,
  348. store_type=StoreType.GRAPHS,
  349. )
  350. # Delete all relationships
  351. await graphs_handler.relationships.delete(parent_id=graph_id,
  352. store_type=StoreType.GRAPHS)
  353. rels, rel_count = await graphs_handler.get_relationships(
  354. parent_id=graph_id, offset=0, limit=10)
  355. assert rel_count == 0
  356. @pytest.mark.asyncio
  357. async def test_error_handling_invalid_graph_id(graphs_handler):
  358. # Attempt to get a non-existent graph
  359. non_existent_id = uuid.uuid4()
  360. overview = await graphs_handler.list_graphs(
  361. offset=0, limit=10, filter_graph_ids=[non_existent_id])
  362. assert overview["total_entries"] == 0
  363. # Attempt to delete a non-existent graph
  364. with pytest.raises(Exception) as exc_info:
  365. await graphs_handler.delete(collection_id=non_existent_id)
  366. # Expect an R2RException or HTTPException (depending on your code)
  367. # Check the message or type if needed
  368. @pytest.mark.asyncio
  369. async def test_filter_by_collection_ids_in_entities(graphs_handler):
  370. # 1) Create a row in "graphs" so it can be referenced by entities
  371. some_parent_id = uuid.uuid4()
  372. some_collection_id = uuid.uuid4()
  373. insert_graph_sql = f"""
  374. INSERT INTO "{graphs_handler.project_name}"."graphs"
  375. (id, collection_id, name, description, status)
  376. VALUES ($1, $2, $3, $4, $5)
  377. """
  378. await graphs_handler.connection_manager.execute_query(
  379. insert_graph_sql,
  380. [
  381. some_parent_id,
  382. some_collection_id,
  383. "MyTestGraph",
  384. "Graph for unit test",
  385. "pending",
  386. ],
  387. )
  388. # 2) Insert a row in "graphs_entities" that references parent_id = some_parent_id
  389. row_id = uuid.uuid4()
  390. insert_entity_sql = f"""
  391. INSERT INTO "{graphs_handler.project_name}"."graphs_entities"
  392. (id, name, parent_id, metadata)
  393. VALUES ($1, $2, $3, $4)
  394. """
  395. await graphs_handler.connection_manager.execute_query(
  396. insert_entity_sql, [row_id, "TestEntity", some_parent_id, None])
  397. # 3) Now run your actual test search
  398. filter_dict = {"collection_ids": {"$in": [str(some_parent_id)]}}
  399. results = []
  400. async for row in graphs_handler.graph_search(
  401. query="anything",
  402. search_type="entities",
  403. filters=filter_dict,
  404. limit=10,
  405. use_fulltext_search=False,
  406. use_hybrid_search=False,
  407. query_embedding=[0, 0, 0, 0],
  408. ):
  409. results.append(row)
  410. assert len(results) == 1, f"Expected 1 matching entity, got {len(results)}"
  411. assert results[0]["name"] == "TestEntity"
  412. # 4) Cleanup if needed
  413. delete_entity_sql = f"""
  414. DELETE FROM "{graphs_handler.project_name}"."graphs_entities" WHERE id = $1
  415. """
  416. await graphs_handler.connection_manager.execute_query(
  417. delete_entity_sql, [row_id])
  418. delete_graph_sql = f"""
  419. DELETE FROM "{graphs_handler.project_name}"."graphs" WHERE id = $1
  420. """
  421. await graphs_handler.connection_manager.execute_query(
  422. delete_graph_sql, [some_parent_id])
  423. # # TODO - Fix code to pass this test.
  424. # # @pytest.mark.asyncio
  425. # # async def test_delete_graph_cascade(graphs_handler):
  426. # # coll_id = uuid.uuid4()
  427. # # graph_resp = await graphs_handler.create(collection_id=coll_id, name="CascadeGraph")
  428. # # graph_id = graph_resp.id
  429. # # # Add entities/relationships here if you have documents attached
  430. # # # This test would verify that cascade=True behavior is correct
  431. # # # For now, just call delete with cascade=True
  432. # # # Depending on your implementation, you might need documents associated with the collection to test fully.
  433. # # await graphs_handler.delete(collection_id=coll_id)
  434. # # overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[graph_id])
  435. # # assert overview["total_entries"] == 0
  436. # # tests/test_graph_filters.py
  437. # import pytest
  438. # import uuid
  439. # from core.providers.database.postgres import PostgresGraphsHandler
  440. # @pytest.mark.asyncio
  441. # async def test_filter_by_collection_ids_in_entities(graphs_handler: PostgresGraphsHandler):
  442. # # Suppose we want to test an entity row whose parent_id=some_uuid
  443. # some_parent_id = uuid.uuid4()
  444. # row_id = uuid.uuid4()
  445. # # Insert an entity row manually for the test
  446. # insert_sql = f"""
  447. # INSERT INTO "{graphs_handler.project_name}"."graphs_entities"
  448. # (id, name, parent_id, metadata)
  449. # VALUES ($1, $2, $3, $4)
  450. # """
  451. # await graphs_handler.connection_manager.execute_query(
  452. # insert_sql,
  453. # [row_id, "TestEntity", some_parent_id, None]
  454. # )
  455. # # Now do a search with "collection_ids": { "$in": [some_parent_id] }
  456. # filter_dict = {
  457. # "collection_ids": { "$in": [str(some_parent_id)] }
  458. # }
  459. # # graph_search with search_type='entities' triggers the logic
  460. # results = []
  461. # async for row in graphs_handler.graph_search(
  462. # query="anything",
  463. # search_type="entities",
  464. # filters=filter_dict,
  465. # limit=10,
  466. # use_fulltext_search=False,
  467. # use_hybrid_search=False,
  468. # query_embedding=[0.0,0.0,0.0,0.0], # placeholder
  469. # ):
  470. # results.append(row)
  471. # assert len(results) == 1, f"Expected 1 matching entity, got {len(results)}"
  472. # assert results[0]["name"] == "TestEntity"
  473. # # cleanup
  474. # delete_sql = f"""
  475. # DELETE FROM "{graphs_handler.project_name}"."graphs_entities" WHERE id = $1
  476. # """
  477. # await graphs_handler.connection_manager.execute_query(delete_sql, [row_id])