test_graphs.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import uuid
  2. import pytest
  3. from r2r import R2RClient, R2RException
  4. @pytest.fixture(scope="session")
  5. def config():
  6. class TestConfig:
  7. base_url = "http://localhost:7272"
  8. superuser_email = "admin@example.com"
  9. superuser_password = "change_me_immediately"
  10. return TestConfig()
  11. @pytest.fixture(scope="session")
  12. def client(config):
  13. """Create a client instance and possibly log in as a superuser."""
  14. client = R2RClient(config.base_url)
  15. client.users.login(config.superuser_email, config.superuser_password)
  16. return client
  17. @pytest.fixture
  18. def test_collection(client):
  19. """Create a test collection (and thus a graph) for testing, then delete it
  20. afterwards."""
  21. collection_id = client.collections.create(
  22. name=f"Test Collection {uuid.uuid4()}",
  23. description="A sample collection for graph tests",
  24. ).results.id
  25. yield collection_id
  26. # Cleanup if needed
  27. # If there's a deletion endpoint for collections, call it here.
  28. client.collections.delete(id=collection_id)
  29. def test_list_graphs(client: R2RClient):
  30. resp = client.graphs.list(limit=5)
  31. assert resp.results is not None, "No results field in list response"
  32. def test_create_and_get_graph(client: R2RClient, test_collection):
  33. # `test_collection` fixture creates a collection and returns ID
  34. collection_id = test_collection
  35. resp = client.graphs.retrieve(collection_id=collection_id).results
  36. assert str(resp.collection_id) == str(collection_id), "Graph ID mismatch"
  37. def test_update_graph(client: R2RClient, test_collection):
  38. collection_id = test_collection
  39. new_name = "Updated Test Graph Name"
  40. new_description = "Updated test description"
  41. resp = client.graphs.update(collection_id=collection_id,
  42. name=new_name,
  43. description=new_description).results
  44. assert resp.name == new_name, "Name not updated correctly"
  45. assert resp.description == new_description, (
  46. "Description not updated correctly")
  47. def test_list_entities(client: R2RClient, test_collection):
  48. collection_id = test_collection
  49. resp = client.graphs.list_entities(collection_id=collection_id,
  50. limit=5).results
  51. assert isinstance(resp, list), "No results array in entities response"
  52. def test_create_and_get_entity(client: R2RClient, test_collection):
  53. collection_id = test_collection
  54. entity_name = "Test Entity"
  55. entity_description = "Test entity description"
  56. create_resp = client.graphs.create_entity(
  57. collection_id=collection_id,
  58. name=entity_name,
  59. description=entity_description,
  60. ).results
  61. entity_id = str(create_resp.id)
  62. resp = client.graphs.get_entity(collection_id=collection_id,
  63. entity_id=entity_id).results
  64. assert resp.name == entity_name, "Entity name mismatch"
  65. def test_list_relationships(client: R2RClient, test_collection):
  66. collection_id = test_collection
  67. resp = client.graphs.list_relationships(collection_id=collection_id,
  68. limit=5).results
  69. assert isinstance(resp, list), "No results array in relationships response"
  70. def test_create_and_get_relationship(client: R2RClient, test_collection):
  71. collection_id = test_collection
  72. # Create two entities
  73. entity1 = client.graphs.create_entity(
  74. collection_id=collection_id,
  75. name="Entity 1",
  76. description="Entity 1 description",
  77. ).results
  78. entity2 = client.graphs.create_entity(
  79. collection_id=collection_id,
  80. name="Entity 2",
  81. description="Entity 2 description",
  82. ).results
  83. # Create relationship
  84. rel_resp = client.graphs.create_relationship(
  85. collection_id=collection_id,
  86. subject="Entity 1",
  87. subject_id=entity1.id,
  88. predicate="related_to",
  89. object="Entity 2",
  90. object_id=entity2.id,
  91. description="Test relationship",
  92. ).results
  93. relationship_id = str(rel_resp.id)
  94. # Get relationship
  95. resp = client.graphs.get_relationship(
  96. collection_id=collection_id, relationship_id=relationship_id).results
  97. assert resp.predicate == "related_to", "Relationship predicate mismatch"
  98. # def test_build_communities(client: R2RClient, test_collection):
  99. # collection_id = test_collection
  100. # # Create two entities
  101. # entity1 = client.graphs.create_entity(
  102. # collection_id=collection_id,
  103. # name="Entity 1",
  104. # description="Entity 1 description",
  105. # ).results
  106. # entity2 = client.graphs.create_entity(
  107. # collection_id=collection_id,
  108. # name="Entity 2",
  109. # description="Entity 2 description",
  110. # ).results
  111. # # Create relationship
  112. # rel_resp = client.graphs.create_relationship(
  113. # collection_id=str(collection_id),
  114. # subject="Entity 1",
  115. # subject_id=entity1.id,
  116. # predicate="related_to",
  117. # object="Entity 2",
  118. # object_id=entity2.id,
  119. # description="Test relationship",
  120. # ).results
  121. # relationship_id = str(rel_resp.id)
  122. # # Build communities
  123. # resp = client.graphs.build(
  124. # collection_id=str(collection_id),
  125. # # graph_enrichment_settings={"use_semantic_clustering": True},
  126. # run_with_orchestration=False,
  127. # ).results
  128. # # After building, list communities
  129. # resp = client.graphs.list_communities(collection_id=str(collection_id),
  130. # limit=5).results
  131. # # We cannot guarantee communities are created if no entities or special conditions apply.
  132. # # If no communities, we may skip this assert or ensure at least no error occurred.
  133. # assert isinstance(resp, list), "No communities array returned."
  134. def test_list_communities(client: R2RClient, test_collection):
  135. collection_id = test_collection
  136. resp = client.graphs.list_communities(collection_id=collection_id,
  137. limit=5).results
  138. assert isinstance(resp, list), "No results array in communities response"
  139. def test_create_and_get_community(client: R2RClient, test_collection):
  140. collection_id = test_collection
  141. community_name = "Test Community"
  142. community_summary = "Test community summary"
  143. create_resp = client.graphs.create_community(
  144. collection_id=collection_id,
  145. name=community_name,
  146. summary=community_summary,
  147. findings=["Finding 1", "Finding 2"],
  148. rating=8,
  149. ).results
  150. community_id = str(create_resp.id)
  151. resp = client.graphs.get_community(collection_id=collection_id,
  152. community_id=community_id).results
  153. assert resp.name == community_name, "Community name mismatch"
  154. def test_update_community(client: R2RClient, test_collection):
  155. collection_id = test_collection
  156. # Create a community to update
  157. create_resp = client.graphs.create_community(
  158. collection_id=collection_id,
  159. name="Community to update",
  160. summary="Original summary",
  161. findings=["Original finding"],
  162. rating=7,
  163. ).results
  164. community_id = str(create_resp.id)
  165. # Update the community
  166. resp = client.graphs.update_community(
  167. collection_id=collection_id,
  168. community_id=community_id,
  169. name="Updated Community",
  170. summary="Updated summary",
  171. findings=["New finding"],
  172. rating=9,
  173. ).results
  174. assert resp.name == "Updated Community", "Community update failed"
  175. def test_pull_operation(client: R2RClient, test_collection):
  176. collection_id = test_collection
  177. resp = client.graphs.pull(collection_id=collection_id).results
  178. assert resp.success is not None, "No success indicator in pull response"
  179. def test_error_handling(client: R2RClient):
  180. # Test retrieving a graph with invalid ID
  181. invalid_id = "not-a-uuid"
  182. with pytest.raises(R2RException) as exc_info:
  183. client.graphs.retrieve(collection_id=invalid_id)
  184. # Expecting a 422 or 404 error. Adjust as per your API's expected response.
  185. assert exc_info.value.status_code in [
  186. 400,
  187. 422,
  188. 404,
  189. ], "Expected an error for invalid ID."