conversations.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858
  1. import csv
  2. import json
  3. import logging
  4. import tempfile
  5. from datetime import datetime
  6. from typing import IO, Any, Optional
  7. from uuid import UUID, uuid4
  8. from fastapi import HTTPException
  9. from core.base import Handler, Message, R2RException
  10. from shared.api.models.management.responses import (
  11. ConversationResponse,
  12. MessageResponse,
  13. )
  14. from .base import PostgresConnectionManager
  15. logger = logging.getLogger(__name__)
  16. def _validate_image_size(
  17. message: Message, max_size_bytes: int = 5 * 1024 * 1024
  18. ) -> None:
  19. """
  20. Validates that images in a message don't exceed the maximum allowed size.
  21. Args:
  22. message: Message object to validate
  23. max_size_bytes: Maximum allowed size for base64-encoded images (default: 5MB)
  24. Raises:
  25. R2RException: If image is too large
  26. """
  27. if (
  28. hasattr(message, "image_data")
  29. and message.image_data
  30. and "data" in message.image_data
  31. ):
  32. base64_data = message.image_data["data"]
  33. # Calculate approximate decoded size (base64 increases size by ~33%)
  34. # The formula is: decoded_size = encoded_size * 3/4
  35. estimated_size_bytes = len(base64_data) * 0.75
  36. if estimated_size_bytes > max_size_bytes:
  37. raise R2RException(
  38. status_code=413, # Payload Too Large
  39. message=f"Image too large: {estimated_size_bytes / 1024 / 1024:.2f}MB exceeds the maximum allowed size of {max_size_bytes / 1024 / 1024:.2f}MB",
  40. )
  41. def _json_default(obj: Any) -> str:
  42. """Default handler for objects not serializable by the standard json
  43. encoder."""
  44. if isinstance(obj, datetime):
  45. # Return ISO8601 string
  46. return obj.isoformat()
  47. elif isinstance(obj, UUID):
  48. # Convert UUID to string
  49. return str(obj)
  50. # If you have other special types, handle them here...
  51. # e.g. decimal.Decimal -> str(obj)
  52. # If we get here, raise an error or just default to string:
  53. raise TypeError(f"Type {type(obj)} not serializable")
  54. def safe_dumps(obj: Any) -> str:
  55. """Wrap `json.dumps` with a default that serializes UUID and datetime."""
  56. return json.dumps(obj, default=_json_default)
  57. class PostgresConversationsHandler(Handler):
  58. def __init__(
  59. self, project_name: str, connection_manager: PostgresConnectionManager
  60. ):
  61. self.project_name = project_name
  62. self.connection_manager = connection_manager
  63. async def create_tables(self):
  64. create_conversations_query = f"""
  65. CREATE TABLE IF NOT EXISTS {self._get_table_name("conversations")} (
  66. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  67. user_id UUID,
  68. created_at TIMESTAMPTZ DEFAULT NOW(),
  69. name TEXT
  70. );
  71. """
  72. create_messages_query = f"""
  73. CREATE TABLE IF NOT EXISTS {self._get_table_name("messages")} (
  74. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  75. conversation_id UUID NOT NULL,
  76. parent_id UUID,
  77. content JSONB,
  78. metadata JSONB,
  79. created_at TIMESTAMPTZ DEFAULT NOW(),
  80. FOREIGN KEY (conversation_id) REFERENCES {self._get_table_name("conversations")}(id),
  81. FOREIGN KEY (parent_id) REFERENCES {self._get_table_name("messages")}(id)
  82. );
  83. """
  84. await self.connection_manager.execute_query(create_conversations_query)
  85. await self.connection_manager.execute_query(create_messages_query)
  86. async def create_conversation(
  87. self,
  88. user_id: Optional[UUID] = None,
  89. name: Optional[str] = None,
  90. ) -> ConversationResponse:
  91. query = f"""
  92. INSERT INTO {self._get_table_name("conversations")} (user_id, name)
  93. VALUES ($1, $2)
  94. RETURNING id, extract(epoch from created_at) as created_at_epoch
  95. """
  96. try:
  97. result = await self.connection_manager.fetchrow_query(
  98. query, [user_id, name]
  99. )
  100. return ConversationResponse(
  101. id=result["id"],
  102. created_at=result["created_at_epoch"],
  103. user_id=user_id or None,
  104. name=name or None,
  105. )
  106. except Exception as e:
  107. raise HTTPException(
  108. status_code=500,
  109. detail=f"Failed to create conversation: {str(e)}",
  110. ) from e
  111. async def get_conversations_overview(
  112. self,
  113. offset: int,
  114. limit: int,
  115. filter_user_ids: Optional[list[UUID]] = None,
  116. conversation_ids: Optional[list[UUID]] = None,
  117. ) -> dict[str, Any]:
  118. conditions = []
  119. params: list = []
  120. param_index = 1
  121. if filter_user_ids:
  122. conditions.append(f"""
  123. c.user_id IN (
  124. SELECT id
  125. FROM {self.project_name}.users
  126. WHERE id = ANY(${param_index})
  127. )
  128. """)
  129. params.append(filter_user_ids)
  130. param_index += 1
  131. if conversation_ids:
  132. conditions.append(f"c.id = ANY(${param_index})")
  133. params.append(conversation_ids)
  134. param_index += 1
  135. where_clause = (
  136. "WHERE " + " AND ".join(conditions) if conditions else ""
  137. )
  138. query = f"""
  139. WITH conversation_overview AS (
  140. SELECT c.id,
  141. extract(epoch from c.created_at) as created_at_epoch,
  142. c.user_id,
  143. c.name
  144. FROM {self._get_table_name("conversations")} c
  145. {where_clause}
  146. ),
  147. counted_overview AS (
  148. SELECT *,
  149. COUNT(*) OVER() AS total_entries
  150. FROM conversation_overview
  151. )
  152. SELECT * FROM counted_overview
  153. ORDER BY created_at_epoch DESC
  154. OFFSET ${param_index}
  155. """
  156. params.append(offset)
  157. param_index += 1
  158. if limit != -1:
  159. query += f" LIMIT ${param_index}"
  160. params.append(limit)
  161. results = await self.connection_manager.fetch_query(query, params)
  162. if not results:
  163. return {"results": [], "total_entries": 0}
  164. total_entries = results[0]["total_entries"]
  165. conversations = [
  166. {
  167. "id": str(row["id"]),
  168. "created_at": row["created_at_epoch"],
  169. "user_id": str(row["user_id"]) if row["user_id"] else None,
  170. "name": row["name"] or None,
  171. }
  172. for row in results
  173. ]
  174. return {"results": conversations, "total_entries": total_entries}
  175. async def add_message(
  176. self,
  177. conversation_id: UUID,
  178. content: Message,
  179. parent_id: Optional[UUID] = None,
  180. metadata: Optional[dict] = None,
  181. max_image_size_bytes: int = 5 * 1024 * 1024, # 5MB default
  182. ) -> MessageResponse:
  183. # Validate image size
  184. try:
  185. _validate_image_size(content, max_image_size_bytes)
  186. except R2RException:
  187. # Re-raise validation exceptions
  188. raise
  189. except Exception as e:
  190. # Handle unexpected errors during validation
  191. logger.error(f"Error validating image: {str(e)}")
  192. raise R2RException(
  193. status_code=400, message=f"Invalid image data: {str(e)}"
  194. ) from e
  195. # 1) Validate that conversation and parent exist (existing code)
  196. conv_check_query = f"""
  197. SELECT 1 FROM {self._get_table_name("conversations")}
  198. WHERE id = $1
  199. """
  200. conv_row = await self.connection_manager.fetchrow_query(
  201. conv_check_query, [conversation_id]
  202. )
  203. if not conv_row:
  204. raise R2RException(
  205. status_code=404,
  206. message=f"Conversation {conversation_id} not found.",
  207. )
  208. if parent_id:
  209. parent_check_query = f"""
  210. SELECT 1 FROM {self._get_table_name("messages")}
  211. WHERE id = $1 AND conversation_id = $2
  212. """
  213. parent_row = await self.connection_manager.fetchrow_query(
  214. parent_check_query, [parent_id, conversation_id]
  215. )
  216. if not parent_row:
  217. raise R2RException(
  218. status_code=404,
  219. message=f"Parent message {parent_id} not found in conversation {conversation_id}.",
  220. )
  221. # 2) Add image info to metadata for tracking/analytics if images are present
  222. metadata = metadata or {}
  223. if hasattr(content, "image_url") and content.image_url:
  224. metadata["has_image"] = True
  225. metadata["image_type"] = "url"
  226. elif hasattr(content, "image_data") and content.image_data:
  227. metadata["has_image"] = True
  228. metadata["image_type"] = "base64"
  229. # Don't store the actual base64 data in metadata as it would be redundant
  230. # 3) Convert the content & metadata to JSON strings
  231. message_id = uuid4()
  232. # Using safe_dumps to handle any type of serialization
  233. content_str = safe_dumps(content.model_dump())
  234. metadata_str = safe_dumps(metadata)
  235. # 4) Insert the message (existing code)
  236. query = f"""
  237. INSERT INTO {self._get_table_name("messages")}
  238. (id, conversation_id, parent_id, content, created_at, metadata)
  239. VALUES ($1, $2, $3, $4::jsonb, NOW(), $5::jsonb)
  240. RETURNING id
  241. """
  242. inserted = await self.connection_manager.fetchrow_query(
  243. query,
  244. [
  245. message_id,
  246. conversation_id,
  247. parent_id,
  248. content_str,
  249. metadata_str,
  250. ],
  251. )
  252. if not inserted:
  253. raise R2RException(
  254. status_code=500, message="Failed to insert message."
  255. )
  256. return MessageResponse(id=message_id, message=content)
  257. async def edit_message(
  258. self,
  259. message_id: UUID,
  260. new_content: str | None = None,
  261. additional_metadata: dict | None = None,
  262. ) -> dict[str, Any]:
  263. # Get the original message
  264. query = f"""
  265. SELECT conversation_id, parent_id, content, metadata, created_at
  266. FROM {self._get_table_name("messages")}
  267. WHERE id = $1
  268. """
  269. row = await self.connection_manager.fetchrow_query(query, [message_id])
  270. if not row:
  271. raise R2RException(
  272. status_code=404,
  273. message=f"Message {message_id} not found.",
  274. )
  275. old_content = json.loads(row["content"])
  276. old_metadata = json.loads(row["metadata"])
  277. if new_content is not None:
  278. old_message = Message(**old_content)
  279. edited_message = Message(
  280. role=old_message.role,
  281. content=new_content,
  282. name=old_message.name,
  283. function_call=old_message.function_call,
  284. tool_calls=old_message.tool_calls,
  285. # Preserve image content if it exists
  286. image_url=getattr(old_message, "image_url", None),
  287. image_data=getattr(old_message, "image_data", None),
  288. )
  289. content_to_save = edited_message.model_dump()
  290. else:
  291. content_to_save = old_content
  292. additional_metadata = additional_metadata or {}
  293. new_metadata = {
  294. **old_metadata,
  295. **additional_metadata,
  296. "edited": (
  297. True
  298. if new_content is not None
  299. else old_metadata.get("edited", False)
  300. ),
  301. }
  302. # Update message without changing the timestamp
  303. update_query = f"""
  304. UPDATE {self._get_table_name("messages")}
  305. SET content = $1::jsonb,
  306. metadata = $2::jsonb,
  307. created_at = $3
  308. WHERE id = $4
  309. RETURNING id
  310. """
  311. updated = await self.connection_manager.fetchrow_query(
  312. update_query,
  313. [
  314. json.dumps(content_to_save),
  315. json.dumps(new_metadata),
  316. row["created_at"],
  317. message_id,
  318. ],
  319. )
  320. if not updated:
  321. raise R2RException(
  322. status_code=500, message="Failed to update message."
  323. )
  324. return {
  325. "id": str(message_id),
  326. "message": (
  327. Message(**content_to_save)
  328. if isinstance(content_to_save, dict)
  329. else content_to_save
  330. ),
  331. "metadata": new_metadata,
  332. }
  333. async def update_message_metadata(
  334. self, message_id: UUID, metadata: dict
  335. ) -> None:
  336. # Fetch current metadata
  337. query = f"""
  338. SELECT metadata FROM {self._get_table_name("messages")}
  339. WHERE id = $1
  340. """
  341. row = await self.connection_manager.fetchrow_query(query, [message_id])
  342. if not row:
  343. raise R2RException(
  344. status_code=404, message=f"Message {message_id} not found."
  345. )
  346. current_metadata = json.loads(row["metadata"]) or {}
  347. updated_metadata = {**current_metadata, **metadata}
  348. update_query = f"""
  349. UPDATE {self._get_table_name("messages")}
  350. SET metadata = $1::jsonb
  351. WHERE id = $2
  352. """
  353. await self.connection_manager.execute_query(
  354. update_query, [json.dumps(updated_metadata), message_id]
  355. )
  356. async def get_conversation(
  357. self,
  358. conversation_id: UUID,
  359. filter_user_ids: Optional[list[UUID]] = None,
  360. ) -> list[MessageResponse]:
  361. # Existing validation code remains the same
  362. conditions = ["c.id = $1"]
  363. params: list = [conversation_id]
  364. if filter_user_ids:
  365. param_index = 2
  366. conditions.append(f"""
  367. c.user_id IN (
  368. SELECT id
  369. FROM {self.project_name}.users
  370. WHERE id = ANY(${param_index})
  371. )
  372. """)
  373. params.append(filter_user_ids)
  374. query = f"""
  375. SELECT c.id, extract(epoch from c.created_at) AS created_at_epoch
  376. FROM {self._get_table_name("conversations")} c
  377. WHERE {" AND ".join(conditions)}
  378. """
  379. conv_row = await self.connection_manager.fetchrow_query(query, params)
  380. if not conv_row:
  381. raise R2RException(
  382. status_code=404,
  383. message=f"Conversation {conversation_id} not found.",
  384. )
  385. # Retrieve messages in chronological order
  386. msg_query = f"""
  387. SELECT id, content, metadata
  388. FROM {self._get_table_name("messages")}
  389. WHERE conversation_id = $1
  390. ORDER BY created_at ASC
  391. """
  392. results = await self.connection_manager.fetch_query(
  393. msg_query, [conversation_id]
  394. )
  395. response_messages = []
  396. for row in results:
  397. try:
  398. # Parse the message content
  399. content_json = json.loads(row["content"])
  400. # Create a Message object with the parsed content
  401. message = Message(**content_json)
  402. # Create a MessageResponse
  403. response_messages.append(
  404. MessageResponse(
  405. id=row["id"],
  406. message=message,
  407. metadata=json.loads(row["metadata"]),
  408. )
  409. )
  410. except Exception as e:
  411. # If there's an error parsing the message (e.g., due to version mismatch),
  412. # log it and create a fallback message
  413. logger.warning(f"Error parsing message {row['id']}: {str(e)}")
  414. fallback_content = content_json.get(
  415. "content", "Message could not be loaded"
  416. )
  417. fallback_role = content_json.get("role", "assistant")
  418. # Create a basic fallback message
  419. fallback_message = Message(
  420. role=fallback_role,
  421. content=f"[Message format incompatible: {fallback_content}]",
  422. )
  423. response_messages.append(
  424. MessageResponse(
  425. id=row["id"],
  426. message=fallback_message,
  427. metadata=json.loads(row["metadata"]),
  428. )
  429. )
  430. return response_messages
  431. async def update_conversation(
  432. self, conversation_id: UUID, name: str
  433. ) -> ConversationResponse:
  434. try:
  435. # Check if conversation exists
  436. conv_query = f"SELECT 1 FROM {self._get_table_name('conversations')} WHERE id = $1"
  437. conv_row = await self.connection_manager.fetchrow_query(
  438. conv_query, [conversation_id]
  439. )
  440. if not conv_row:
  441. raise R2RException(
  442. status_code=404,
  443. message=f"Conversation {conversation_id} not found.",
  444. )
  445. update_query = f"""
  446. UPDATE {self._get_table_name("conversations")}
  447. SET name = $1 WHERE id = $2
  448. RETURNING user_id, extract(epoch from created_at) as created_at_epoch
  449. """
  450. updated_row = await self.connection_manager.fetchrow_query(
  451. update_query, [name, conversation_id]
  452. )
  453. return ConversationResponse(
  454. id=conversation_id,
  455. created_at=updated_row["created_at_epoch"],
  456. user_id=updated_row["user_id"] or None,
  457. name=name,
  458. )
  459. except Exception as e:
  460. raise HTTPException(
  461. status_code=500,
  462. detail=f"Failed to update conversation: {str(e)}",
  463. ) from e
  464. async def delete_conversation(
  465. self,
  466. conversation_id: UUID,
  467. filter_user_ids: Optional[list[UUID]] = None,
  468. ) -> None:
  469. conditions = ["c.id = $1"]
  470. params: list = [conversation_id]
  471. if filter_user_ids:
  472. param_index = 2
  473. conditions.append(f"""
  474. c.user_id IN (
  475. SELECT id
  476. FROM {self.project_name}.users
  477. WHERE id = ANY(${param_index})
  478. )
  479. """)
  480. params.append(filter_user_ids)
  481. conv_query = f"""
  482. SELECT 1
  483. FROM {self._get_table_name("conversations")} c
  484. WHERE {" AND ".join(conditions)}
  485. """
  486. conv_row = await self.connection_manager.fetchrow_query(
  487. conv_query, params
  488. )
  489. if not conv_row:
  490. raise R2RException(
  491. status_code=404,
  492. message=f"Conversation {conversation_id} not found.",
  493. )
  494. # Delete all messages
  495. del_messages_query = f"DELETE FROM {self._get_table_name('messages')} WHERE conversation_id = $1"
  496. await self.connection_manager.execute_query(
  497. del_messages_query, [conversation_id]
  498. )
  499. # Delete conversation
  500. del_conv_query = f"DELETE FROM {self._get_table_name('conversations')} WHERE id = $1"
  501. await self.connection_manager.execute_query(
  502. del_conv_query, [conversation_id]
  503. )
  504. async def export_conversations_to_csv(
  505. self,
  506. columns: Optional[list[str]] = None,
  507. filters: Optional[dict] = None,
  508. include_header: bool = True,
  509. ) -> tuple[str, IO]:
  510. """Creates a CSV file from the PostgreSQL data and returns the path to
  511. the temp file."""
  512. valid_columns = {
  513. "id",
  514. "user_id",
  515. "created_at",
  516. "name",
  517. }
  518. if not columns:
  519. columns = list(valid_columns)
  520. elif invalid_cols := set(columns) - valid_columns:
  521. raise ValueError(f"Invalid columns: {invalid_cols}")
  522. select_stmt = f"""
  523. SELECT
  524. id::text,
  525. user_id::text,
  526. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
  527. name
  528. FROM {self._get_table_name("conversations")}
  529. """
  530. conditions = []
  531. params: list[Any] = []
  532. param_index = 1
  533. if filters:
  534. for field, value in filters.items():
  535. if field not in valid_columns:
  536. continue
  537. if isinstance(value, dict):
  538. for op, val in value.items():
  539. if op == "$eq":
  540. conditions.append(f"{field} = ${param_index}")
  541. params.append(val)
  542. param_index += 1
  543. elif op == "$gt":
  544. conditions.append(f"{field} > ${param_index}")
  545. params.append(val)
  546. param_index += 1
  547. elif op == "$lt":
  548. conditions.append(f"{field} < ${param_index}")
  549. params.append(val)
  550. param_index += 1
  551. else:
  552. # Direct equality
  553. conditions.append(f"{field} = ${param_index}")
  554. params.append(value)
  555. param_index += 1
  556. if conditions:
  557. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  558. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  559. temp_file = None
  560. try:
  561. temp_file = tempfile.NamedTemporaryFile(
  562. mode="w", delete=True, suffix=".csv"
  563. )
  564. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  565. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  566. async with conn.transaction():
  567. cursor = await conn.cursor(select_stmt, *params)
  568. if include_header:
  569. writer.writerow(columns)
  570. chunk_size = 1000
  571. while True:
  572. rows = await cursor.fetch(chunk_size)
  573. if not rows:
  574. break
  575. for row in rows:
  576. row_dict = {
  577. "id": row[0],
  578. "user_id": row[1],
  579. "created_at": row[2],
  580. "name": row[3],
  581. }
  582. writer.writerow([row_dict[col] for col in columns])
  583. temp_file.flush()
  584. return temp_file.name, temp_file
  585. except Exception as e:
  586. if temp_file:
  587. temp_file.close()
  588. raise HTTPException(
  589. status_code=500,
  590. detail=f"Failed to export data: {str(e)}",
  591. ) from e
  592. async def export_messages_to_csv(
  593. self,
  594. columns: Optional[list[str]] = None,
  595. filters: Optional[dict] = None,
  596. include_header: bool = True,
  597. handle_images: str = "metadata_only", # Options: "full", "metadata_only", "exclude"
  598. ) -> tuple[str, IO]:
  599. """
  600. Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
  601. Args:
  602. columns: List of columns to include in export
  603. filters: Filter criteria for messages
  604. include_header: Whether to include header row
  605. handle_images: How to handle image data in exports:
  606. - "full": Include complete image data (warning: may create large files)
  607. - "metadata_only": Replace image data with metadata only
  608. - "exclude": Remove image data completely
  609. """
  610. valid_columns = {
  611. "id",
  612. "conversation_id",
  613. "parent_id",
  614. "content",
  615. "metadata",
  616. "created_at",
  617. "has_image", # New virtual column to indicate image presence
  618. }
  619. if not columns:
  620. columns = list(valid_columns - {"has_image"})
  621. elif invalid_cols := set(columns) - valid_columns:
  622. raise ValueError(f"Invalid columns: {invalid_cols}")
  623. # Add virtual column for image presence
  624. virtual_columns = []
  625. has_image_column = False
  626. if "has_image" in columns:
  627. virtual_columns.append(
  628. "(content->>'image_url' IS NOT NULL OR content->>'image_data' IS NOT NULL) as has_image"
  629. )
  630. columns.remove("has_image")
  631. has_image_column = True
  632. select_stmt = f"""
  633. SELECT
  634. id::text,
  635. conversation_id::text,
  636. parent_id::text,
  637. content::text,
  638. metadata::text,
  639. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at
  640. {", " + ", ".join(virtual_columns) if virtual_columns else ""}
  641. FROM {self._get_table_name("messages")}
  642. """
  643. # Keep existing filter conditions setup
  644. conditions = []
  645. params: list[Any] = []
  646. param_index = 1
  647. if filters:
  648. for field, value in filters.items():
  649. if field not in valid_columns or field == "has_image":
  650. continue
  651. if isinstance(value, dict):
  652. for op, val in value.items():
  653. if op == "$eq":
  654. conditions.append(f"{field} = ${param_index}")
  655. params.append(val)
  656. param_index += 1
  657. elif op == "$gt":
  658. conditions.append(f"{field} > ${param_index}")
  659. params.append(val)
  660. param_index += 1
  661. elif op == "$lt":
  662. conditions.append(f"{field} < ${param_index}")
  663. params.append(val)
  664. param_index += 1
  665. else:
  666. conditions.append(f"{field} = ${param_index}")
  667. params.append(value)
  668. param_index += 1
  669. # Special filter for has_image
  670. if filters and "has_image" in filters:
  671. if filters["has_image"]:
  672. conditions.append(
  673. "(content->>'image_url' IS NOT NULL OR content->>'image_data' IS NOT NULL)"
  674. )
  675. if conditions:
  676. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  677. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  678. temp_file = None
  679. try:
  680. temp_file = tempfile.NamedTemporaryFile(
  681. mode="w", delete=True, suffix=".csv"
  682. )
  683. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  684. # Prepare export columns
  685. export_columns = list(columns)
  686. if has_image_column:
  687. export_columns.append("has_image")
  688. if include_header:
  689. writer.writerow(export_columns)
  690. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  691. async with conn.transaction():
  692. cursor = await conn.cursor(select_stmt, *params)
  693. chunk_size = 1000
  694. while True:
  695. rows = await cursor.fetch(chunk_size)
  696. if not rows:
  697. break
  698. for row in rows:
  699. row_dict = {
  700. "id": row[0],
  701. "conversation_id": row[1],
  702. "parent_id": row[2],
  703. "content": row[3],
  704. "metadata": row[4],
  705. "created_at": row[5],
  706. }
  707. # Add virtual column if present
  708. if has_image_column:
  709. row_dict["has_image"] = (
  710. "true" if row[6] else "false"
  711. )
  712. # Process image data based on handle_images setting
  713. if (
  714. "content" in columns
  715. and handle_images != "full"
  716. ):
  717. try:
  718. content_json = json.loads(
  719. row_dict["content"]
  720. )
  721. if (
  722. "image_data" in content_json
  723. and content_json["image_data"]
  724. ):
  725. media_type = content_json[
  726. "image_data"
  727. ].get("media_type", "image/jpeg")
  728. if handle_images == "metadata_only":
  729. content_json["image_data"] = {
  730. "media_type": media_type,
  731. "data": "[BASE64_DATA_EXCLUDED_FROM_EXPORT]",
  732. }
  733. elif handle_images == "exclude":
  734. content_json.pop(
  735. "image_data", None
  736. )
  737. row_dict["content"] = json.dumps(
  738. content_json
  739. )
  740. except (json.JSONDecodeError, TypeError) as e:
  741. logger.warning(
  742. f"Error processing message content for export: {e}"
  743. )
  744. writer.writerow(
  745. [row_dict[col] for col in export_columns]
  746. )
  747. temp_file.flush()
  748. return temp_file.name, temp_file
  749. except Exception as e:
  750. if temp_file:
  751. temp_file.close()
  752. raise HTTPException(
  753. status_code=500,
  754. detail=f"Failed to export data: {str(e)}",
  755. ) from e