test_graphs.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. import pytest
  2. import uuid
  3. from uuid import UUID
  4. from enum import Enum
  5. from core.base.abstractions import Entity, Relationship, Community
  6. from core.base.api.models import GraphResponse
  7. class StoreType(str, Enum):
  8. GRAPHS = "graphs"
  9. DOCUMENTS = "documents"
  10. @pytest.mark.asyncio
  11. async def test_create_graph(graphs_handler):
  12. coll_id = uuid.uuid4()
  13. resp = await graphs_handler.create(
  14. collection_id=coll_id, name="My Graph", description="Test Graph"
  15. )
  16. assert isinstance(resp, GraphResponse)
  17. assert resp.name == "My Graph"
  18. assert resp.collection_id == coll_id
  19. @pytest.mark.asyncio
  20. async def test_add_entities_and_relationships(graphs_handler):
  21. # Create a graph
  22. coll_id = uuid.uuid4()
  23. graph_resp = await graphs_handler.create(
  24. collection_id=coll_id, name="TestGraph"
  25. )
  26. graph_id = graph_resp.id
  27. # Add an entity
  28. entity = await graphs_handler.entities.create(
  29. parent_id=graph_id,
  30. store_type=StoreType.GRAPHS.value,
  31. name="TestEntity",
  32. category="Person",
  33. description="A test entity",
  34. )
  35. assert entity.name == "TestEntity"
  36. # Add another entity
  37. entity2 = await graphs_handler.entities.create(
  38. parent_id=graph_id,
  39. store_type=StoreType.GRAPHS.value,
  40. name="AnotherEntity",
  41. category="Place",
  42. description="A test place",
  43. )
  44. # Add a relationship between them
  45. rel = await graphs_handler.relationships.create(
  46. subject="TestEntity",
  47. subject_id=entity.id,
  48. predicate="lives_in",
  49. object="AnotherEntity",
  50. object_id=entity2.id,
  51. parent_id=graph_id,
  52. store_type=StoreType.GRAPHS.value,
  53. description="Entity lives in AnotherEntity",
  54. )
  55. assert rel.predicate == "lives_in"
  56. # Verify entities retrieval
  57. ents, total_ents = await graphs_handler.get_entities(
  58. parent_id=graph_id, offset=0, limit=10
  59. )
  60. assert total_ents == 2
  61. names = [e.name for e in ents]
  62. assert "TestEntity" in names and "AnotherEntity" in names
  63. # Verify relationships retrieval
  64. rels, total_rels = await graphs_handler.get_relationships(
  65. parent_id=graph_id, offset=0, limit=10
  66. )
  67. assert total_rels == 1
  68. assert rels[0].predicate == "lives_in"
  69. @pytest.mark.asyncio
  70. async def test_delete_entities_and_relationships(graphs_handler):
  71. # Create another graph
  72. coll_id = uuid.uuid4()
  73. graph_resp = await graphs_handler.create(
  74. collection_id=coll_id, name="DeletableGraph"
  75. )
  76. graph_id = graph_resp.id
  77. # Add entities
  78. e1 = await graphs_handler.entities.create(
  79. parent_id=graph_id,
  80. store_type=StoreType.GRAPHS.value,
  81. name="DeleteMe",
  82. )
  83. e2 = await graphs_handler.entities.create(
  84. parent_id=graph_id,
  85. store_type=StoreType.GRAPHS.value,
  86. name="DeleteMeToo",
  87. )
  88. # Add relationship
  89. rel = await graphs_handler.relationships.create(
  90. subject="DeleteMe",
  91. subject_id=e1.id,
  92. predicate="related_to",
  93. object="DeleteMeToo",
  94. object_id=e2.id,
  95. parent_id=graph_id,
  96. store_type=StoreType.GRAPHS.value,
  97. )
  98. # Delete one entity
  99. await graphs_handler.entities.delete(
  100. parent_id=graph_id,
  101. entity_ids=[e1.id],
  102. store_type=StoreType.GRAPHS.value,
  103. )
  104. ents, count = await graphs_handler.get_entities(
  105. parent_id=graph_id, offset=0, limit=10
  106. )
  107. assert count == 1
  108. assert ents[0].id == e2.id
  109. # Delete the relationship
  110. await graphs_handler.relationships.delete(
  111. parent_id=graph_id,
  112. relationship_ids=[rel.id],
  113. store_type=StoreType.GRAPHS.value,
  114. )
  115. rels, rel_count = await graphs_handler.get_relationships(
  116. parent_id=graph_id, offset=0, limit=10
  117. )
  118. assert rel_count == 0
  119. @pytest.mark.asyncio
  120. async def test_communities(graphs_handler):
  121. # Insert a community for a collection_id (not strictly related to a graph_id)
  122. coll_id = uuid.uuid4()
  123. await graphs_handler.communities.create(
  124. parent_id=coll_id,
  125. store_type=StoreType.GRAPHS.value,
  126. name="CommunityOne",
  127. summary="Test community",
  128. findings=["finding1", "finding2"],
  129. rating=4.5,
  130. rating_explanation="Excellent",
  131. description_embedding=[0.1, 0.2, 0.3, 0.4],
  132. )
  133. comms, count = await graphs_handler.communities.get(
  134. parent_id=coll_id,
  135. store_type=StoreType.GRAPHS.value,
  136. offset=0,
  137. limit=10,
  138. )
  139. assert count == 1
  140. assert comms[0].name == "CommunityOne"
  141. # TODO - Fix code such that these tests pass
  142. # # @pytest.mark.asyncio
  143. # # async def test_delete_graph(graphs_handler):
  144. # # # Create a graph and then delete it
  145. # # coll_id = uuid.uuid4()
  146. # # graph_resp = await graphs_handler.create(collection_id=coll_id, name="TempGraph")
  147. # # graph_id = graph_resp.id
  148. # # # reset or delete calls are complicated in the code. We'll just call `reset` and `delete`
  149. # # await graphs_handler.reset(graph_id)
  150. # # # This should remove all entities & relationships from the graph_id
  151. # # # Now delete the graph itself
  152. # # # The `delete` method seems to be tied to collection_id rather than graph_id
  153. # # await graphs_handler.delete(collection_id=graph_id, cascade=False)
  154. # # # If the code is structured so that delete requires a collection_id,
  155. # # # ensure `graph_id == collection_id` or adapt the code accordingly.
  156. # # # Try fetching the graph
  157. # # overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[graph_id])
  158. # # assert overview["total_entries"] == 0, "Graph should be deleted"
  159. # @pytest.mark.asyncio
  160. # async def test_delete_graph(graphs_handler):
  161. # # Create a graph and then delete it
  162. # coll_id = uuid.uuid4()
  163. # graph_resp = await graphs_handler.create(collection_id=coll_id, name="TempGraph")
  164. # graph_id = graph_resp.id
  165. # # Reset the graph (remove entities, relationships, communities)
  166. # await graphs_handler.reset(graph_id)
  167. # # Now delete the graph using collection_id (which equals graph_id in this code)
  168. # await graphs_handler.delete(collection_id=coll_id)
  169. # # Verify the graph is deleted
  170. # overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[coll_id])
  171. # assert overview["total_entries"] == 0, "Graph should be deleted"
  172. @pytest.mark.asyncio
  173. async def test_create_graph_defaults(graphs_handler):
  174. # Create a graph without specifying name or description
  175. coll_id = uuid.uuid4()
  176. resp = await graphs_handler.create(collection_id=coll_id)
  177. assert resp.collection_id == coll_id
  178. # The code sets a default name, which should be "Graph {coll_id}"
  179. assert resp.name == f"Graph {coll_id}"
  180. # Default description should be empty string as per code
  181. assert resp.description == ""
  182. # @pytest.mark.asyncio
  183. # async def test_list_multiple_graphs(graphs_handler):
  184. # # Create multiple graphs
  185. # coll_id1 = uuid.uuid4()
  186. # coll_id2 = uuid.uuid4()
  187. # graph_resp1 = await graphs_handler.create(collection_id=coll_id1, name="Graph1")
  188. # graph_resp2 = await graphs_handler.create(collection_id=coll_id2, name="Graph2")
  189. # graph_resp3 = await graphs_handler.create(collection_id=coll_id2, name="Graph3")
  190. # # List all graphs without filters
  191. # overview = await graphs_handler.list_graphs(offset=0, limit=10)
  192. # # Ensure at least these three are in there
  193. # found_ids = [g.id for g in overview["results"]]
  194. # assert graph_resp1.id in found_ids
  195. # assert graph_resp2.id in found_ids
  196. # assert graph_resp3.id in found_ids
  197. # # Filter by collection_id = coll_id2 should return Graph2 and Graph3 (the most recent one first if same collection)
  198. # overview_coll2 = await graphs_handler.list_graphs(offset=0, limit=10, filter_collection_id=coll_id2)
  199. # returned_ids = [g.id for g in overview_coll2["results"]]
  200. # # According to the code, we only see the "most recent" graph per collection. Verify this logic.
  201. # # If your code is returning only the most recent graph per collection, we should see only one graph per collection_id here.
  202. # # Adjust test according to actual logic you desire.
  203. # # For this example, let's assume we should only get the latest graph per collection. Graph3 should be newer than Graph2.
  204. # assert len(returned_ids) == 1
  205. # assert graph_resp3.id in returned_ids
  206. @pytest.mark.asyncio
  207. async def test_update_graph(graphs_handler):
  208. coll_id = uuid.uuid4()
  209. graph_resp = await graphs_handler.create(
  210. collection_id=coll_id, name="OldName", description="OldDescription"
  211. )
  212. graph_id = graph_resp.id
  213. # Update name and description
  214. updated_resp = await graphs_handler.update(
  215. collection_id=graph_id, name="NewName", description="NewDescription"
  216. )
  217. assert updated_resp.name == "NewName"
  218. assert updated_resp.description == "NewDescription"
  219. # Retrieve and verify
  220. overview = await graphs_handler.list_graphs(
  221. offset=0, limit=10, filter_graph_ids=[graph_id]
  222. )
  223. assert overview["total_entries"] == 1
  224. fetched_graph = overview["results"][0]
  225. assert fetched_graph.name == "NewName"
  226. assert fetched_graph.description == "NewDescription"
  227. @pytest.mark.asyncio
  228. async def test_bulk_entities(graphs_handler):
  229. coll_id = uuid.uuid4()
  230. graph_resp = await graphs_handler.create(
  231. collection_id=coll_id, name="BulkEntities"
  232. )
  233. graph_id = graph_resp.id
  234. # Add multiple entities
  235. entities_to_add = [
  236. {"name": "EntityA", "category": "CategoryA", "description": "DescA"},
  237. {"name": "EntityB", "category": "CategoryB", "description": "DescB"},
  238. {"name": "EntityC", "category": "CategoryC", "description": "DescC"},
  239. ]
  240. for ent in entities_to_add:
  241. await graphs_handler.entities.create(
  242. parent_id=graph_id,
  243. store_type=StoreType.GRAPHS.value,
  244. name=ent["name"],
  245. category=ent["category"],
  246. description=ent["description"],
  247. )
  248. ents, total = await graphs_handler.get_entities(
  249. parent_id=graph_id, offset=0, limit=10
  250. )
  251. assert total == 3
  252. fetched_names = [e.name for e in ents]
  253. for ent in entities_to_add:
  254. assert ent["name"] in fetched_names
  255. @pytest.mark.asyncio
  256. async def test_relationship_filtering(graphs_handler):
  257. coll_id = uuid.uuid4()
  258. graph_resp = await graphs_handler.create(
  259. collection_id=coll_id, name="RelFilteringGraph"
  260. )
  261. graph_id = graph_resp.id
  262. # Add entities
  263. e1 = await graphs_handler.entities.create(
  264. parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="Node1"
  265. )
  266. e2 = await graphs_handler.entities.create(
  267. parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="Node2"
  268. )
  269. e3 = await graphs_handler.entities.create(
  270. parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="Node3"
  271. )
  272. # Add different relationships
  273. await graphs_handler.relationships.create(
  274. subject="Node1",
  275. subject_id=e1.id,
  276. predicate="connected_to",
  277. object="Node2",
  278. object_id=e2.id,
  279. parent_id=graph_id,
  280. store_type=StoreType.GRAPHS.value,
  281. )
  282. await graphs_handler.relationships.create(
  283. subject="Node2",
  284. subject_id=e2.id,
  285. predicate="linked_with",
  286. object="Node3",
  287. object_id=e3.id,
  288. parent_id=graph_id,
  289. store_type=StoreType.GRAPHS.value,
  290. )
  291. # Get all relationships
  292. all_rels, all_count = await graphs_handler.get_relationships(
  293. parent_id=graph_id, offset=0, limit=10
  294. )
  295. assert all_count == 2
  296. # Filter by relationship_type = ["connected_to"]
  297. filtered_rels, filt_count = await graphs_handler.get_relationships(
  298. parent_id=graph_id,
  299. offset=0,
  300. limit=10,
  301. relationship_types=["connected_to"],
  302. )
  303. assert filt_count == 1
  304. assert filtered_rels[0].predicate == "connected_to"
  305. @pytest.mark.asyncio
  306. async def test_delete_all_entities(graphs_handler):
  307. coll_id = uuid.uuid4()
  308. graph_resp = await graphs_handler.create(
  309. collection_id=coll_id, name="DeleteAllEntities"
  310. )
  311. graph_id = graph_resp.id
  312. # Add some entities
  313. await graphs_handler.entities.create(
  314. parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E1"
  315. )
  316. await graphs_handler.entities.create(
  317. parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E2"
  318. )
  319. # Delete all entities without specifying IDs
  320. await graphs_handler.entities.delete(
  321. parent_id=graph_id, store_type=StoreType.GRAPHS.value
  322. )
  323. ents, count = await graphs_handler.get_entities(
  324. parent_id=graph_id, offset=0, limit=10
  325. )
  326. assert count == 0
  327. @pytest.mark.asyncio
  328. async def test_delete_all_relationships(graphs_handler):
  329. coll_id = uuid.uuid4()
  330. graph_resp = await graphs_handler.create(
  331. collection_id=coll_id, name="DeleteAllRels"
  332. )
  333. graph_id = graph_resp.id
  334. # Add two entities and a relationship
  335. e1 = await graphs_handler.entities.create(
  336. parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E1"
  337. )
  338. e2 = await graphs_handler.entities.create(
  339. parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E2"
  340. )
  341. await graphs_handler.relationships.create(
  342. subject="E1",
  343. subject_id=e1.id,
  344. predicate="connected",
  345. object="E2",
  346. object_id=e2.id,
  347. parent_id=graph_id,
  348. store_type=StoreType.GRAPHS.value,
  349. )
  350. # Delete all relationships
  351. await graphs_handler.relationships.delete(
  352. parent_id=graph_id, store_type=StoreType.GRAPHS.value
  353. )
  354. rels, rel_count = await graphs_handler.get_relationships(
  355. parent_id=graph_id, offset=0, limit=10
  356. )
  357. assert rel_count == 0
  358. @pytest.mark.asyncio
  359. async def test_error_handling_invalid_graph_id(graphs_handler):
  360. # Attempt to get a non-existent graph
  361. non_existent_id = uuid.uuid4()
  362. overview = await graphs_handler.list_graphs(
  363. offset=0, limit=10, filter_graph_ids=[non_existent_id]
  364. )
  365. assert overview["total_entries"] == 0
  366. # Attempt to delete a non-existent graph
  367. with pytest.raises(Exception) as exc_info:
  368. await graphs_handler.delete(collection_id=non_existent_id)
  369. # Expect an R2RException or HTTPException (depending on your code)
  370. # Check the message or type if needed
  371. # TODO - Fix code to pass this test.
  372. # @pytest.mark.asyncio
  373. # async def test_delete_graph_cascade(graphs_handler):
  374. # coll_id = uuid.uuid4()
  375. # graph_resp = await graphs_handler.create(collection_id=coll_id, name="CascadeGraph")
  376. # graph_id = graph_resp.id
  377. # # Add entities/relationships here if you have documents attached
  378. # # This test would verify that cascade=True behavior is correct
  379. # # For now, just call delete with cascade=True
  380. # # Depending on your implementation, you might need documents associated with the collection to test fully.
  381. # await graphs_handler.delete(collection_id=coll_id)
  382. # overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[graph_id])
  383. # assert overview["total_entries"] == 0