prompts_handler.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639
  1. import json
  2. import logging
  3. import os
  4. from abc import abstractmethod
  5. from dataclasses import dataclass
  6. from datetime import datetime, timedelta
  7. from pathlib import Path
  8. from typing import Any, Generic, Optional, TypeVar
  9. import yaml
  10. from core.base import Handler, generate_default_prompt_id
  11. from .base import PostgresConnectionManager
  12. logger = logging.getLogger(__name__)
  13. T = TypeVar("T")
  14. @dataclass
  15. class CacheEntry(Generic[T]):
  16. """Represents a cached item with metadata"""
  17. value: T
  18. created_at: datetime
  19. last_accessed: datetime
  20. access_count: int = 0
  21. class Cache(Generic[T]):
  22. """A generic cache implementation with TTL and LRU-like features"""
  23. def __init__(
  24. self,
  25. ttl: Optional[timedelta] = None,
  26. max_size: Optional[int] = 1000,
  27. cleanup_interval: timedelta = timedelta(hours=1),
  28. ):
  29. self._cache: dict[str, CacheEntry[T]] = {}
  30. self._ttl = ttl
  31. self._max_size = max_size
  32. self._cleanup_interval = cleanup_interval
  33. self._last_cleanup = datetime.now()
  34. def get(self, key: str) -> Optional[T]:
  35. """Retrieve an item from cache"""
  36. self._maybe_cleanup()
  37. if key not in self._cache:
  38. return None
  39. entry = self._cache[key]
  40. if self._ttl and datetime.now() - entry.created_at > self._ttl:
  41. del self._cache[key]
  42. return None
  43. entry.last_accessed = datetime.now()
  44. entry.access_count += 1
  45. return entry.value
  46. def set(self, key: str, value: T) -> None:
  47. """Store an item in cache"""
  48. self._maybe_cleanup()
  49. now = datetime.now()
  50. self._cache[key] = CacheEntry(
  51. value=value, created_at=now, last_accessed=now
  52. )
  53. if self._max_size and len(self._cache) > self._max_size:
  54. self._evict_lru()
  55. def invalidate(self, key: str) -> None:
  56. """Remove an item from cache"""
  57. self._cache.pop(key, None)
  58. def clear(self) -> None:
  59. """Clear all cached items"""
  60. self._cache.clear()
  61. def _maybe_cleanup(self) -> None:
  62. """Periodically clean up expired entries"""
  63. now = datetime.now()
  64. if now - self._last_cleanup > self._cleanup_interval:
  65. self._cleanup()
  66. self._last_cleanup = now
  67. def _cleanup(self) -> None:
  68. """Remove expired entries"""
  69. if not self._ttl:
  70. return
  71. now = datetime.now()
  72. expired = [
  73. k for k, v in self._cache.items() if now - v.created_at > self._ttl
  74. ]
  75. for k in expired:
  76. del self._cache[k]
  77. def _evict_lru(self) -> None:
  78. """Remove least recently used item"""
  79. if not self._cache:
  80. return
  81. lru_key = min(
  82. self._cache.keys(), key=lambda k: self._cache[k].last_accessed
  83. )
  84. del self._cache[lru_key]
  85. class CacheablePromptHandler(Handler):
  86. """Abstract base class that adds caching capabilities to prompt handlers"""
  87. def __init__(
  88. self,
  89. cache_ttl: Optional[timedelta] = timedelta(hours=1),
  90. max_cache_size: Optional[int] = 1000,
  91. ):
  92. self._prompt_cache = Cache[str](ttl=cache_ttl, max_size=max_cache_size)
  93. self._template_cache = Cache[dict](
  94. ttl=cache_ttl, max_size=max_cache_size
  95. )
  96. def _cache_key(
  97. self, prompt_name: str, inputs: Optional[dict] = None
  98. ) -> str:
  99. """Generate a cache key for a prompt request"""
  100. if inputs:
  101. # Sort dict items for consistent keys
  102. sorted_inputs = sorted(inputs.items())
  103. return f"{prompt_name}:{sorted_inputs}"
  104. return prompt_name
  105. async def get_cached_prompt(
  106. self,
  107. prompt_name: str,
  108. inputs: Optional[dict[str, Any]] = None,
  109. prompt_override: Optional[str] = None,
  110. bypass_cache: bool = False,
  111. ) -> str:
  112. """Get a prompt with caching support"""
  113. if prompt_override:
  114. if inputs:
  115. try:
  116. return prompt_override.format(**inputs)
  117. except KeyError:
  118. return prompt_override
  119. return prompt_override
  120. cache_key = self._cache_key(prompt_name, inputs)
  121. if not bypass_cache:
  122. cached = self._prompt_cache.get(cache_key)
  123. if cached is not None:
  124. logger.debug(f"Cache hit for prompt: {cache_key}")
  125. return cached
  126. result = await self._get_prompt_impl(prompt_name, inputs)
  127. self._prompt_cache.set(cache_key, result)
  128. return result
  129. async def get_prompt( # type: ignore
  130. self,
  131. name: str,
  132. inputs: Optional[dict] = None,
  133. prompt_override: Optional[str] = None,
  134. ) -> dict:
  135. query = f"""
  136. SELECT id, name, template, input_types, created_at, updated_at
  137. FROM {self._get_table_name("prompts")}
  138. WHERE name = $1;
  139. """
  140. result = await self.connection_manager.fetchrow_query(query, [name])
  141. if not result:
  142. raise ValueError(f"Prompt template '{name}' not found")
  143. input_types = result["input_types"]
  144. if isinstance(input_types, str):
  145. input_types = json.loads(input_types)
  146. return {
  147. "id": result["id"],
  148. "name": result["name"],
  149. "template": result["template"],
  150. "input_types": input_types,
  151. "created_at": result["created_at"],
  152. "updated_at": result["updated_at"],
  153. }
  154. @abstractmethod
  155. async def _get_prompt_impl(
  156. self, prompt_name: str, inputs: Optional[dict[str, Any]] = None
  157. ) -> str:
  158. """Implementation of prompt retrieval logic"""
  159. pass
  160. async def update_prompt(
  161. self,
  162. name: str,
  163. template: Optional[str] = None,
  164. input_types: Optional[dict[str, str]] = None,
  165. ) -> None:
  166. """Public method to update a prompt with proper cache invalidation"""
  167. # First invalidate all caches for this prompt
  168. self._template_cache.invalidate(name)
  169. cache_keys_to_invalidate = [
  170. key
  171. for key in self._prompt_cache._cache.keys()
  172. if key.startswith(f"{name}:") or key == name
  173. ]
  174. for key in cache_keys_to_invalidate:
  175. self._prompt_cache.invalidate(key)
  176. # Perform the update
  177. await self._update_prompt_impl(name, template, input_types)
  178. # Force refresh template cache
  179. template_info = await self._get_template_info(name)
  180. if template_info:
  181. self._template_cache.set(name, template_info)
  182. @abstractmethod
  183. async def _update_prompt_impl(
  184. self,
  185. name: str,
  186. template: Optional[str] = None,
  187. input_types: Optional[dict[str, str]] = None,
  188. ) -> None:
  189. """Implementation of prompt update logic"""
  190. pass
  191. @abstractmethod
  192. async def _get_template_info(self, prompt_name: str) -> Optional[dict]:
  193. """Get template info with caching"""
  194. pass
  195. class PostgresPromptsHandler(CacheablePromptHandler):
  196. """PostgreSQL implementation of the CacheablePromptHandler."""
  197. def __init__(
  198. self,
  199. project_name: str,
  200. connection_manager: PostgresConnectionManager,
  201. prompt_directory: Optional[Path] = None,
  202. **cache_options,
  203. ):
  204. super().__init__(**cache_options)
  205. self.prompt_directory = (
  206. prompt_directory or Path(os.path.dirname(__file__)) / "prompts"
  207. )
  208. self.connection_manager = connection_manager
  209. self.project_name = project_name
  210. self.prompts: dict[str, dict[str, str | dict[str, str]]] = {}
  211. async def _load_prompts(self) -> None:
  212. """Load prompts from both database and YAML files."""
  213. # First load from database
  214. await self._load_prompts_from_database()
  215. # Then load from YAML files, potentially overriding unmodified database entries
  216. await self._load_prompts_from_yaml_directory()
  217. async def _load_prompts_from_database(self) -> None:
  218. """Load prompts from the database."""
  219. query = f"""
  220. SELECT id, name, template, input_types, created_at, updated_at
  221. FROM {self._get_table_name("prompts")};
  222. """
  223. try:
  224. results = await self.connection_manager.fetch_query(query)
  225. for row in results:
  226. logger.info(f"Loading saved prompt: {row['name']}")
  227. # Ensure input_types is a dictionary
  228. input_types = row["input_types"]
  229. if isinstance(input_types, str):
  230. input_types = json.loads(input_types)
  231. self.prompts[row["name"]] = {
  232. "id": row["id"],
  233. "template": row["template"],
  234. "input_types": input_types,
  235. "created_at": row["created_at"],
  236. "updated_at": row["updated_at"],
  237. }
  238. # Pre-populate the template cache
  239. self._template_cache.set(
  240. row["name"],
  241. {
  242. "id": row["id"],
  243. "template": row["template"],
  244. "input_types": input_types,
  245. },
  246. )
  247. logger.debug(f"Loaded {len(results)} prompts from database")
  248. except Exception as e:
  249. logger.error(f"Failed to load prompts from database: {e}")
  250. raise
  251. async def _load_prompts_from_yaml_directory(self) -> None:
  252. """Load prompts from YAML files in the specified directory."""
  253. if not self.prompt_directory.is_dir():
  254. logger.warning(
  255. f"Prompt directory not found: {self.prompt_directory}"
  256. )
  257. return
  258. logger.info(f"Loading prompts from {self.prompt_directory}")
  259. for yaml_file in self.prompt_directory.glob("*.yaml"):
  260. logger.debug(f"Processing {yaml_file}")
  261. try:
  262. with open(yaml_file, "r") as file:
  263. data = yaml.safe_load(file)
  264. if not isinstance(data, dict):
  265. raise ValueError(
  266. f"Invalid format in YAML file {yaml_file}"
  267. )
  268. for name, prompt_data in data.items():
  269. should_modify = True
  270. if name in self.prompts:
  271. # Only modify if the prompt hasn't been updated since creation
  272. existing = self.prompts[name]
  273. should_modify = (
  274. existing["created_at"]
  275. == existing["updated_at"]
  276. )
  277. if should_modify:
  278. logger.info(f"Loading default prompt: {name}")
  279. await self.add_prompt(
  280. name=name,
  281. template=prompt_data["template"],
  282. input_types=prompt_data.get("input_types", {}),
  283. preserve_existing=(not should_modify),
  284. )
  285. except Exception as e:
  286. logger.error(f"Error loading {yaml_file}: {e}")
  287. continue
  288. def _get_table_name(self, base_name: str) -> str:
  289. """Get the fully qualified table name."""
  290. return f"{self.project_name}.{base_name}"
  291. # Implementation of abstract methods from CacheablePromptHandler
  292. async def _get_prompt_impl(
  293. self, prompt_name: str, inputs: Optional[dict[str, Any]] = None
  294. ) -> str:
  295. """Implementation of database prompt retrieval"""
  296. template_info = await self._get_template_info(prompt_name)
  297. if not template_info:
  298. raise ValueError(f"Prompt template '{prompt_name}' not found")
  299. template, input_types = (
  300. template_info["template"],
  301. template_info["input_types"],
  302. )
  303. if inputs:
  304. # Validate input types
  305. for key, value in inputs.items():
  306. expected_type = input_types.get(key)
  307. if not expected_type:
  308. raise ValueError(
  309. f"Unexpected input key: {key} expected input types: {input_types}"
  310. )
  311. return template.format(**inputs)
  312. return template
  313. async def _get_template_info(self, prompt_name: str) -> Optional[dict]: # type: ignore
  314. """Get template info with caching"""
  315. cached = self._template_cache.get(prompt_name)
  316. if cached is not None:
  317. return cached
  318. query = f"""
  319. SELECT template, input_types
  320. FROM {self._get_table_name("prompts")}
  321. WHERE name = $1;
  322. """
  323. result = await self.connection_manager.fetchrow_query(
  324. query, [prompt_name]
  325. )
  326. if result:
  327. # Ensure input_types is a dictionary
  328. input_types = result["input_types"]
  329. if isinstance(input_types, str):
  330. input_types = json.loads(input_types)
  331. template_info = {
  332. "template": result["template"],
  333. "input_types": input_types,
  334. }
  335. self._template_cache.set(prompt_name, template_info)
  336. return template_info
  337. return None
  338. async def _update_prompt_impl(
  339. self,
  340. name: str,
  341. template: Optional[str] = None,
  342. input_types: Optional[dict[str, str]] = None,
  343. ) -> None:
  344. """Implementation of database prompt update with proper connection handling"""
  345. if not template and not input_types:
  346. return
  347. # Clear caches first
  348. self._template_cache.invalidate(name)
  349. for key in list(self._prompt_cache._cache.keys()):
  350. if key.startswith(f"{name}:"):
  351. self._prompt_cache.invalidate(key)
  352. # Build update query
  353. set_clauses = []
  354. params = [name] # First parameter is always the name
  355. param_index = 2 # Start from 2 since $1 is name
  356. if template:
  357. set_clauses.append(f"template = ${param_index}")
  358. params.append(template)
  359. param_index += 1
  360. if input_types:
  361. set_clauses.append(f"input_types = ${param_index}")
  362. params.append(json.dumps(input_types))
  363. param_index += 1
  364. set_clauses.append("updated_at = CURRENT_TIMESTAMP")
  365. query = f"""
  366. UPDATE {self._get_table_name("prompts")}
  367. SET {', '.join(set_clauses)}
  368. WHERE name = $1
  369. RETURNING id, template, input_types;
  370. """
  371. try:
  372. # Execute update and get returned values
  373. result = await self.connection_manager.fetchrow_query(
  374. query, params
  375. )
  376. if not result:
  377. raise ValueError(f"Prompt template '{name}' not found")
  378. # Update in-memory state
  379. if name in self.prompts:
  380. if template:
  381. self.prompts[name]["template"] = template
  382. if input_types:
  383. self.prompts[name]["input_types"] = input_types
  384. self.prompts[name]["updated_at"] = datetime.now().isoformat()
  385. except Exception as e:
  386. logger.error(f"Failed to update prompt {name}: {str(e)}")
  387. raise
  388. async def create_tables(self):
  389. """Create the necessary tables for storing prompts."""
  390. query = f"""
  391. CREATE TABLE IF NOT EXISTS {self._get_table_name("prompts")} (
  392. id UUID PRIMARY KEY,
  393. name VARCHAR(255) NOT NULL UNIQUE,
  394. template TEXT NOT NULL,
  395. input_types JSONB NOT NULL,
  396. created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
  397. updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
  398. );
  399. CREATE OR REPLACE FUNCTION {self.project_name}.update_updated_at_column()
  400. RETURNS TRIGGER AS $$
  401. BEGIN
  402. NEW.updated_at = CURRENT_TIMESTAMP;
  403. RETURN NEW;
  404. END;
  405. $$ language 'plpgsql';
  406. DROP TRIGGER IF EXISTS update_prompts_updated_at
  407. ON {self._get_table_name("prompts")};
  408. CREATE TRIGGER update_prompts_updated_at
  409. BEFORE UPDATE ON {self._get_table_name("prompts")}
  410. FOR EACH ROW
  411. EXECUTE FUNCTION {self.project_name}.update_updated_at_column();
  412. """
  413. await self.connection_manager.execute_query(query)
  414. await self._load_prompts()
  415. async def add_prompt(
  416. self,
  417. name: str,
  418. template: str,
  419. input_types: dict[str, str],
  420. preserve_existing: bool = False,
  421. ) -> None:
  422. """Add or update a prompt."""
  423. if preserve_existing and name in self.prompts:
  424. return
  425. id = generate_default_prompt_id(name)
  426. # Ensure input_types is properly serialized
  427. input_types_json = (
  428. json.dumps(input_types)
  429. if isinstance(input_types, dict)
  430. else input_types
  431. )
  432. query = f"""
  433. INSERT INTO {self._get_table_name("prompts")} (id, name, template, input_types)
  434. VALUES ($1, $2, $3, $4)
  435. ON CONFLICT (name) DO UPDATE
  436. SET template = EXCLUDED.template,
  437. input_types = EXCLUDED.input_types,
  438. updated_at = CURRENT_TIMESTAMP
  439. RETURNING id, created_at, updated_at;
  440. """
  441. result = await self.connection_manager.fetchrow_query(
  442. query, [id, name, template, input_types_json]
  443. )
  444. self.prompts[name] = {
  445. "id": result["id"],
  446. "template": template,
  447. "input_types": input_types,
  448. "created_at": result["created_at"],
  449. "updated_at": result["updated_at"],
  450. }
  451. # Update template cache
  452. self._template_cache.set(
  453. name,
  454. {
  455. "id": id,
  456. "template": template,
  457. "input_types": input_types,
  458. }, # Store as dict in cache
  459. )
  460. # Invalidate any cached formatted prompts
  461. for key in list(self._prompt_cache._cache.keys()):
  462. if key.startswith(f"{name}:"):
  463. self._prompt_cache.invalidate(key)
  464. async def get_all_prompts(self) -> dict[str, Any]:
  465. """Retrieve all stored prompts."""
  466. query = f"""
  467. SELECT id, name, template, input_types, created_at, updated_at, COUNT(*) OVER() AS total_entries
  468. FROM {self._get_table_name("prompts")};
  469. """
  470. results = await self.connection_manager.fetch_query(query)
  471. if not results:
  472. return {"results": [], "total_entries": 0}
  473. total_entries = results[0]["total_entries"] if results else 0
  474. prompts = [
  475. {
  476. "name": row["name"],
  477. "id": row["id"],
  478. "template": row["template"],
  479. "input_types": (
  480. json.loads(row["input_types"])
  481. if isinstance(row["input_types"], str)
  482. else row["input_types"]
  483. ),
  484. "created_at": row["created_at"],
  485. "updated_at": row["updated_at"],
  486. }
  487. for row in results
  488. ]
  489. return {"results": prompts, "total_entries": total_entries}
  490. async def delete_prompt(self, name: str) -> None:
  491. """Delete a prompt template."""
  492. query = f"""
  493. DELETE FROM {self._get_table_name("prompts")}
  494. WHERE name = $1;
  495. """
  496. result = await self.connection_manager.execute_query(query, [name])
  497. if result == "DELETE 0":
  498. raise ValueError(f"Prompt template '{name}' not found")
  499. # Invalidate caches
  500. self._template_cache.invalidate(name)
  501. for key in list(self._prompt_cache._cache.keys()):
  502. if key.startswith(f"{name}:"):
  503. self._prompt_cache.invalidate(key)
  504. async def get_message_payload(
  505. self,
  506. system_prompt_name: Optional[str] = None,
  507. system_role: str = "system",
  508. system_inputs: dict = {},
  509. system_prompt_override: Optional[str] = None,
  510. task_prompt_name: Optional[str] = None,
  511. task_role: str = "user",
  512. task_inputs: dict = {},
  513. task_prompt_override: Optional[str] = None,
  514. ) -> list[dict]:
  515. """Create a message payload from system and task prompts."""
  516. if system_prompt_override:
  517. system_prompt = system_prompt_override
  518. else:
  519. system_prompt = await self.get_cached_prompt(
  520. system_prompt_name or "default_system",
  521. system_inputs,
  522. prompt_override=system_prompt_override,
  523. )
  524. task_prompt = await self.get_cached_prompt(
  525. task_prompt_name or "default_rag",
  526. task_inputs,
  527. prompt_override=task_prompt_override,
  528. )
  529. return [
  530. {
  531. "role": system_role,
  532. "content": system_prompt,
  533. },
  534. {
  535. "role": task_role,
  536. "content": task_prompt,
  537. },
  538. ]