123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376 |
- import json
- from typing import Any, Dict, List, Optional
- from uuid import UUID, uuid4
- from core.base import Handler, Message, R2RException
- from shared.api.models.management.responses import (
- ConversationResponse,
- MessageResponse,
- )
- from .base import PostgresConnectionManager
- class PostgresConversationsHandler(Handler):
- def __init__(
- self, project_name: str, connection_manager: PostgresConnectionManager
- ):
- self.project_name = project_name
- self.connection_manager = connection_manager
- async def create_tables(self):
- # Ensure the uuid_generate_v4() extension is available
- # Depending on your environment, you may need a separate call:
- # await self.connection_manager.execute_query("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";")
- create_conversations_query = f"""
- CREATE TABLE IF NOT EXISTS {self._get_table_name("conversations")} (
- id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
- user_id UUID,
- created_at TIMESTAMPTZ DEFAULT NOW(),
- name TEXT
- );
- """
- create_messages_query = f"""
- CREATE TABLE IF NOT EXISTS {self._get_table_name("messages")} (
- id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
- conversation_id UUID NOT NULL,
- parent_id UUID,
- content JSONB,
- metadata JSONB,
- created_at TIMESTAMPTZ DEFAULT NOW(),
- FOREIGN KEY (conversation_id) REFERENCES {self._get_table_name("conversations")}(id),
- FOREIGN KEY (parent_id) REFERENCES {self._get_table_name("messages")}(id)
- );
- """
- await self.connection_manager.execute_query(create_conversations_query)
- await self.connection_manager.execute_query(create_messages_query)
- async def create_conversation(
- self, user_id: Optional[UUID] = None, name: Optional[str] = None
- ) -> ConversationResponse:
- query = f"""
- INSERT INTO {self._get_table_name("conversations")} (user_id, name)
- VALUES ($1, $2)
- RETURNING id, extract(epoch from created_at) as created_at_epoch
- """
- result = await self.connection_manager.fetchrow_query(
- query, [user_id, name]
- )
- if not result:
- raise R2RException(
- status_code=500, message="Failed to create conversation."
- )
- return ConversationResponse(
- id=str(result["id"]),
- created_at=result["created_at_epoch"],
- )
- async def verify_conversation_access(
- self, conversation_id: UUID, user_id: UUID
- ) -> bool:
- query = f"""
- SELECT 1 FROM {self._get_table_name("conversations")}
- WHERE id = $1 AND (user_id IS NULL OR user_id = $2)
- """
- row = await self.connection_manager.fetchrow_query(
- query, [conversation_id, user_id]
- )
- return row is not None
- async def get_conversations_overview(
- self,
- offset: int,
- limit: int,
- user_ids: Optional[UUID | List[UUID]] = None,
- conversation_ids: Optional[List[UUID]] = None,
- ) -> Dict[str, Any]:
- # Construct conditions
- conditions = []
- params = []
- param_index = 1
- if user_ids is not None:
- if isinstance(user_ids, UUID):
- conditions.append(f"user_id = ${param_index}")
- params.append(user_ids)
- param_index += 1
- else:
- # user_ids is a list of UUIDs
- placeholders = ", ".join(
- f"${i+param_index}" for i in range(len(user_ids))
- )
- conditions.append(
- f"user_id = ANY(ARRAY[{placeholders}]::uuid[])"
- )
- params.extend(user_ids)
- param_index += len(user_ids)
- if conversation_ids:
- placeholders = ", ".join(
- f"${i+param_index}" for i in range(len(conversation_ids))
- )
- conditions.append(f"id = ANY(ARRAY[{placeholders}]::uuid[])")
- params.extend(conversation_ids)
- param_index += len(conversation_ids)
- where_clause = ""
- if conditions:
- where_clause = "WHERE " + " AND ".join(conditions)
- limit_clause = ""
- if limit != -1:
- limit_clause = f"LIMIT ${param_index}"
- params.append(limit)
- param_index += 1
- offset_clause = f"OFFSET ${param_index}"
- params.append(offset)
- query = f"""
- WITH conversation_overview AS (
- SELECT id, extract(epoch from created_at) as created_at_epoch, user_id, name
- FROM {self._get_table_name("conversations")}
- {where_clause}
- ),
- counted_overview AS (
- SELECT *,
- COUNT(*) OVER() AS total_entries
- FROM conversation_overview
- )
- SELECT * FROM counted_overview
- ORDER BY created_at_epoch DESC
- {limit_clause} {offset_clause}
- """
- results = await self.connection_manager.fetch_query(query, params)
- if not results:
- return {"results": [], "total_entries": 0}
- total_entries = results[0]["total_entries"]
- conversations = [
- {
- "id": str(row["id"]),
- "created_at": row["created_at_epoch"],
- "user_id": str(row["user_id"]) if row["user_id"] else None,
- "name": row["name"] or None,
- }
- for row in results
- ]
- return {"results": conversations, "total_entries": total_entries}
- async def add_message(
- self,
- conversation_id: UUID,
- content: Message,
- parent_id: Optional[UUID] = None,
- metadata: Optional[dict] = None,
- ) -> MessageResponse:
- # Check if conversation exists
- conv_check_query = f"""
- SELECT 1 FROM {self._get_table_name("conversations")}
- WHERE id = $1
- """
- conv_row = await self.connection_manager.fetchrow_query(
- conv_check_query, [conversation_id]
- )
- if not conv_row:
- raise R2RException(
- status_code=404,
- message=f"Conversation {conversation_id} not found.",
- )
- # Check parent message if provided
- if parent_id:
- parent_check_query = f"""
- SELECT 1 FROM {self._get_table_name("messages")}
- WHERE id = $1 AND conversation_id = $2
- """
- parent_row = await self.connection_manager.fetchrow_query(
- parent_check_query, [parent_id, conversation_id]
- )
- if not parent_row:
- raise R2RException(
- status_code=404,
- message=f"Parent message {parent_id} not found in conversation {conversation_id}.",
- )
- message_id = uuid4()
- content_str = json.dumps(content.model_dump())
- metadata_str = json.dumps(metadata or {})
- query = f"""
- INSERT INTO {self._get_table_name("messages")}
- (id, conversation_id, parent_id, content, created_at, metadata)
- VALUES ($1, $2, $3, $4::jsonb, NOW(), $5::jsonb)
- RETURNING id
- """
- inserted = await self.connection_manager.fetchrow_query(
- query,
- [
- message_id,
- conversation_id,
- parent_id,
- content_str,
- metadata_str,
- ],
- )
- if not inserted:
- raise R2RException(
- status_code=500, message="Failed to insert message."
- )
- return MessageResponse(id=str(message_id), message=content)
- async def edit_message(
- self,
- message_id: UUID,
- new_content: str,
- additional_metadata: dict = {},
- ) -> Dict[str, Any]:
- # Get the original message
- query = f"""
- SELECT conversation_id, parent_id, content, metadata
- FROM {self._get_table_name("messages")}
- WHERE id = $1
- """
- row = await self.connection_manager.fetchrow_query(query, [message_id])
- if not row:
- raise R2RException(
- status_code=404, message=f"Message {message_id} not found."
- )
- old_content = json.loads(row["content"])
- old_metadata = json.loads(row["metadata"])
- # Update the content
- old_message = Message(**old_content)
- edited_message = Message(
- role=old_message.role,
- content=new_content,
- name=old_message.name,
- function_call=old_message.function_call,
- tool_calls=old_message.tool_calls,
- )
- # Merge metadata and mark edited
- new_metadata = {**old_metadata, **additional_metadata, "edited": True}
- # Instead of branching, we'll simply replace the message content and metadata:
- # NOTE: If you prefer versioning or forking behavior, you'd add a new message.
- # For simplicity, we just edit the existing message.
- update_query = f"""
- UPDATE {self._get_table_name("messages")}
- SET content = $1::jsonb, metadata = $2::jsonb, created_at = NOW()
- WHERE id = $3
- RETURNING id
- """
- updated = await self.connection_manager.fetchrow_query(
- update_query,
- [
- json.dumps(edited_message.model_dump()),
- json.dumps(new_metadata),
- message_id,
- ],
- )
- if not updated:
- raise R2RException(
- status_code=500, message="Failed to update message."
- )
- return {
- "id": str(message_id),
- "message": edited_message,
- "metadata": new_metadata,
- }
- async def update_message_metadata(
- self, message_id: UUID, metadata: dict
- ) -> None:
- # Fetch current metadata
- query = f"""
- SELECT metadata FROM {self._get_table_name("messages")}
- WHERE id = $1
- """
- row = await self.connection_manager.fetchrow_query(query, [message_id])
- if not row:
- raise R2RException(
- status_code=404, message=f"Message {message_id} not found."
- )
- current_metadata = row["metadata"] or {}
- updated_metadata = {**current_metadata, **metadata}
- update_query = f"""
- UPDATE {self._get_table_name("messages")}
- SET metadata = $1::jsonb
- WHERE id = $2
- """
- await self.connection_manager.execute_query(
- update_query, [updated_metadata, message_id]
- )
- async def get_conversation(
- self, conversation_id: UUID
- ) -> List[MessageResponse]:
- # Check conversation
- conv_query = f"SELECT extract(epoch from created_at) AS created_at_epoch FROM {self._get_table_name('conversations')} WHERE id = $1"
- conv_row = await self.connection_manager.fetchrow_query(
- conv_query, [conversation_id]
- )
- if not conv_row:
- raise R2RException(
- status_code=404,
- message=f"Conversation {conversation_id} not found.",
- )
- # Retrieve messages in chronological order
- # We'll recursively gather messages based on parent_id = NULL as root.
- # Since no branching, we simply order by created_at.
- msg_query = f"""
- SELECT id, content, metadata
- FROM {self._get_table_name("messages")}
- WHERE conversation_id = $1
- ORDER BY created_at ASC
- """
- results = await self.connection_manager.fetch_query(
- msg_query, [conversation_id]
- )
- print("results = ", results)
- return [
- MessageResponse(
- id=str(row["id"]),
- message=Message(**json.loads(row["content"])),
- metadata=json.loads(row["metadata"]),
- )
- for row in results
- ]
- async def delete_conversation(self, conversation_id: UUID):
- # Check if conversation exists
- conv_query = f"SELECT 1 FROM {self._get_table_name('conversations')} WHERE id = $1"
- conv_row = await self.connection_manager.fetchrow_query(
- conv_query, [conversation_id]
- )
- if not conv_row:
- raise R2RException(
- status_code=404,
- message=f"Conversation {conversation_id} not found.",
- )
- # Delete all messages
- del_messages_query = f"DELETE FROM {self._get_table_name('messages')} WHERE conversation_id = $1"
- await self.connection_manager.execute_query(
- del_messages_query, [conversation_id]
- )
- # Delete conversation
- del_conv_query = f"DELETE FROM {self._get_table_name('conversations')} WHERE id = $1"
- await self.connection_manager.execute_query(
- del_conv_query, [conversation_id]
- )
|