test_graphs.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import uuid
  2. import pytest
  3. from r2r import R2RClient, R2RException
  4. @pytest.fixture(scope="session")
  5. def config():
  6. class TestConfig:
  7. base_url = "http://localhost:7272"
  8. superuser_email = "admin@example.com"
  9. superuser_password = "change_me_immediately"
  10. return TestConfig()
  11. @pytest.fixture(scope="session")
  12. def client(config):
  13. """Create a client instance and possibly log in as a superuser."""
  14. client = R2RClient(config.base_url)
  15. client.users.login(config.superuser_email, config.superuser_password)
  16. return client
  17. @pytest.fixture
  18. def test_collection(client):
  19. """Create a test collection (and thus a graph) for testing, then delete it afterwards."""
  20. resp = client.collections.create(
  21. name=f"Test Collection {uuid.uuid4()}",
  22. description="A sample collection for graph tests",
  23. )
  24. collection_id = resp["results"]["id"]
  25. yield collection_id
  26. # Cleanup if needed
  27. # If there's a deletion endpoint for collections, call it here.
  28. client.collections.delete(id=collection_id)
  29. def test_list_graphs(client):
  30. resp = client.graphs.list(limit=5)
  31. assert "results" in resp, "No results field in list response"
  32. def test_create_and_get_graph(client, test_collection):
  33. # `test_collection` fixture creates a collection and returns ID
  34. collection_id = test_collection
  35. resp = client.graphs.retrieve(collection_id=collection_id)["results"]
  36. assert resp["collection_id"] == collection_id, "Graph ID mismatch"
  37. def test_update_graph(client, test_collection):
  38. collection_id = test_collection
  39. new_name = "Updated Test Graph Name"
  40. new_description = "Updated test description"
  41. resp = client.graphs.update(
  42. collection_id=collection_id, name=new_name, description=new_description
  43. )["results"]
  44. assert resp["name"] == new_name, "Name not updated correctly"
  45. assert (
  46. resp["description"] == new_description
  47. ), "Description not updated correctly"
  48. def test_list_entities(client, test_collection):
  49. collection_id = test_collection
  50. resp = client.graphs.list_entities(collection_id=collection_id, limit=5)[
  51. "results"
  52. ]
  53. assert isinstance(resp, list), "No results array in entities response"
  54. def test_create_and_get_entity(client, test_collection):
  55. collection_id = test_collection
  56. entity_name = "Test Entity"
  57. entity_description = "Test entity description"
  58. create_resp = client.graphs.create_entity(
  59. collection_id=collection_id,
  60. name=entity_name,
  61. description=entity_description,
  62. )["results"]
  63. entity_id = create_resp["id"]
  64. resp = client.graphs.get_entity(
  65. collection_id=collection_id, entity_id=entity_id
  66. )["results"]
  67. assert resp["name"] == entity_name, "Entity name mismatch"
  68. def test_list_relationships(client, test_collection):
  69. collection_id = test_collection
  70. resp = client.graphs.list_relationships(
  71. collection_id=collection_id, limit=5
  72. )["results"]
  73. assert isinstance(resp, list), "No results array in relationships response"
  74. def test_create_and_get_relationship(client, test_collection):
  75. collection_id = test_collection
  76. # Create two entities
  77. entity1 = client.graphs.create_entity(
  78. collection_id=collection_id,
  79. name="Entity 1",
  80. description="Entity 1 description",
  81. )["results"]
  82. entity2 = client.graphs.create_entity(
  83. collection_id=collection_id,
  84. name="Entity 2",
  85. description="Entity 2 description",
  86. )["results"]
  87. # Create relationship
  88. rel_resp = client.graphs.create_relationship(
  89. collection_id=collection_id,
  90. subject="Entity 1",
  91. subject_id=entity1["id"],
  92. predicate="related_to",
  93. object="Entity 2",
  94. object_id=entity2["id"],
  95. description="Test relationship",
  96. )["results"]
  97. relationship_id = rel_resp["id"]
  98. # Get relationship
  99. resp = client.graphs.get_relationship(
  100. collection_id=collection_id, relationship_id=relationship_id
  101. )["results"]
  102. assert resp["predicate"] == "related_to", "Relationship predicate mismatch"
  103. def test_build_communities(client, test_collection):
  104. collection_id = test_collection
  105. # Create two entities
  106. entity1 = client.graphs.create_entity(
  107. collection_id=collection_id,
  108. name="Entity 1",
  109. description="Entity 1 description",
  110. )["results"]
  111. entity2 = client.graphs.create_entity(
  112. collection_id=collection_id,
  113. name="Entity 2",
  114. description="Entity 2 description",
  115. )["results"]
  116. # Create relationship
  117. rel_resp = client.graphs.create_relationship(
  118. collection_id=collection_id,
  119. subject="Entity 1",
  120. subject_id=entity1["id"],
  121. predicate="related_to",
  122. object="Entity 2",
  123. object_id=entity2["id"],
  124. description="Test relationship",
  125. )["results"]
  126. relationship_id = rel_resp["id"]
  127. # Build communities
  128. # Adjust parameters as needed if `run_type` and `settings` differ.
  129. # The router expects `run_type` and `graph_enrichment_settings`.
  130. resp = client.graphs.build(
  131. collection_id=collection_id,
  132. run_type="run",
  133. # graph_enrichment_settings={"use_semantic_clustering": True},
  134. run_with_orchestration=False,
  135. )["results"]
  136. # After building, list communities
  137. resp = client.graphs.list_communities(
  138. collection_id=collection_id, limit=5
  139. )["results"]
  140. # We cannot guarantee communities are created if no entities or special conditions apply.
  141. # If no communities, we may skip this assert or ensure at least no error occurred.
  142. assert isinstance(resp, list), "No communities array returned."
  143. def test_list_communities(client, test_collection):
  144. collection_id = test_collection
  145. resp = client.graphs.list_communities(
  146. collection_id=collection_id, limit=5
  147. )["results"]
  148. assert isinstance(resp, list), "No results array in communities response"
  149. def test_create_and_get_community(client, test_collection):
  150. collection_id = test_collection
  151. community_name = "Test Community"
  152. community_summary = "Test community summary"
  153. create_resp = client.graphs.create_community(
  154. collection_id=collection_id,
  155. name=community_name,
  156. summary=community_summary,
  157. findings=["Finding 1", "Finding 2"],
  158. rating=8,
  159. )["results"]
  160. community_id = create_resp["id"]
  161. resp = client.graphs.get_community(
  162. collection_id=collection_id, community_id=community_id
  163. )["results"]
  164. assert resp["name"] == community_name, "Community name mismatch"
  165. def test_update_community(client, test_collection):
  166. collection_id = test_collection
  167. # Create a community to update
  168. create_resp = client.graphs.create_community(
  169. collection_id=collection_id,
  170. name="Community to update",
  171. summary="Original summary",
  172. findings=["Original finding"],
  173. rating=7,
  174. )["results"]
  175. community_id = create_resp["id"]
  176. # Update the community
  177. resp = client.graphs.update_community(
  178. collection_id=collection_id,
  179. community_id=community_id,
  180. name="Updated Community",
  181. summary="Updated summary",
  182. findings=["New finding"],
  183. rating=9,
  184. )["results"]
  185. assert resp["name"] == "Updated Community", "Community update failed"
  186. def test_pull_operation(client, test_collection):
  187. collection_id = test_collection
  188. resp = client.graphs.pull(collection_id=collection_id)["results"]
  189. # Depending on your system, the pull might require documents in the collection to actually do something.
  190. # If no documents are present, it might return success=False or a warning. Check behavior and adjust test.
  191. assert "success" in resp, "No success indicator in pull response"
  192. def test_error_handling(client):
  193. # Test retrieving a graph with invalid ID
  194. invalid_id = "not-a-uuid"
  195. with pytest.raises(R2RException) as exc_info:
  196. client.graphs.retrieve(collection_id=invalid_id)
  197. # Expecting a 422 or 404 error. Adjust as per your API's expected response.
  198. assert exc_info.value.status_code in [
  199. 400,
  200. 422,
  201. 404,
  202. ], "Expected an error for invalid ID."