conversations.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. import csv
  2. import json
  3. import tempfile
  4. from typing import IO, Any, Optional
  5. from uuid import UUID, uuid4
  6. from fastapi import HTTPException
  7. from core.base import Handler, Message, R2RException
  8. from shared.api.models.management.responses import (
  9. ConversationResponse,
  10. MessageResponse,
  11. )
  12. from .base import PostgresConnectionManager
  13. class PostgresConversationsHandler(Handler):
  14. def __init__(
  15. self, project_name: str, connection_manager: PostgresConnectionManager
  16. ):
  17. self.project_name = project_name
  18. self.connection_manager = connection_manager
  19. async def create_tables(self):
  20. create_conversations_query = f"""
  21. CREATE TABLE IF NOT EXISTS {self._get_table_name("conversations")} (
  22. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  23. user_id UUID,
  24. created_at TIMESTAMPTZ DEFAULT NOW(),
  25. name TEXT
  26. );
  27. """
  28. create_messages_query = f"""
  29. CREATE TABLE IF NOT EXISTS {self._get_table_name("messages")} (
  30. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  31. conversation_id UUID NOT NULL,
  32. parent_id UUID,
  33. content JSONB,
  34. metadata JSONB,
  35. created_at TIMESTAMPTZ DEFAULT NOW(),
  36. FOREIGN KEY (conversation_id) REFERENCES {self._get_table_name("conversations")}(id),
  37. FOREIGN KEY (parent_id) REFERENCES {self._get_table_name("messages")}(id)
  38. );
  39. """
  40. await self.connection_manager.execute_query(create_conversations_query)
  41. await self.connection_manager.execute_query(create_messages_query)
  42. async def create_conversation(
  43. self,
  44. user_id: Optional[UUID] = None,
  45. name: Optional[str] = None,
  46. ) -> ConversationResponse:
  47. query = f"""
  48. INSERT INTO {self._get_table_name("conversations")} (user_id, name)
  49. VALUES ($1, $2)
  50. RETURNING id, extract(epoch from created_at) as created_at_epoch
  51. """
  52. try:
  53. result = await self.connection_manager.fetchrow_query(
  54. query, [user_id, name]
  55. )
  56. return ConversationResponse(
  57. id=result["id"],
  58. created_at=result["created_at_epoch"],
  59. user_id=user_id or None,
  60. name=name or None,
  61. )
  62. except Exception as e:
  63. raise HTTPException(
  64. status_code=500,
  65. detail=f"Failed to create conversation: {str(e)}",
  66. ) from e
  67. async def get_conversations_overview(
  68. self,
  69. offset: int,
  70. limit: int,
  71. filter_user_ids: Optional[list[UUID]] = None,
  72. conversation_ids: Optional[list[UUID]] = None,
  73. ) -> dict[str, Any]:
  74. conditions = []
  75. params: list = []
  76. param_index = 1
  77. if filter_user_ids:
  78. conditions.append(
  79. f"""
  80. c.user_id IN (
  81. SELECT id
  82. FROM {self.project_name}.users
  83. WHERE id = ANY(${param_index})
  84. )
  85. """
  86. )
  87. params.append(filter_user_ids)
  88. param_index += 1
  89. if conversation_ids:
  90. conditions.append(f"c.id = ANY(${param_index})")
  91. params.append(conversation_ids)
  92. param_index += 1
  93. where_clause = (
  94. "WHERE " + " AND ".join(conditions) if conditions else ""
  95. )
  96. query = f"""
  97. WITH conversation_overview AS (
  98. SELECT c.id,
  99. extract(epoch from c.created_at) as created_at_epoch,
  100. c.user_id,
  101. c.name
  102. FROM {self._get_table_name("conversations")} c
  103. {where_clause}
  104. ),
  105. counted_overview AS (
  106. SELECT *,
  107. COUNT(*) OVER() AS total_entries
  108. FROM conversation_overview
  109. )
  110. SELECT * FROM counted_overview
  111. ORDER BY created_at_epoch DESC
  112. OFFSET ${param_index}
  113. """
  114. params.append(offset)
  115. param_index += 1
  116. if limit != -1:
  117. query += f" LIMIT ${param_index}"
  118. params.append(limit)
  119. results = await self.connection_manager.fetch_query(query, params)
  120. if not results:
  121. return {"results": [], "total_entries": 0}
  122. total_entries = results[0]["total_entries"]
  123. conversations = [
  124. {
  125. "id": str(row["id"]),
  126. "created_at": row["created_at_epoch"],
  127. "user_id": str(row["user_id"]) if row["user_id"] else None,
  128. "name": row["name"] or None,
  129. }
  130. for row in results
  131. ]
  132. return {"results": conversations, "total_entries": total_entries}
  133. async def add_message(
  134. self,
  135. conversation_id: UUID,
  136. content: Message,
  137. parent_id: Optional[UUID] = None,
  138. metadata: Optional[dict] = None,
  139. ) -> MessageResponse:
  140. # Check if conversation exists
  141. conv_check_query = f"""
  142. SELECT 1 FROM {self._get_table_name("conversations")}
  143. WHERE id = $1
  144. """
  145. conv_row = await self.connection_manager.fetchrow_query(
  146. conv_check_query, [conversation_id]
  147. )
  148. if not conv_row:
  149. raise R2RException(
  150. status_code=404,
  151. message=f"Conversation {conversation_id} not found.",
  152. )
  153. # Check parent message if provided
  154. if parent_id:
  155. parent_check_query = f"""
  156. SELECT 1 FROM {self._get_table_name("messages")}
  157. WHERE id = $1 AND conversation_id = $2
  158. """
  159. parent_row = await self.connection_manager.fetchrow_query(
  160. parent_check_query, [parent_id, conversation_id]
  161. )
  162. if not parent_row:
  163. raise R2RException(
  164. status_code=404,
  165. message=f"Parent message {parent_id} not found in conversation {conversation_id}.",
  166. )
  167. message_id = uuid4()
  168. content_str = json.dumps(content.model_dump())
  169. metadata_str = json.dumps(metadata or {})
  170. query = f"""
  171. INSERT INTO {self._get_table_name("messages")}
  172. (id, conversation_id, parent_id, content, created_at, metadata)
  173. VALUES ($1, $2, $3, $4::jsonb, NOW(), $5::jsonb)
  174. RETURNING id
  175. """
  176. inserted = await self.connection_manager.fetchrow_query(
  177. query,
  178. [
  179. message_id,
  180. conversation_id,
  181. parent_id,
  182. content_str,
  183. metadata_str,
  184. ],
  185. )
  186. if not inserted:
  187. raise R2RException(
  188. status_code=500, message="Failed to insert message."
  189. )
  190. return MessageResponse(id=message_id, message=content)
  191. async def edit_message(
  192. self,
  193. message_id: UUID,
  194. new_content: str | None = None,
  195. additional_metadata: dict | None = None,
  196. ) -> dict[str, Any]:
  197. # Get the original message
  198. query = f"""
  199. SELECT conversation_id, parent_id, content, metadata, created_at
  200. FROM {self._get_table_name("messages")}
  201. WHERE id = $1
  202. """
  203. row = await self.connection_manager.fetchrow_query(query, [message_id])
  204. if not row:
  205. raise R2RException(
  206. status_code=404,
  207. message=f"Message {message_id} not found.",
  208. )
  209. old_content = json.loads(row["content"])
  210. old_metadata = json.loads(row["metadata"])
  211. if new_content is not None:
  212. old_message = Message(**old_content)
  213. edited_message = Message(
  214. role=old_message.role,
  215. content=new_content,
  216. name=old_message.name,
  217. function_call=old_message.function_call,
  218. tool_calls=old_message.tool_calls,
  219. )
  220. content_to_save = edited_message.model_dump()
  221. else:
  222. content_to_save = old_content
  223. additional_metadata = additional_metadata or {}
  224. new_metadata = {
  225. **old_metadata,
  226. **additional_metadata,
  227. "edited": (
  228. True
  229. if new_content is not None
  230. else old_metadata.get("edited", False)
  231. ),
  232. }
  233. # Update message without changing the timestamp
  234. update_query = f"""
  235. UPDATE {self._get_table_name("messages")}
  236. SET content = $1::jsonb,
  237. metadata = $2::jsonb,
  238. created_at = $3
  239. WHERE id = $4
  240. RETURNING id
  241. """
  242. updated = await self.connection_manager.fetchrow_query(
  243. update_query,
  244. [
  245. json.dumps(content_to_save),
  246. json.dumps(new_metadata),
  247. row["created_at"],
  248. message_id,
  249. ],
  250. )
  251. if not updated:
  252. raise R2RException(
  253. status_code=500, message="Failed to update message."
  254. )
  255. return {
  256. "id": str(message_id),
  257. "message": (
  258. Message(**content_to_save)
  259. if isinstance(content_to_save, dict)
  260. else content_to_save
  261. ),
  262. "metadata": new_metadata,
  263. }
  264. async def update_message_metadata(
  265. self, message_id: UUID, metadata: dict
  266. ) -> None:
  267. # Fetch current metadata
  268. query = f"""
  269. SELECT metadata FROM {self._get_table_name("messages")}
  270. WHERE id = $1
  271. """
  272. row = await self.connection_manager.fetchrow_query(query, [message_id])
  273. if not row:
  274. raise R2RException(
  275. status_code=404, message=f"Message {message_id} not found."
  276. )
  277. current_metadata = json.loads(row["metadata"]) or {}
  278. updated_metadata = {**current_metadata, **metadata}
  279. update_query = f"""
  280. UPDATE {self._get_table_name("messages")}
  281. SET metadata = $1::jsonb
  282. WHERE id = $2
  283. """
  284. await self.connection_manager.execute_query(
  285. update_query, [json.dumps(updated_metadata), message_id]
  286. )
  287. async def get_conversation(
  288. self,
  289. conversation_id: UUID,
  290. filter_user_ids: Optional[list[UUID]] = None,
  291. ) -> list[MessageResponse]:
  292. conditions = ["c.id = $1"]
  293. params: list = [conversation_id]
  294. if filter_user_ids:
  295. param_index = 2
  296. conditions.append(
  297. f"""
  298. c.user_id IN (
  299. SELECT id
  300. FROM {self.project_name}.users
  301. WHERE id = ANY(${param_index})
  302. )
  303. """
  304. )
  305. params.append(filter_user_ids)
  306. query = f"""
  307. SELECT c.id, extract(epoch from c.created_at) AS created_at_epoch
  308. FROM {self._get_table_name('conversations')} c
  309. WHERE {' AND '.join(conditions)}
  310. """
  311. conv_row = await self.connection_manager.fetchrow_query(query, params)
  312. if not conv_row:
  313. raise R2RException(
  314. status_code=404,
  315. message=f"Conversation {conversation_id} not found.",
  316. )
  317. # Retrieve messages in chronological order
  318. msg_query = f"""
  319. SELECT id, content, metadata
  320. FROM {self._get_table_name("messages")}
  321. WHERE conversation_id = $1
  322. ORDER BY created_at ASC
  323. """
  324. results = await self.connection_manager.fetch_query(
  325. msg_query, [conversation_id]
  326. )
  327. return [
  328. MessageResponse(
  329. id=row["id"],
  330. message=Message(**json.loads(row["content"])),
  331. metadata=json.loads(row["metadata"]),
  332. )
  333. for row in results
  334. ]
  335. async def update_conversation(
  336. self, conversation_id: UUID, name: str
  337. ) -> ConversationResponse:
  338. try:
  339. # Check if conversation exists
  340. conv_query = f"SELECT 1 FROM {self._get_table_name('conversations')} WHERE id = $1"
  341. conv_row = await self.connection_manager.fetchrow_query(
  342. conv_query, [conversation_id]
  343. )
  344. if not conv_row:
  345. raise R2RException(
  346. status_code=404,
  347. message=f"Conversation {conversation_id} not found.",
  348. )
  349. update_query = f"""
  350. UPDATE {self._get_table_name('conversations')}
  351. SET name = $1 WHERE id = $2
  352. RETURNING user_id, extract(epoch from created_at) as created_at_epoch
  353. """
  354. updated_row = await self.connection_manager.fetchrow_query(
  355. update_query, [name, conversation_id]
  356. )
  357. return ConversationResponse(
  358. id=conversation_id,
  359. created_at=updated_row["created_at_epoch"],
  360. user_id=updated_row["user_id"] or None,
  361. name=name,
  362. )
  363. except Exception as e:
  364. raise HTTPException(
  365. status_code=500,
  366. detail=f"Failed to update conversation: {str(e)}",
  367. ) from e
  368. async def delete_conversation(
  369. self,
  370. conversation_id: UUID,
  371. filter_user_ids: Optional[list[UUID]] = None,
  372. ) -> None:
  373. conditions = ["c.id = $1"]
  374. params: list = [conversation_id]
  375. if filter_user_ids:
  376. param_index = 2
  377. conditions.append(
  378. f"""
  379. c.user_id IN (
  380. SELECT id
  381. FROM {self.project_name}.users
  382. WHERE id = ANY(${param_index})
  383. )
  384. """
  385. )
  386. params.append(filter_user_ids)
  387. conv_query = f"""
  388. SELECT 1
  389. FROM {self._get_table_name('conversations')} c
  390. WHERE {' AND '.join(conditions)}
  391. """
  392. conv_row = await self.connection_manager.fetchrow_query(
  393. conv_query, params
  394. )
  395. if not conv_row:
  396. raise R2RException(
  397. status_code=404,
  398. message=f"Conversation {conversation_id} not found.",
  399. )
  400. # Delete all messages
  401. del_messages_query = f"DELETE FROM {self._get_table_name('messages')} WHERE conversation_id = $1"
  402. await self.connection_manager.execute_query(
  403. del_messages_query, [conversation_id]
  404. )
  405. # Delete conversation
  406. del_conv_query = f"DELETE FROM {self._get_table_name('conversations')} WHERE id = $1"
  407. await self.connection_manager.execute_query(
  408. del_conv_query, [conversation_id]
  409. )
  410. async def export_conversations_to_csv(
  411. self,
  412. columns: Optional[list[str]] = None,
  413. filters: Optional[dict] = None,
  414. include_header: bool = True,
  415. ) -> tuple[str, IO]:
  416. """
  417. Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
  418. """
  419. valid_columns = {
  420. "id",
  421. "user_id",
  422. "created_at",
  423. "name",
  424. }
  425. if not columns:
  426. columns = list(valid_columns)
  427. elif invalid_cols := set(columns) - valid_columns:
  428. raise ValueError(f"Invalid columns: {invalid_cols}")
  429. select_stmt = f"""
  430. SELECT
  431. id::text,
  432. user_id::text,
  433. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
  434. name
  435. FROM {self._get_table_name("conversations")}
  436. """
  437. conditions = []
  438. params: list[Any] = []
  439. param_index = 1
  440. if filters:
  441. for field, value in filters.items():
  442. if field not in valid_columns:
  443. continue
  444. if isinstance(value, dict):
  445. for op, val in value.items():
  446. if op == "$eq":
  447. conditions.append(f"{field} = ${param_index}")
  448. params.append(val)
  449. param_index += 1
  450. elif op == "$gt":
  451. conditions.append(f"{field} > ${param_index}")
  452. params.append(val)
  453. param_index += 1
  454. elif op == "$lt":
  455. conditions.append(f"{field} < ${param_index}")
  456. params.append(val)
  457. param_index += 1
  458. else:
  459. # Direct equality
  460. conditions.append(f"{field} = ${param_index}")
  461. params.append(value)
  462. param_index += 1
  463. if conditions:
  464. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  465. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  466. temp_file = None
  467. try:
  468. temp_file = tempfile.NamedTemporaryFile(
  469. mode="w", delete=True, suffix=".csv"
  470. )
  471. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  472. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  473. async with conn.transaction():
  474. cursor = await conn.cursor(select_stmt, *params)
  475. if include_header:
  476. writer.writerow(columns)
  477. chunk_size = 1000
  478. while True:
  479. rows = await cursor.fetch(chunk_size)
  480. if not rows:
  481. break
  482. for row in rows:
  483. writer.writerow(row)
  484. temp_file.flush()
  485. return temp_file.name, temp_file
  486. except Exception as e:
  487. if temp_file:
  488. temp_file.close()
  489. raise HTTPException(
  490. status_code=500,
  491. detail=f"Failed to export data: {str(e)}",
  492. ) from e
  493. async def export_messages_to_csv(
  494. self,
  495. columns: Optional[list[str]] = None,
  496. filters: Optional[dict] = None,
  497. include_header: bool = True,
  498. ) -> tuple[str, IO]:
  499. """
  500. Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
  501. """
  502. valid_columns = {
  503. "id",
  504. "conversation_id",
  505. "parent_id",
  506. "content",
  507. "metadata",
  508. "created_at",
  509. }
  510. if not columns:
  511. columns = list(valid_columns)
  512. elif invalid_cols := set(columns) - valid_columns:
  513. raise ValueError(f"Invalid columns: {invalid_cols}")
  514. select_stmt = f"""
  515. SELECT
  516. id::text,
  517. conversation_id::text,
  518. parent_id::text,
  519. content::text,
  520. metadata::text,
  521. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at
  522. FROM {self._get_table_name("messages")}
  523. """
  524. conditions = []
  525. params: list[Any] = []
  526. param_index = 1
  527. if filters:
  528. for field, value in filters.items():
  529. if field not in valid_columns:
  530. continue
  531. if isinstance(value, dict):
  532. for op, val in value.items():
  533. if op == "$eq":
  534. conditions.append(f"{field} = ${param_index}")
  535. params.append(val)
  536. param_index += 1
  537. elif op == "$gt":
  538. conditions.append(f"{field} > ${param_index}")
  539. params.append(val)
  540. param_index += 1
  541. elif op == "$lt":
  542. conditions.append(f"{field} < ${param_index}")
  543. params.append(val)
  544. param_index += 1
  545. else:
  546. # Direct equality
  547. conditions.append(f"{field} = ${param_index}")
  548. params.append(value)
  549. param_index += 1
  550. if conditions:
  551. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  552. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  553. temp_file = None
  554. try:
  555. temp_file = tempfile.NamedTemporaryFile(
  556. mode="w", delete=True, suffix=".csv"
  557. )
  558. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  559. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  560. async with conn.transaction():
  561. cursor = await conn.cursor(select_stmt, *params)
  562. if include_header:
  563. writer.writerow(columns)
  564. chunk_size = 1000
  565. while True:
  566. rows = await cursor.fetch(chunk_size)
  567. if not rows:
  568. break
  569. for row in rows:
  570. writer.writerow(row)
  571. temp_file.flush()
  572. return temp_file.name, temp_file
  573. except Exception as e:
  574. if temp_file:
  575. temp_file.close()
  576. raise HTTPException(
  577. status_code=500,
  578. detail=f"Failed to export data: {str(e)}",
  579. ) from e