conversations.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. import json
  2. from typing import Any, Optional
  3. from uuid import UUID, uuid4
  4. from fastapi import HTTPException
  5. from core.base import Handler, Message, R2RException
  6. from shared.api.models.management.responses import (
  7. ConversationResponse,
  8. MessageResponse,
  9. )
  10. from .base import PostgresConnectionManager
  11. class PostgresConversationsHandler(Handler):
  12. def __init__(
  13. self, project_name: str, connection_manager: PostgresConnectionManager
  14. ):
  15. self.project_name = project_name
  16. self.connection_manager = connection_manager
  17. async def create_tables(self):
  18. create_conversations_query = f"""
  19. CREATE TABLE IF NOT EXISTS {self._get_table_name("conversations")} (
  20. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  21. user_id UUID,
  22. created_at TIMESTAMPTZ DEFAULT NOW(),
  23. name TEXT
  24. );
  25. """
  26. create_messages_query = f"""
  27. CREATE TABLE IF NOT EXISTS {self._get_table_name("messages")} (
  28. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  29. conversation_id UUID NOT NULL,
  30. parent_id UUID,
  31. content JSONB,
  32. metadata JSONB,
  33. created_at TIMESTAMPTZ DEFAULT NOW(),
  34. FOREIGN KEY (conversation_id) REFERENCES {self._get_table_name("conversations")}(id),
  35. FOREIGN KEY (parent_id) REFERENCES {self._get_table_name("messages")}(id)
  36. );
  37. """
  38. await self.connection_manager.execute_query(create_conversations_query)
  39. await self.connection_manager.execute_query(create_messages_query)
  40. async def create_conversation(
  41. self,
  42. user_id: Optional[UUID] = None,
  43. name: Optional[str] = None,
  44. ) -> ConversationResponse:
  45. query = f"""
  46. INSERT INTO {self._get_table_name("conversations")} (user_id, name)
  47. VALUES ($1, $2)
  48. RETURNING id, extract(epoch from created_at) as created_at_epoch
  49. """
  50. try:
  51. result = await self.connection_manager.fetchrow_query(
  52. query, [user_id, name]
  53. )
  54. return ConversationResponse(
  55. id=result["id"],
  56. created_at=result["created_at_epoch"],
  57. user_id=user_id or None,
  58. name=name or None,
  59. )
  60. except Exception as e:
  61. raise HTTPException(
  62. status_code=500,
  63. detail=f"Failed to create conversation: {str(e)}",
  64. ) from e
  65. async def get_conversations_overview(
  66. self,
  67. offset: int,
  68. limit: int,
  69. filter_user_ids: Optional[list[UUID]] = None,
  70. conversation_ids: Optional[list[UUID]] = None,
  71. ) -> dict[str, Any]:
  72. conditions = []
  73. params: list = []
  74. param_index = 1
  75. if filter_user_ids:
  76. conditions.append(
  77. f"""
  78. c.user_id IN (
  79. SELECT id
  80. FROM {self.project_name}.users
  81. WHERE id = ANY(${param_index})
  82. )
  83. """
  84. )
  85. params.append(filter_user_ids)
  86. param_index += 1
  87. if conversation_ids:
  88. conditions.append(f"c.id = ANY(${param_index})")
  89. params.append(conversation_ids)
  90. param_index += 1
  91. where_clause = (
  92. "WHERE " + " AND ".join(conditions) if conditions else ""
  93. )
  94. query = f"""
  95. WITH conversation_overview AS (
  96. SELECT c.id,
  97. extract(epoch from c.created_at) as created_at_epoch,
  98. c.user_id,
  99. c.name
  100. FROM {self._get_table_name("conversations")} c
  101. {where_clause}
  102. ),
  103. counted_overview AS (
  104. SELECT *,
  105. COUNT(*) OVER() AS total_entries
  106. FROM conversation_overview
  107. )
  108. SELECT * FROM counted_overview
  109. ORDER BY created_at_epoch DESC
  110. OFFSET ${param_index}
  111. """
  112. params.append(offset)
  113. param_index += 1
  114. if limit != -1:
  115. query += f" LIMIT ${param_index}"
  116. params.append(limit)
  117. results = await self.connection_manager.fetch_query(query, params)
  118. if not results:
  119. return {"results": [], "total_entries": 0}
  120. total_entries = results[0]["total_entries"]
  121. conversations = [
  122. {
  123. "id": str(row["id"]),
  124. "created_at": row["created_at_epoch"],
  125. "user_id": str(row["user_id"]) if row["user_id"] else None,
  126. "name": row["name"] or None,
  127. }
  128. for row in results
  129. ]
  130. return {"results": conversations, "total_entries": total_entries}
  131. async def add_message(
  132. self,
  133. conversation_id: UUID,
  134. content: Message,
  135. parent_id: Optional[UUID] = None,
  136. metadata: Optional[dict] = None,
  137. ) -> MessageResponse:
  138. # Check if conversation exists
  139. conv_check_query = f"""
  140. SELECT 1 FROM {self._get_table_name("conversations")}
  141. WHERE id = $1
  142. """
  143. conv_row = await self.connection_manager.fetchrow_query(
  144. conv_check_query, [conversation_id]
  145. )
  146. if not conv_row:
  147. raise R2RException(
  148. status_code=404,
  149. message=f"Conversation {conversation_id} not found.",
  150. )
  151. # Check parent message if provided
  152. if parent_id:
  153. parent_check_query = f"""
  154. SELECT 1 FROM {self._get_table_name("messages")}
  155. WHERE id = $1 AND conversation_id = $2
  156. """
  157. parent_row = await self.connection_manager.fetchrow_query(
  158. parent_check_query, [parent_id, conversation_id]
  159. )
  160. if not parent_row:
  161. raise R2RException(
  162. status_code=404,
  163. message=f"Parent message {parent_id} not found in conversation {conversation_id}.",
  164. )
  165. message_id = uuid4()
  166. content_str = json.dumps(content.model_dump())
  167. metadata_str = json.dumps(metadata or {})
  168. query = f"""
  169. INSERT INTO {self._get_table_name("messages")}
  170. (id, conversation_id, parent_id, content, created_at, metadata)
  171. VALUES ($1, $2, $3, $4::jsonb, NOW(), $5::jsonb)
  172. RETURNING id
  173. """
  174. inserted = await self.connection_manager.fetchrow_query(
  175. query,
  176. [
  177. message_id,
  178. conversation_id,
  179. parent_id,
  180. content_str,
  181. metadata_str,
  182. ],
  183. )
  184. if not inserted:
  185. raise R2RException(
  186. status_code=500, message="Failed to insert message."
  187. )
  188. return MessageResponse(id=message_id, message=content)
  189. async def edit_message(
  190. self,
  191. message_id: UUID,
  192. new_content: str | None = None,
  193. additional_metadata: dict | None = None,
  194. ) -> dict[str, Any]:
  195. # Get the original message
  196. query = f"""
  197. SELECT conversation_id, parent_id, content, metadata, created_at
  198. FROM {self._get_table_name("messages")}
  199. WHERE id = $1
  200. """
  201. row = await self.connection_manager.fetchrow_query(query, [message_id])
  202. if not row:
  203. raise R2RException(
  204. status_code=404,
  205. message=f"Message {message_id} not found.",
  206. )
  207. old_content = json.loads(row["content"])
  208. old_metadata = json.loads(row["metadata"])
  209. if new_content is not None:
  210. old_message = Message(**old_content)
  211. edited_message = Message(
  212. role=old_message.role,
  213. content=new_content,
  214. name=old_message.name,
  215. function_call=old_message.function_call,
  216. tool_calls=old_message.tool_calls,
  217. )
  218. content_to_save = edited_message.model_dump()
  219. else:
  220. content_to_save = old_content
  221. additional_metadata = additional_metadata or {}
  222. new_metadata = {
  223. **old_metadata,
  224. **additional_metadata,
  225. "edited": (
  226. True
  227. if new_content is not None
  228. else old_metadata.get("edited", False)
  229. ),
  230. }
  231. # Update message without changing the timestamp
  232. update_query = f"""
  233. UPDATE {self._get_table_name("messages")}
  234. SET content = $1::jsonb,
  235. metadata = $2::jsonb,
  236. created_at = $3
  237. WHERE id = $4
  238. RETURNING id
  239. """
  240. updated = await self.connection_manager.fetchrow_query(
  241. update_query,
  242. [
  243. json.dumps(content_to_save),
  244. json.dumps(new_metadata),
  245. row["created_at"],
  246. message_id,
  247. ],
  248. )
  249. if not updated:
  250. raise R2RException(
  251. status_code=500, message="Failed to update message."
  252. )
  253. return {
  254. "id": str(message_id),
  255. "message": (
  256. Message(**content_to_save)
  257. if isinstance(content_to_save, dict)
  258. else content_to_save
  259. ),
  260. "metadata": new_metadata,
  261. }
  262. async def update_message_metadata(
  263. self, message_id: UUID, metadata: dict
  264. ) -> None:
  265. # Fetch current metadata
  266. query = f"""
  267. SELECT metadata FROM {self._get_table_name("messages")}
  268. WHERE id = $1
  269. """
  270. row = await self.connection_manager.fetchrow_query(query, [message_id])
  271. if not row:
  272. raise R2RException(
  273. status_code=404, message=f"Message {message_id} not found."
  274. )
  275. current_metadata = json.loads(row["metadata"]) or {}
  276. updated_metadata = {**current_metadata, **metadata}
  277. update_query = f"""
  278. UPDATE {self._get_table_name("messages")}
  279. SET metadata = $1::jsonb
  280. WHERE id = $2
  281. """
  282. await self.connection_manager.execute_query(
  283. update_query, [json.dumps(updated_metadata), message_id]
  284. )
  285. async def get_conversation(
  286. self,
  287. conversation_id: UUID,
  288. filter_user_ids: Optional[list[UUID]] = None,
  289. ) -> list[MessageResponse]:
  290. conditions = ["c.id = $1"]
  291. params: list = [conversation_id]
  292. if filter_user_ids:
  293. param_index = 2
  294. conditions.append(
  295. f"""
  296. c.user_id IN (
  297. SELECT id
  298. FROM {self.project_name}.users
  299. WHERE id = ANY(${param_index})
  300. )
  301. """
  302. )
  303. params.append(filter_user_ids)
  304. query = f"""
  305. SELECT c.id, extract(epoch from c.created_at) AS created_at_epoch
  306. FROM {self._get_table_name('conversations')} c
  307. WHERE {' AND '.join(conditions)}
  308. """
  309. conv_row = await self.connection_manager.fetchrow_query(query, params)
  310. if not conv_row:
  311. raise R2RException(
  312. status_code=404,
  313. message=f"Conversation {conversation_id} not found.",
  314. )
  315. # Retrieve messages in chronological order
  316. msg_query = f"""
  317. SELECT id, content, metadata
  318. FROM {self._get_table_name("messages")}
  319. WHERE conversation_id = $1
  320. ORDER BY created_at ASC
  321. """
  322. results = await self.connection_manager.fetch_query(
  323. msg_query, [conversation_id]
  324. )
  325. return [
  326. MessageResponse(
  327. id=row["id"],
  328. message=Message(**json.loads(row["content"])),
  329. metadata=json.loads(row["metadata"]),
  330. )
  331. for row in results
  332. ]
  333. async def update_conversation(
  334. self, conversation_id: UUID, name: str
  335. ) -> ConversationResponse:
  336. try:
  337. # Check if conversation exists
  338. conv_query = f"SELECT 1 FROM {self._get_table_name('conversations')} WHERE id = $1"
  339. conv_row = await self.connection_manager.fetchrow_query(
  340. conv_query, [conversation_id]
  341. )
  342. if not conv_row:
  343. raise R2RException(
  344. status_code=404,
  345. message=f"Conversation {conversation_id} not found.",
  346. )
  347. update_query = f"""
  348. UPDATE {self._get_table_name('conversations')}
  349. SET name = $1 WHERE id = $2
  350. RETURNING user_id, extract(epoch from created_at) as created_at_epoch
  351. """
  352. updated_row = await self.connection_manager.fetchrow_query(
  353. update_query, [name, conversation_id]
  354. )
  355. return ConversationResponse(
  356. id=conversation_id,
  357. created_at=updated_row["created_at_epoch"],
  358. user_id=updated_row["user_id"] or None,
  359. name=name,
  360. )
  361. except Exception as e:
  362. raise HTTPException(
  363. status_code=500,
  364. detail=f"Failed to update conversation: {str(e)}",
  365. ) from e
  366. async def delete_conversation(
  367. self,
  368. conversation_id: UUID,
  369. filter_user_ids: Optional[list[UUID]] = None,
  370. ) -> None:
  371. conditions = ["c.id = $1"]
  372. params: list = [conversation_id]
  373. if filter_user_ids:
  374. param_index = 2
  375. conditions.append(
  376. f"""
  377. c.user_id IN (
  378. SELECT id
  379. FROM {self.project_name}.users
  380. WHERE id = ANY(${param_index})
  381. )
  382. """
  383. )
  384. params.append(filter_user_ids)
  385. conv_query = f"""
  386. SELECT 1
  387. FROM {self._get_table_name('conversations')} c
  388. WHERE {' AND '.join(conditions)}
  389. """
  390. conv_row = await self.connection_manager.fetchrow_query(
  391. conv_query, params
  392. )
  393. if not conv_row:
  394. raise R2RException(
  395. status_code=404,
  396. message=f"Conversation {conversation_id} not found.",
  397. )
  398. # Delete all messages
  399. del_messages_query = f"DELETE FROM {self._get_table_name('messages')} WHERE conversation_id = $1"
  400. await self.connection_manager.execute_query(
  401. del_messages_query, [conversation_id]
  402. )
  403. # Delete conversation
  404. del_conv_query = f"DELETE FROM {self._get_table_name('conversations')} WHERE id = $1"
  405. await self.connection_manager.execute_query(
  406. del_conv_query, [conversation_id]
  407. )