conversations.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. import json
  2. from typing import Any, Dict, List, Optional
  3. from uuid import UUID, uuid4
  4. from core.base import Handler, Message, R2RException
  5. from shared.api.models.management.responses import (
  6. ConversationResponse,
  7. MessageResponse,
  8. )
  9. from .base import PostgresConnectionManager
  10. class PostgresConversationsHandler(Handler):
  11. def __init__(
  12. self, project_name: str, connection_manager: PostgresConnectionManager
  13. ):
  14. self.project_name = project_name
  15. self.connection_manager = connection_manager
  16. async def create_tables(self):
  17. # Ensure the uuid_generate_v4() extension is available
  18. # Depending on your environment, you may need a separate call:
  19. # await self.connection_manager.execute_query("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";")
  20. create_conversations_query = f"""
  21. CREATE TABLE IF NOT EXISTS {self._get_table_name("conversations")} (
  22. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  23. user_id UUID,
  24. created_at TIMESTAMPTZ DEFAULT NOW(),
  25. name TEXT
  26. );
  27. """
  28. create_messages_query = f"""
  29. CREATE TABLE IF NOT EXISTS {self._get_table_name("messages")} (
  30. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  31. conversation_id UUID NOT NULL,
  32. parent_id UUID,
  33. content JSONB,
  34. metadata JSONB,
  35. created_at TIMESTAMPTZ DEFAULT NOW(),
  36. FOREIGN KEY (conversation_id) REFERENCES {self._get_table_name("conversations")}(id),
  37. FOREIGN KEY (parent_id) REFERENCES {self._get_table_name("messages")}(id)
  38. );
  39. """
  40. await self.connection_manager.execute_query(create_conversations_query)
  41. await self.connection_manager.execute_query(create_messages_query)
  42. async def create_conversation(
  43. self, user_id: Optional[UUID] = None, name: Optional[str] = None
  44. ) -> ConversationResponse:
  45. query = f"""
  46. INSERT INTO {self._get_table_name("conversations")} (user_id, name)
  47. VALUES ($1, $2)
  48. RETURNING id, extract(epoch from created_at) as created_at_epoch
  49. """
  50. result = await self.connection_manager.fetchrow_query(
  51. query, [user_id, name]
  52. )
  53. if not result:
  54. raise R2RException(
  55. status_code=500, message="Failed to create conversation."
  56. )
  57. return ConversationResponse(
  58. id=str(result["id"]),
  59. created_at=result["created_at_epoch"],
  60. )
  61. async def verify_conversation_access(
  62. self, conversation_id: UUID, user_id: UUID
  63. ) -> bool:
  64. query = f"""
  65. SELECT 1 FROM {self._get_table_name("conversations")}
  66. WHERE id = $1 AND (user_id IS NULL OR user_id = $2)
  67. """
  68. row = await self.connection_manager.fetchrow_query(
  69. query, [conversation_id, user_id]
  70. )
  71. return row is not None
  72. async def get_conversations_overview(
  73. self,
  74. offset: int,
  75. limit: int,
  76. user_ids: Optional[UUID | List[UUID]] = None,
  77. conversation_ids: Optional[List[UUID]] = None,
  78. ) -> Dict[str, Any]:
  79. # Construct conditions
  80. conditions = []
  81. params = []
  82. param_index = 1
  83. if user_ids is not None:
  84. if isinstance(user_ids, UUID):
  85. conditions.append(f"user_id = ${param_index}")
  86. params.append(user_ids)
  87. param_index += 1
  88. else:
  89. # user_ids is a list of UUIDs
  90. placeholders = ", ".join(
  91. f"${i+param_index}" for i in range(len(user_ids))
  92. )
  93. conditions.append(
  94. f"user_id = ANY(ARRAY[{placeholders}]::uuid[])"
  95. )
  96. params.extend(user_ids)
  97. param_index += len(user_ids)
  98. if conversation_ids:
  99. placeholders = ", ".join(
  100. f"${i+param_index}" for i in range(len(conversation_ids))
  101. )
  102. conditions.append(f"id = ANY(ARRAY[{placeholders}]::uuid[])")
  103. params.extend(conversation_ids)
  104. param_index += len(conversation_ids)
  105. where_clause = ""
  106. if conditions:
  107. where_clause = "WHERE " + " AND ".join(conditions)
  108. limit_clause = ""
  109. if limit != -1:
  110. limit_clause = f"LIMIT ${param_index}"
  111. params.append(limit)
  112. param_index += 1
  113. offset_clause = f"OFFSET ${param_index}"
  114. params.append(offset)
  115. query = f"""
  116. WITH conversation_overview AS (
  117. SELECT id, extract(epoch from created_at) as created_at_epoch, user_id, name
  118. FROM {self._get_table_name("conversations")}
  119. {where_clause}
  120. ),
  121. counted_overview AS (
  122. SELECT *,
  123. COUNT(*) OVER() AS total_entries
  124. FROM conversation_overview
  125. )
  126. SELECT * FROM counted_overview
  127. ORDER BY created_at_epoch DESC
  128. {limit_clause} {offset_clause}
  129. """
  130. results = await self.connection_manager.fetch_query(query, params)
  131. if not results:
  132. return {"results": [], "total_entries": 0}
  133. total_entries = results[0]["total_entries"]
  134. conversations = [
  135. {
  136. "id": str(row["id"]),
  137. "created_at": row["created_at_epoch"],
  138. "user_id": str(row["user_id"]) if row["user_id"] else None,
  139. "name": row["name"] or None,
  140. }
  141. for row in results
  142. ]
  143. return {"results": conversations, "total_entries": total_entries}
  144. async def add_message(
  145. self,
  146. conversation_id: UUID,
  147. content: Message,
  148. parent_id: Optional[UUID] = None,
  149. metadata: Optional[dict] = None,
  150. ) -> MessageResponse:
  151. # Check if conversation exists
  152. conv_check_query = f"""
  153. SELECT 1 FROM {self._get_table_name("conversations")}
  154. WHERE id = $1
  155. """
  156. conv_row = await self.connection_manager.fetchrow_query(
  157. conv_check_query, [conversation_id]
  158. )
  159. if not conv_row:
  160. raise R2RException(
  161. status_code=404,
  162. message=f"Conversation {conversation_id} not found.",
  163. )
  164. # Check parent message if provided
  165. if parent_id:
  166. parent_check_query = f"""
  167. SELECT 1 FROM {self._get_table_name("messages")}
  168. WHERE id = $1 AND conversation_id = $2
  169. """
  170. parent_row = await self.connection_manager.fetchrow_query(
  171. parent_check_query, [parent_id, conversation_id]
  172. )
  173. if not parent_row:
  174. raise R2RException(
  175. status_code=404,
  176. message=f"Parent message {parent_id} not found in conversation {conversation_id}.",
  177. )
  178. message_id = uuid4()
  179. content_str = json.dumps(content.model_dump())
  180. metadata_str = json.dumps(metadata or {})
  181. query = f"""
  182. INSERT INTO {self._get_table_name("messages")}
  183. (id, conversation_id, parent_id, content, created_at, metadata)
  184. VALUES ($1, $2, $3, $4::jsonb, NOW(), $5::jsonb)
  185. RETURNING id
  186. """
  187. inserted = await self.connection_manager.fetchrow_query(
  188. query,
  189. [
  190. message_id,
  191. conversation_id,
  192. parent_id,
  193. content_str,
  194. metadata_str,
  195. ],
  196. )
  197. if not inserted:
  198. raise R2RException(
  199. status_code=500, message="Failed to insert message."
  200. )
  201. return MessageResponse(id=str(message_id), message=content)
  202. async def edit_message(
  203. self,
  204. message_id: UUID,
  205. new_content: str,
  206. additional_metadata: dict = {},
  207. ) -> Dict[str, Any]:
  208. # Get the original message
  209. query = f"""
  210. SELECT conversation_id, parent_id, content, metadata
  211. FROM {self._get_table_name("messages")}
  212. WHERE id = $1
  213. """
  214. row = await self.connection_manager.fetchrow_query(query, [message_id])
  215. if not row:
  216. raise R2RException(
  217. status_code=404, message=f"Message {message_id} not found."
  218. )
  219. old_content = json.loads(row["content"])
  220. old_metadata = json.loads(row["metadata"])
  221. # Update the content
  222. old_message = Message(**old_content)
  223. edited_message = Message(
  224. role=old_message.role,
  225. content=new_content,
  226. name=old_message.name,
  227. function_call=old_message.function_call,
  228. tool_calls=old_message.tool_calls,
  229. )
  230. # Merge metadata and mark edited
  231. new_metadata = {**old_metadata, **additional_metadata, "edited": True}
  232. # Instead of branching, we'll simply replace the message content and metadata:
  233. # NOTE: If you prefer versioning or forking behavior, you'd add a new message.
  234. # For simplicity, we just edit the existing message.
  235. update_query = f"""
  236. UPDATE {self._get_table_name("messages")}
  237. SET content = $1::jsonb, metadata = $2::jsonb, created_at = NOW()
  238. WHERE id = $3
  239. RETURNING id
  240. """
  241. updated = await self.connection_manager.fetchrow_query(
  242. update_query,
  243. [
  244. json.dumps(edited_message.model_dump()),
  245. json.dumps(new_metadata),
  246. message_id,
  247. ],
  248. )
  249. if not updated:
  250. raise R2RException(
  251. status_code=500, message="Failed to update message."
  252. )
  253. return {
  254. "id": str(message_id),
  255. "message": edited_message,
  256. "metadata": new_metadata,
  257. }
  258. async def update_message_metadata(
  259. self, message_id: UUID, metadata: dict
  260. ) -> None:
  261. # Fetch current metadata
  262. query = f"""
  263. SELECT metadata FROM {self._get_table_name("messages")}
  264. WHERE id = $1
  265. """
  266. row = await self.connection_manager.fetchrow_query(query, [message_id])
  267. if not row:
  268. raise R2RException(
  269. status_code=404, message=f"Message {message_id} not found."
  270. )
  271. current_metadata = row["metadata"] or {}
  272. updated_metadata = {**current_metadata, **metadata}
  273. update_query = f"""
  274. UPDATE {self._get_table_name("messages")}
  275. SET metadata = $1::jsonb
  276. WHERE id = $2
  277. """
  278. await self.connection_manager.execute_query(
  279. update_query, [updated_metadata, message_id]
  280. )
  281. async def get_conversation(
  282. self, conversation_id: UUID
  283. ) -> List[MessageResponse]:
  284. # Check conversation
  285. conv_query = f"SELECT extract(epoch from created_at) AS created_at_epoch FROM {self._get_table_name('conversations')} WHERE id = $1"
  286. conv_row = await self.connection_manager.fetchrow_query(
  287. conv_query, [conversation_id]
  288. )
  289. if not conv_row:
  290. raise R2RException(
  291. status_code=404,
  292. message=f"Conversation {conversation_id} not found.",
  293. )
  294. # Retrieve messages in chronological order
  295. # We'll recursively gather messages based on parent_id = NULL as root.
  296. # Since no branching, we simply order by created_at.
  297. msg_query = f"""
  298. SELECT id, content, metadata
  299. FROM {self._get_table_name("messages")}
  300. WHERE conversation_id = $1
  301. ORDER BY created_at ASC
  302. """
  303. results = await self.connection_manager.fetch_query(
  304. msg_query, [conversation_id]
  305. )
  306. print("results = ", results)
  307. return [
  308. MessageResponse(
  309. id=str(row["id"]),
  310. message=Message(**json.loads(row["content"])),
  311. metadata=json.loads(row["metadata"]),
  312. )
  313. for row in results
  314. ]
  315. async def delete_conversation(self, conversation_id: UUID):
  316. # Check if conversation exists
  317. conv_query = f"SELECT 1 FROM {self._get_table_name('conversations')} WHERE id = $1"
  318. conv_row = await self.connection_manager.fetchrow_query(
  319. conv_query, [conversation_id]
  320. )
  321. if not conv_row:
  322. raise R2RException(
  323. status_code=404,
  324. message=f"Conversation {conversation_id} not found.",
  325. )
  326. # Delete all messages
  327. del_messages_query = f"DELETE FROM {self._get_table_name('messages')} WHERE conversation_id = $1"
  328. await self.connection_manager.execute_query(
  329. del_messages_query, [conversation_id]
  330. )
  331. # Delete conversation
  332. del_conv_query = f"DELETE FROM {self._get_table_name('conversations')} WHERE id = $1"
  333. await self.connection_manager.execute_query(
  334. del_conv_query, [conversation_id]
  335. )