test_conversations.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import json
  2. import uuid
  3. from uuid import UUID
  4. import pytest
  5. from core.base import Message, R2RException
  6. from shared.api.models.management.responses import (
  7. ConversationResponse,
  8. MessageResponse,
  9. )
  10. @pytest.mark.asyncio
  11. async def test_create_conversation(conversations_handler):
  12. resp = await conversations_handler.create_conversation()
  13. assert isinstance(resp, ConversationResponse)
  14. assert resp.id is not None
  15. assert resp.created_at is not None
  16. @pytest.mark.asyncio
  17. async def test_create_conversation_with_user_and_name(conversations_handler):
  18. user_id = uuid.uuid4()
  19. resp = await conversations_handler.create_conversation(
  20. user_id=user_id, name="Test Conv"
  21. )
  22. assert resp.id is not None
  23. assert resp.created_at is not None
  24. # There's no direct field for user_id in ConversationResponse,
  25. # but we can verify by fetch:
  26. # Just trust it for now since the handler doesn't return user_id directly.
  27. @pytest.mark.asyncio
  28. async def test_add_message(conversations_handler):
  29. conv = await conversations_handler.create_conversation()
  30. conv_id = conv.id
  31. msg = Message(role="user", content="Hello!")
  32. resp = await conversations_handler.add_message(conv_id, msg)
  33. assert isinstance(resp, MessageResponse)
  34. assert resp.id is not None
  35. assert resp.message.content == "Hello!"
  36. @pytest.mark.asyncio
  37. async def test_add_message_with_parent(conversations_handler):
  38. conv = await conversations_handler.create_conversation()
  39. conv_id = conv.id
  40. parent_msg = Message(role="user", content="Parent message")
  41. parent_resp = await conversations_handler.add_message(conv_id, parent_msg)
  42. parent_id = parent_resp.id
  43. child_msg = Message(role="assistant", content="Child reply")
  44. child_resp = await conversations_handler.add_message(
  45. conv_id, child_msg, parent_id=parent_id
  46. )
  47. assert child_resp.id is not None
  48. assert child_resp.message.content == "Child reply"
  49. @pytest.mark.asyncio
  50. async def test_edit_message(conversations_handler):
  51. conv = await conversations_handler.create_conversation()
  52. conv_id = conv.id
  53. original_msg = Message(role="user", content="Original")
  54. resp = await conversations_handler.add_message(conv_id, original_msg)
  55. msg_id = resp.id
  56. updated = await conversations_handler.edit_message(
  57. msg_id, "Edited content"
  58. )
  59. assert updated["message"].content == "Edited content"
  60. assert updated["metadata"]["edited"] is True
  61. @pytest.mark.asyncio
  62. async def test_update_message_metadata(conversations_handler):
  63. conv = await conversations_handler.create_conversation()
  64. conv_id = conv.id
  65. msg = Message(role="user", content="Meta-test")
  66. resp = await conversations_handler.add_message(conv_id, msg)
  67. msg_id = resp.id
  68. await conversations_handler.update_message_metadata(
  69. msg_id, {"test_key": "test_value"}
  70. )
  71. # Verify metadata updated
  72. full_conversation = await conversations_handler.get_conversation(conv_id)
  73. for m in full_conversation:
  74. if m.id == str(msg_id):
  75. assert m.metadata["test_key"] == "test_value"
  76. break
  77. @pytest.mark.asyncio
  78. async def test_get_conversation(conversations_handler):
  79. conv = await conversations_handler.create_conversation()
  80. conv_id = conv.id
  81. msg1 = Message(role="user", content="Msg1")
  82. msg2 = Message(role="assistant", content="Msg2")
  83. await conversations_handler.add_message(conv_id, msg1)
  84. await conversations_handler.add_message(conv_id, msg2)
  85. messages = await conversations_handler.get_conversation(conv_id)
  86. assert len(messages) == 2
  87. assert messages[0].message.content == "Msg1"
  88. assert messages[1].message.content == "Msg2"
  89. @pytest.mark.asyncio
  90. async def test_delete_conversation(conversations_handler):
  91. conv = await conversations_handler.create_conversation()
  92. conv_id = conv.id
  93. msg = Message(role="user", content="To be deleted")
  94. await conversations_handler.add_message(conv_id, msg)
  95. await conversations_handler.delete_conversation(conv_id)
  96. with pytest.raises(R2RException) as exc:
  97. await conversations_handler.get_conversation(conv_id)
  98. assert (
  99. exc.value.status_code == 404
  100. ), "Conversation should be deleted and not found"