test_graphs.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. import uuid
  2. from enum import Enum
  3. import pytest
  4. from core.base.api.models import GraphResponse
  5. class StoreType(str, Enum):
  6. GRAPHS = "graphs"
  7. DOCUMENTS = "documents"
  8. @pytest.mark.asyncio
  9. async def test_create_graph(graphs_handler):
  10. coll_id = uuid.uuid4()
  11. resp = await graphs_handler.create(
  12. collection_id=coll_id, name="My Graph", description="Test Graph"
  13. )
  14. assert isinstance(resp, GraphResponse)
  15. assert resp.name == "My Graph"
  16. assert resp.collection_id == coll_id
  17. @pytest.mark.asyncio
  18. async def test_add_entities_and_relationships(graphs_handler):
  19. # Create a graph
  20. coll_id = uuid.uuid4()
  21. graph_resp = await graphs_handler.create(
  22. collection_id=coll_id, name="TestGraph"
  23. )
  24. graph_id = graph_resp.id
  25. # Add an entity
  26. entity = await graphs_handler.entities.create(
  27. parent_id=graph_id,
  28. store_type=StoreType.GRAPHS,
  29. name="TestEntity",
  30. category="Person",
  31. description="A test entity",
  32. )
  33. assert entity.name == "TestEntity"
  34. # Add another entity
  35. entity2 = await graphs_handler.entities.create(
  36. parent_id=graph_id,
  37. store_type=StoreType.GRAPHS,
  38. name="AnotherEntity",
  39. category="Place",
  40. description="A test place",
  41. )
  42. # Add a relationship between them
  43. rel = await graphs_handler.relationships.create(
  44. subject="TestEntity",
  45. subject_id=entity.id,
  46. predicate="lives_in",
  47. object="AnotherEntity",
  48. object_id=entity2.id,
  49. parent_id=graph_id,
  50. store_type=StoreType.GRAPHS,
  51. description="Entity lives in AnotherEntity",
  52. )
  53. assert rel.predicate == "lives_in"
  54. # Verify entities retrieval
  55. ents, total_ents = await graphs_handler.get_entities(
  56. parent_id=graph_id, offset=0, limit=10
  57. )
  58. assert total_ents == 2
  59. names = [e.name for e in ents]
  60. assert "TestEntity" in names and "AnotherEntity" in names
  61. # Verify relationships retrieval
  62. rels, total_rels = await graphs_handler.get_relationships(
  63. parent_id=graph_id, offset=0, limit=10
  64. )
  65. assert total_rels == 1
  66. assert rels[0].predicate == "lives_in"
  67. @pytest.mark.asyncio
  68. async def test_delete_entities_and_relationships(graphs_handler):
  69. # Create another graph
  70. coll_id = uuid.uuid4()
  71. graph_resp = await graphs_handler.create(
  72. collection_id=coll_id, name="DeletableGraph"
  73. )
  74. graph_id = graph_resp.id
  75. # Add entities
  76. e1 = await graphs_handler.entities.create(
  77. parent_id=graph_id,
  78. store_type=StoreType.GRAPHS,
  79. name="DeleteMe",
  80. )
  81. e2 = await graphs_handler.entities.create(
  82. parent_id=graph_id,
  83. store_type=StoreType.GRAPHS,
  84. name="DeleteMeToo",
  85. )
  86. # Add relationship
  87. rel = await graphs_handler.relationships.create(
  88. subject="DeleteMe",
  89. subject_id=e1.id,
  90. predicate="related_to",
  91. object="DeleteMeToo",
  92. object_id=e2.id,
  93. parent_id=graph_id,
  94. store_type=StoreType.GRAPHS,
  95. )
  96. # Delete one entity
  97. await graphs_handler.entities.delete(
  98. parent_id=graph_id,
  99. entity_ids=[e1.id],
  100. store_type=StoreType.GRAPHS,
  101. )
  102. ents, count = await graphs_handler.get_entities(
  103. parent_id=graph_id, offset=0, limit=10
  104. )
  105. assert count == 1
  106. assert ents[0].id == e2.id
  107. # Delete the relationship
  108. await graphs_handler.relationships.delete(
  109. parent_id=graph_id,
  110. relationship_ids=[rel.id],
  111. store_type=StoreType.GRAPHS,
  112. )
  113. rels, rel_count = await graphs_handler.get_relationships(
  114. parent_id=graph_id, offset=0, limit=10
  115. )
  116. assert rel_count == 0
  117. @pytest.mark.asyncio
  118. async def test_communities(graphs_handler):
  119. # Insert a community for a collection_id (not strictly related to a graph_id)
  120. coll_id = uuid.uuid4()
  121. await graphs_handler.communities.create(
  122. parent_id=coll_id,
  123. store_type=StoreType.GRAPHS,
  124. name="CommunityOne",
  125. summary="Test community",
  126. findings=["finding1", "finding2"],
  127. rating=4.5,
  128. rating_explanation="Excellent",
  129. description_embedding=[0.1, 0.2, 0.3, 0.4],
  130. )
  131. comms, count = await graphs_handler.communities.get(
  132. parent_id=coll_id,
  133. store_type=StoreType.GRAPHS,
  134. offset=0,
  135. limit=10,
  136. )
  137. assert count == 1
  138. assert comms[0].name == "CommunityOne"
  139. # TODO - Fix code such that these tests pass
  140. # # @pytest.mark.asyncio
  141. # # async def test_delete_graph(graphs_handler):
  142. # # # Create a graph and then delete it
  143. # # coll_id = uuid.uuid4()
  144. # # graph_resp = await graphs_handler.create(collection_id=coll_id, name="TempGraph")
  145. # # graph_id = graph_resp.id
  146. # # # reset or delete calls are complicated in the code. We'll just call `reset` and `delete`
  147. # # await graphs_handler.reset(graph_id)
  148. # # # This should remove all entities & relationships from the graph_id
  149. # # # Now delete the graph itself
  150. # # # The `delete` method seems to be tied to collection_id rather than graph_id
  151. # # await graphs_handler.delete(collection_id=graph_id, cascade=False)
  152. # # # If the code is structured so that delete requires a collection_id,
  153. # # # ensure `graph_id == collection_id` or adapt the code accordingly.
  154. # # # Try fetching the graph
  155. # # overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[graph_id])
  156. # # assert overview["total_entries"] == 0, "Graph should be deleted"
  157. # @pytest.mark.asyncio
  158. # async def test_delete_graph(graphs_handler):
  159. # # Create a graph and then delete it
  160. # coll_id = uuid.uuid4()
  161. # graph_resp = await graphs_handler.create(collection_id=coll_id, name="TempGraph")
  162. # graph_id = graph_resp.id
  163. # # Reset the graph (remove entities, relationships, communities)
  164. # await graphs_handler.reset(graph_id)
  165. # # Now delete the graph using collection_id (which equals graph_id in this code)
  166. # await graphs_handler.delete(collection_id=coll_id)
  167. # # Verify the graph is deleted
  168. # overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[coll_id])
  169. # assert overview["total_entries"] == 0, "Graph should be deleted"
  170. @pytest.mark.asyncio
  171. async def test_create_graph_defaults(graphs_handler):
  172. # Create a graph without specifying name or description
  173. coll_id = uuid.uuid4()
  174. resp = await graphs_handler.create(collection_id=coll_id)
  175. assert resp.collection_id == coll_id
  176. # The code sets a default name, which should be "Graph {coll_id}"
  177. assert resp.name == f"Graph {coll_id}"
  178. # Default description should be empty string as per code
  179. assert resp.description == ""
  180. # @pytest.mark.asyncio
  181. # async def test_list_multiple_graphs(graphs_handler):
  182. # # Create multiple graphs
  183. # coll_id1 = uuid.uuid4()
  184. # coll_id2 = uuid.uuid4()
  185. # graph_resp1 = await graphs_handler.create(collection_id=coll_id1, name="Graph1")
  186. # graph_resp2 = await graphs_handler.create(collection_id=coll_id2, name="Graph2")
  187. # graph_resp3 = await graphs_handler.create(collection_id=coll_id2, name="Graph3")
  188. # # List all graphs without filters
  189. # overview = await graphs_handler.list_graphs(offset=0, limit=10)
  190. # # Ensure at least these three are in there
  191. # found_ids = [g.id for g in overview["results"]]
  192. # assert graph_resp1.id in found_ids
  193. # assert graph_resp2.id in found_ids
  194. # assert graph_resp3.id in found_ids
  195. # # Filter by collection_id = coll_id2 should return Graph2 and Graph3 (the most recent one first if same collection)
  196. # overview_coll2 = await graphs_handler.list_graphs(offset=0, limit=10, filter_collection_id=coll_id2)
  197. # returned_ids = [g.id for g in overview_coll2["results"]]
  198. # # According to the code, we only see the "most recent" graph per collection. Verify this logic.
  199. # # If your code is returning only the most recent graph per collection, we should see only one graph per collection_id here.
  200. # # Adjust test according to actual logic you desire.
  201. # # For this example, let's assume we should only get the latest graph per collection. Graph3 should be newer than Graph2.
  202. # assert len(returned_ids) == 1
  203. # assert graph_resp3.id in returned_ids
  204. @pytest.mark.asyncio
  205. async def test_update_graph(graphs_handler):
  206. coll_id = uuid.uuid4()
  207. graph_resp = await graphs_handler.create(
  208. collection_id=coll_id, name="OldName", description="OldDescription"
  209. )
  210. graph_id = graph_resp.id
  211. # Update name and description
  212. updated_resp = await graphs_handler.update(
  213. collection_id=graph_id, name="NewName", description="NewDescription"
  214. )
  215. assert updated_resp.name == "NewName"
  216. assert updated_resp.description == "NewDescription"
  217. # Retrieve and verify
  218. overview = await graphs_handler.list_graphs(
  219. offset=0, limit=10, filter_graph_ids=[graph_id]
  220. )
  221. assert overview["total_entries"] == 1
  222. fetched_graph = overview["results"][0]
  223. assert fetched_graph.name == "NewName"
  224. assert fetched_graph.description == "NewDescription"
  225. @pytest.mark.asyncio
  226. async def test_bulk_entities(graphs_handler):
  227. coll_id = uuid.uuid4()
  228. graph_resp = await graphs_handler.create(
  229. collection_id=coll_id, name="BulkEntities"
  230. )
  231. graph_id = graph_resp.id
  232. # Add multiple entities
  233. entities_to_add = [
  234. {"name": "EntityA", "category": "CategoryA", "description": "DescA"},
  235. {"name": "EntityB", "category": "CategoryB", "description": "DescB"},
  236. {"name": "EntityC", "category": "CategoryC", "description": "DescC"},
  237. ]
  238. for ent in entities_to_add:
  239. await graphs_handler.entities.create(
  240. parent_id=graph_id,
  241. store_type=StoreType.GRAPHS,
  242. name=ent["name"],
  243. category=ent["category"],
  244. description=ent["description"],
  245. )
  246. ents, total = await graphs_handler.get_entities(
  247. parent_id=graph_id, offset=0, limit=10
  248. )
  249. assert total == 3
  250. fetched_names = [e.name for e in ents]
  251. for ent in entities_to_add:
  252. assert ent["name"] in fetched_names
  253. @pytest.mark.asyncio
  254. async def test_relationship_filtering(graphs_handler):
  255. coll_id = uuid.uuid4()
  256. graph_resp = await graphs_handler.create(
  257. collection_id=coll_id, name="RelFilteringGraph"
  258. )
  259. graph_id = graph_resp.id
  260. # Add entities
  261. e1 = await graphs_handler.entities.create(
  262. parent_id=graph_id, store_type=StoreType.GRAPHS, name="Node1"
  263. )
  264. e2 = await graphs_handler.entities.create(
  265. parent_id=graph_id, store_type=StoreType.GRAPHS, name="Node2"
  266. )
  267. e3 = await graphs_handler.entities.create(
  268. parent_id=graph_id, store_type=StoreType.GRAPHS, name="Node3"
  269. )
  270. # Add different relationships
  271. await graphs_handler.relationships.create(
  272. subject="Node1",
  273. subject_id=e1.id,
  274. predicate="connected_to",
  275. object="Node2",
  276. object_id=e2.id,
  277. parent_id=graph_id,
  278. store_type=StoreType.GRAPHS,
  279. )
  280. await graphs_handler.relationships.create(
  281. subject="Node2",
  282. subject_id=e2.id,
  283. predicate="linked_with",
  284. object="Node3",
  285. object_id=e3.id,
  286. parent_id=graph_id,
  287. store_type=StoreType.GRAPHS,
  288. )
  289. # Get all relationships
  290. all_rels, all_count = await graphs_handler.get_relationships(
  291. parent_id=graph_id, offset=0, limit=10
  292. )
  293. assert all_count == 2
  294. # Filter by relationship_type = ["connected_to"]
  295. filtered_rels, filt_count = await graphs_handler.get_relationships(
  296. parent_id=graph_id,
  297. offset=0,
  298. limit=10,
  299. relationship_types=["connected_to"],
  300. )
  301. assert filt_count == 1
  302. assert filtered_rels[0].predicate == "connected_to"
  303. @pytest.mark.asyncio
  304. async def test_delete_all_entities(graphs_handler):
  305. coll_id = uuid.uuid4()
  306. graph_resp = await graphs_handler.create(
  307. collection_id=coll_id, name="DeleteAllEntities"
  308. )
  309. graph_id = graph_resp.id
  310. # Add some entities
  311. await graphs_handler.entities.create(
  312. parent_id=graph_id, store_type=StoreType.GRAPHS, name="E1"
  313. )
  314. await graphs_handler.entities.create(
  315. parent_id=graph_id, store_type=StoreType.GRAPHS, name="E2"
  316. )
  317. # Delete all entities without specifying IDs
  318. await graphs_handler.entities.delete(
  319. parent_id=graph_id, store_type=StoreType.GRAPHS
  320. )
  321. ents, count = await graphs_handler.get_entities(
  322. parent_id=graph_id, offset=0, limit=10
  323. )
  324. assert count == 0
  325. @pytest.mark.asyncio
  326. async def test_delete_all_relationships(graphs_handler):
  327. coll_id = uuid.uuid4()
  328. graph_resp = await graphs_handler.create(
  329. collection_id=coll_id, name="DeleteAllRels"
  330. )
  331. graph_id = graph_resp.id
  332. # Add two entities and a relationship
  333. e1 = await graphs_handler.entities.create(
  334. parent_id=graph_id, store_type=StoreType.GRAPHS, name="E1"
  335. )
  336. e2 = await graphs_handler.entities.create(
  337. parent_id=graph_id, store_type=StoreType.GRAPHS, name="E2"
  338. )
  339. await graphs_handler.relationships.create(
  340. subject="E1",
  341. subject_id=e1.id,
  342. predicate="connected",
  343. object="E2",
  344. object_id=e2.id,
  345. parent_id=graph_id,
  346. store_type=StoreType.GRAPHS,
  347. )
  348. # Delete all relationships
  349. await graphs_handler.relationships.delete(
  350. parent_id=graph_id, store_type=StoreType.GRAPHS
  351. )
  352. rels, rel_count = await graphs_handler.get_relationships(
  353. parent_id=graph_id, offset=0, limit=10
  354. )
  355. assert rel_count == 0
  356. @pytest.mark.asyncio
  357. async def test_error_handling_invalid_graph_id(graphs_handler):
  358. # Attempt to get a non-existent graph
  359. non_existent_id = uuid.uuid4()
  360. overview = await graphs_handler.list_graphs(
  361. offset=0, limit=10, filter_graph_ids=[non_existent_id]
  362. )
  363. assert overview["total_entries"] == 0
  364. # Attempt to delete a non-existent graph
  365. with pytest.raises(Exception) as exc_info:
  366. await graphs_handler.delete(collection_id=non_existent_id)
  367. # Expect an R2RException or HTTPException (depending on your code)
  368. # Check the message or type if needed
  369. # TODO - Fix code to pass this test.
  370. # @pytest.mark.asyncio
  371. # async def test_delete_graph_cascade(graphs_handler):
  372. # coll_id = uuid.uuid4()
  373. # graph_resp = await graphs_handler.create(collection_id=coll_id, name="CascadeGraph")
  374. # graph_id = graph_resp.id
  375. # # Add entities/relationships here if you have documents attached
  376. # # This test would verify that cascade=True behavior is correct
  377. # # For now, just call delete with cascade=True
  378. # # Depending on your implementation, you might need documents associated with the collection to test fully.
  379. # await graphs_handler.delete(collection_id=coll_id)
  380. # overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[graph_id])
  381. # assert overview["total_entries"] == 0