prompts_handler.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748
  1. import json
  2. import logging
  3. import os
  4. from abc import abstractmethod
  5. from dataclasses import dataclass
  6. from datetime import datetime, timedelta
  7. from pathlib import Path
  8. from typing import Any, Generic, Optional, TypeVar
  9. import yaml
  10. from core.base import Handler, generate_default_prompt_id
  11. from .base import PostgresConnectionManager
  12. logger = logging.getLogger(__name__)
  13. T = TypeVar("T")
  14. @dataclass
  15. class CacheEntry(Generic[T]):
  16. """Represents a cached item with metadata."""
  17. value: T
  18. created_at: datetime
  19. last_accessed: datetime
  20. access_count: int = 0
  21. class Cache(Generic[T]):
  22. """A generic cache implementation with TTL and LRU-like features."""
  23. def __init__(
  24. self,
  25. ttl: Optional[timedelta] = None,
  26. max_size: Optional[int] = 1000,
  27. cleanup_interval: timedelta = timedelta(hours=1),
  28. ):
  29. self._cache: dict[str, CacheEntry[T]] = {}
  30. self._ttl = ttl
  31. self._max_size = max_size
  32. self._cleanup_interval = cleanup_interval
  33. self._last_cleanup = datetime.now()
  34. def get(self, key: str) -> Optional[T]:
  35. """Retrieve an item from cache."""
  36. self._maybe_cleanup()
  37. if key not in self._cache:
  38. return None
  39. entry = self._cache[key]
  40. if self._ttl and datetime.now() - entry.created_at > self._ttl:
  41. del self._cache[key]
  42. return None
  43. entry.last_accessed = datetime.now()
  44. entry.access_count += 1
  45. return entry.value
  46. def set(self, key: str, value: T) -> None:
  47. """Store an item in cache."""
  48. self._maybe_cleanup()
  49. now = datetime.now()
  50. self._cache[key] = CacheEntry(
  51. value=value, created_at=now, last_accessed=now
  52. )
  53. if self._max_size and len(self._cache) > self._max_size:
  54. self._evict_lru()
  55. def invalidate(self, key: str) -> None:
  56. """Remove an item from cache."""
  57. self._cache.pop(key, None)
  58. def clear(self) -> None:
  59. """Clear all cached items."""
  60. self._cache.clear()
  61. def _maybe_cleanup(self) -> None:
  62. """Periodically clean up expired entries."""
  63. now = datetime.now()
  64. if now - self._last_cleanup > self._cleanup_interval:
  65. self._cleanup()
  66. self._last_cleanup = now
  67. def _cleanup(self) -> None:
  68. """Remove expired entries."""
  69. if not self._ttl:
  70. return
  71. now = datetime.now()
  72. expired = [
  73. k for k, v in self._cache.items() if now - v.created_at > self._ttl
  74. ]
  75. for k in expired:
  76. del self._cache[k]
  77. def _evict_lru(self) -> None:
  78. """Remove least recently used item."""
  79. if not self._cache:
  80. return
  81. lru_key = min(
  82. self._cache.keys(), key=lambda k: self._cache[k].last_accessed
  83. )
  84. del self._cache[lru_key]
  85. class CacheablePromptHandler(Handler):
  86. """Abstract base class that adds caching capabilities to prompt
  87. handlers."""
  88. def __init__(
  89. self,
  90. cache_ttl: Optional[timedelta] = timedelta(hours=1),
  91. max_cache_size: Optional[int] = 1000,
  92. ):
  93. self._prompt_cache = Cache[str](ttl=cache_ttl, max_size=max_cache_size)
  94. self._template_cache = Cache[dict](
  95. ttl=cache_ttl, max_size=max_cache_size
  96. )
  97. def _cache_key(
  98. self, prompt_name: str, inputs: Optional[dict] = None
  99. ) -> str:
  100. """Generate a cache key for a prompt request."""
  101. if inputs:
  102. # Sort dict items for consistent keys
  103. sorted_inputs = sorted(inputs.items())
  104. return f"{prompt_name}:{sorted_inputs}"
  105. return prompt_name
  106. async def get_cached_prompt(
  107. self,
  108. prompt_name: str,
  109. inputs: Optional[dict[str, Any]] = None,
  110. prompt_override: Optional[str] = None,
  111. bypass_cache: bool = False,
  112. ) -> str:
  113. if prompt_override:
  114. # If the user gave us a direct override, use it.
  115. if inputs:
  116. try:
  117. return prompt_override.format(**inputs)
  118. except KeyError:
  119. return prompt_override
  120. return prompt_override
  121. cache_key = self._cache_key(prompt_name, inputs)
  122. # If not bypassing, try returning from the prompt-level cache
  123. if not bypass_cache:
  124. cached = self._prompt_cache.get(cache_key)
  125. if cached is not None:
  126. logger.debug(f"Prompt cache hit: {cache_key}")
  127. return cached
  128. logger.debug(
  129. "Prompt cache miss or bypass. Retrieving from DB or template cache."
  130. )
  131. # Notice the new parameter `bypass_template_cache` below
  132. result = await self._get_prompt_impl(
  133. prompt_name, inputs, bypass_template_cache=bypass_cache
  134. )
  135. self._prompt_cache.set(cache_key, result)
  136. return result
  137. async def get_prompt( # type: ignore
  138. self,
  139. name: str,
  140. inputs: Optional[dict] = None,
  141. prompt_override: Optional[str] = None,
  142. ) -> dict:
  143. query = f"""
  144. SELECT id, name, template, input_types, created_at, updated_at
  145. FROM {self._get_table_name("prompts")}
  146. WHERE name = $1;
  147. """
  148. result = await self.connection_manager.fetchrow_query(query, [name])
  149. if not result:
  150. raise ValueError(f"Prompt template '{name}' not found")
  151. input_types = result["input_types"]
  152. if isinstance(input_types, str):
  153. input_types = json.loads(input_types)
  154. return {
  155. "id": result["id"],
  156. "name": result["name"],
  157. "template": result["template"],
  158. "input_types": input_types,
  159. "created_at": result["created_at"],
  160. "updated_at": result["updated_at"],
  161. }
  162. def _format_prompt(
  163. self,
  164. template: str,
  165. inputs: Optional[dict[str, Any]],
  166. input_types: dict[str, str],
  167. ) -> str:
  168. if inputs:
  169. # optional input validation if needed
  170. for k, _v in inputs.items():
  171. if k not in input_types:
  172. raise ValueError(
  173. f"Unexpected input '{k}' for prompt with input types {input_types}"
  174. )
  175. return template.format(**inputs)
  176. return template
  177. async def update_prompt(
  178. self,
  179. name: str,
  180. template: Optional[str] = None,
  181. input_types: Optional[dict[str, str]] = None,
  182. ) -> None:
  183. """Public method to update a prompt with proper cache invalidation."""
  184. # First invalidate all caches for this prompt
  185. self._template_cache.invalidate(name)
  186. cache_keys_to_invalidate = [
  187. key
  188. for key in self._prompt_cache._cache.keys()
  189. if key.startswith(f"{name}:") or key == name
  190. ]
  191. for key in cache_keys_to_invalidate:
  192. self._prompt_cache.invalidate(key)
  193. # Perform the update
  194. await self._update_prompt_impl(name, template, input_types)
  195. # Force refresh template cache
  196. template_info = await self._get_template_info(name)
  197. if template_info:
  198. self._template_cache.set(name, template_info)
  199. @abstractmethod
  200. async def _update_prompt_impl(
  201. self,
  202. name: str,
  203. template: Optional[str] = None,
  204. input_types: Optional[dict[str, str]] = None,
  205. ) -> None:
  206. """Implementation of prompt update logic."""
  207. pass
  208. @abstractmethod
  209. async def _get_template_info(self, prompt_name: str) -> Optional[dict]:
  210. """Get template info with caching."""
  211. pass
  212. @abstractmethod
  213. async def _get_prompt_impl(
  214. self,
  215. prompt_name: str,
  216. inputs: Optional[dict[str, Any]] = None,
  217. bypass_template_cache: bool = False,
  218. ) -> str:
  219. """Implementation of prompt retrieval logic."""
  220. pass
  221. class PostgresPromptsHandler(CacheablePromptHandler):
  222. """PostgreSQL implementation of the CacheablePromptHandler."""
  223. def __init__(
  224. self,
  225. project_name: str,
  226. connection_manager: PostgresConnectionManager,
  227. prompt_directory: Optional[Path] = None,
  228. **cache_options,
  229. ):
  230. super().__init__(**cache_options)
  231. self.prompt_directory = (
  232. prompt_directory or Path(os.path.dirname(__file__)) / "prompts"
  233. )
  234. self.connection_manager = connection_manager
  235. self.project_name = project_name
  236. self.prompts: dict[str, dict[str, str | dict[str, str]]] = {}
  237. async def _load_prompts(self) -> None:
  238. """Load prompts from both database and YAML files."""
  239. # First load from database
  240. await self._load_prompts_from_database()
  241. # Then load from YAML files, potentially overriding unmodified database entries
  242. await self._load_prompts_from_yaml_directory()
  243. async def _load_prompts_from_database(self) -> None:
  244. """Load prompts from the database."""
  245. query = f"""
  246. SELECT id, name, template, input_types, created_at, updated_at
  247. FROM {self._get_table_name("prompts")};
  248. """
  249. try:
  250. results = await self.connection_manager.fetch_query(query)
  251. for row in results:
  252. logger.info(f"Loading saved prompt: {row['name']}")
  253. # Ensure input_types is a dictionary
  254. input_types = row["input_types"]
  255. if isinstance(input_types, str):
  256. input_types = json.loads(input_types)
  257. self.prompts[row["name"]] = {
  258. "id": row["id"],
  259. "template": row["template"],
  260. "input_types": input_types,
  261. "created_at": row["created_at"],
  262. "updated_at": row["updated_at"],
  263. }
  264. # Pre-populate the template cache
  265. self._template_cache.set(
  266. row["name"],
  267. {
  268. "id": row["id"],
  269. "template": row["template"],
  270. "input_types": input_types,
  271. },
  272. )
  273. logger.debug(f"Loaded {len(results)} prompts from database")
  274. except Exception as e:
  275. logger.error(f"Failed to load prompts from database: {e}")
  276. raise
  277. async def _load_prompts_from_yaml_directory(
  278. self, default_overwrite_on_diff: bool = False
  279. ) -> None:
  280. """Load prompts from YAML files in the specified directory.
  281. :param default_overwrite_on_diff: If a YAML prompt does not specify
  282. 'overwrite_on_diff', we use this default.
  283. """
  284. if not self.prompt_directory.is_dir():
  285. logger.warning(
  286. f"Prompt directory not found: {self.prompt_directory}"
  287. )
  288. return
  289. logger.info(f"Loading prompts from {self.prompt_directory}")
  290. for yaml_file in self.prompt_directory.glob("*.yaml"):
  291. logger.debug(f"Processing {yaml_file}")
  292. try:
  293. with open(yaml_file, "r", encoding="utf-8") as file:
  294. data = yaml.safe_load(file)
  295. if not isinstance(data, dict):
  296. raise ValueError(
  297. f"Invalid format in YAML file {yaml_file}"
  298. )
  299. for name, prompt_data in data.items():
  300. # Attempt to parse the relevant prompt fields
  301. template = prompt_data.get("template")
  302. input_types = prompt_data.get("input_types", {})
  303. # Decide on per-prompt overwrite behavior (or fallback)
  304. overwrite_on_diff = prompt_data.get(
  305. "overwrite_on_diff", default_overwrite_on_diff
  306. )
  307. # Some logic to determine if we *should* modify
  308. # For instance, preserve only if it has never been updated
  309. # (i.e., created_at == updated_at).
  310. should_modify = True
  311. if name in self.prompts:
  312. existing = self.prompts[name]
  313. should_modify = (
  314. existing["created_at"]
  315. == existing["updated_at"]
  316. )
  317. # If should_modify is True, the default logic is
  318. # preserve_existing = False,
  319. # so we can pass that in. Otherwise, preserve_existing=True
  320. # effectively means we skip the update.
  321. logger.info(
  322. f"Loading default prompt: {name} from {yaml_file}."
  323. )
  324. await self.add_prompt(
  325. name=name,
  326. template=template,
  327. input_types=input_types,
  328. preserve_existing=False,
  329. overwrite_on_diff=overwrite_on_diff,
  330. )
  331. except Exception as e:
  332. logger.error(f"Error loading {yaml_file}: {e}")
  333. continue
  334. def _get_table_name(self, base_name: str) -> str:
  335. """Get the fully qualified table name."""
  336. return f"{self.project_name}.{base_name}"
  337. # Implementation of abstract methods from CacheablePromptHandler
  338. async def _get_prompt_impl(
  339. self,
  340. prompt_name: str,
  341. inputs: Optional[dict[str, Any]] = None,
  342. bypass_template_cache: bool = False,
  343. ) -> str:
  344. """Implementation of database prompt retrieval."""
  345. # If we're bypassing the template cache, skip the cache lookup
  346. if not bypass_template_cache:
  347. template_info = self._template_cache.get(prompt_name)
  348. if template_info is not None:
  349. logger.debug(f"Template cache hit: {prompt_name}")
  350. # use that
  351. return self._format_prompt(
  352. template_info["template"],
  353. inputs,
  354. template_info["input_types"],
  355. )
  356. # If we get here, either no cache was found or bypass_cache is True
  357. query = f"""
  358. SELECT template, input_types
  359. FROM {self._get_table_name("prompts")}
  360. WHERE name = $1;
  361. """
  362. result = await self.connection_manager.fetchrow_query(
  363. query, [prompt_name]
  364. )
  365. if not result:
  366. raise ValueError(f"Prompt template '{prompt_name}' not found")
  367. template = result["template"]
  368. input_types = result["input_types"]
  369. if isinstance(input_types, str):
  370. input_types = json.loads(input_types)
  371. # Update template cache if not bypassing it
  372. if not bypass_template_cache:
  373. self._template_cache.set(
  374. prompt_name, {"template": template, "input_types": input_types}
  375. )
  376. return self._format_prompt(template, inputs, input_types)
  377. async def _get_template_info(self, prompt_name: str) -> Optional[dict]: # type: ignore
  378. """Get template info with caching."""
  379. cached = self._template_cache.get(prompt_name)
  380. if cached is not None:
  381. return cached
  382. query = f"""
  383. SELECT template, input_types
  384. FROM {self._get_table_name("prompts")}
  385. WHERE name = $1;
  386. """
  387. result = await self.connection_manager.fetchrow_query(
  388. query, [prompt_name]
  389. )
  390. if result:
  391. # Ensure input_types is a dictionary
  392. input_types = result["input_types"]
  393. if isinstance(input_types, str):
  394. input_types = json.loads(input_types)
  395. template_info = {
  396. "template": result["template"],
  397. "input_types": input_types,
  398. }
  399. self._template_cache.set(prompt_name, template_info)
  400. return template_info
  401. return None
  402. async def _update_prompt_impl(
  403. self,
  404. name: str,
  405. template: Optional[str] = None,
  406. input_types: Optional[dict[str, str]] = None,
  407. ) -> None:
  408. """Implementation of database prompt update with proper connection
  409. handling."""
  410. if not template and not input_types:
  411. return
  412. # Clear caches first
  413. self._template_cache.invalidate(name)
  414. for key in list(self._prompt_cache._cache.keys()):
  415. if key.startswith(f"{name}:"):
  416. self._prompt_cache.invalidate(key)
  417. # Build update query
  418. set_clauses = []
  419. params = [name] # First parameter is always the name
  420. param_index = 2 # Start from 2 since $1 is name
  421. if template:
  422. set_clauses.append(f"template = ${param_index}")
  423. params.append(template)
  424. param_index += 1
  425. if input_types:
  426. set_clauses.append(f"input_types = ${param_index}")
  427. params.append(json.dumps(input_types))
  428. param_index += 1
  429. set_clauses.append("updated_at = CURRENT_TIMESTAMP")
  430. query = f"""
  431. UPDATE {self._get_table_name("prompts")}
  432. SET {", ".join(set_clauses)}
  433. WHERE name = $1
  434. RETURNING id, template, input_types;
  435. """
  436. try:
  437. # Execute update and get returned values
  438. result = await self.connection_manager.fetchrow_query(
  439. query, params
  440. )
  441. if not result:
  442. raise ValueError(f"Prompt template '{name}' not found")
  443. # Update in-memory state
  444. if name in self.prompts:
  445. if template:
  446. self.prompts[name]["template"] = template
  447. if input_types:
  448. self.prompts[name]["input_types"] = input_types
  449. self.prompts[name]["updated_at"] = datetime.now().isoformat()
  450. except Exception as e:
  451. logger.error(f"Failed to update prompt {name}: {str(e)}")
  452. raise
  453. async def create_tables(self):
  454. """Create the necessary tables for storing prompts."""
  455. query = f"""
  456. CREATE TABLE IF NOT EXISTS {self._get_table_name("prompts")} (
  457. id UUID PRIMARY KEY,
  458. name VARCHAR(255) NOT NULL UNIQUE,
  459. template TEXT NOT NULL,
  460. input_types JSONB NOT NULL,
  461. created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
  462. updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
  463. );
  464. CREATE OR REPLACE FUNCTION {self.project_name}.update_updated_at_column()
  465. RETURNS TRIGGER AS $$
  466. BEGIN
  467. NEW.updated_at = CURRENT_TIMESTAMP;
  468. RETURN NEW;
  469. END;
  470. $$ language 'plpgsql';
  471. DROP TRIGGER IF EXISTS update_prompts_updated_at
  472. ON {self._get_table_name("prompts")};
  473. CREATE TRIGGER update_prompts_updated_at
  474. BEFORE UPDATE ON {self._get_table_name("prompts")}
  475. FOR EACH ROW
  476. EXECUTE FUNCTION {self.project_name}.update_updated_at_column();
  477. """
  478. await self.connection_manager.execute_query(query)
  479. await self._load_prompts()
  480. async def add_prompt(
  481. self,
  482. name: str,
  483. template: str,
  484. input_types: dict[str, str],
  485. preserve_existing: bool = False,
  486. overwrite_on_diff: bool = False, # <-- new param
  487. ) -> None:
  488. """Add or update a prompt.
  489. If `preserve_existing` is True and prompt already exists, we skip updating.
  490. If `overwrite_on_diff` is True and an existing prompt differs from what is provided,
  491. we overwrite and log a warning. Otherwise, we skip if the prompt differs.
  492. """
  493. # Check if prompt is in-memory
  494. existing_prompt = self.prompts.get(name)
  495. # If preserving existing and it already exists, skip entirely
  496. if preserve_existing and existing_prompt:
  497. logger.debug(
  498. f"Preserving existing prompt: {name}, skipping update."
  499. )
  500. return
  501. # If an existing prompt is found, check for diffs
  502. if existing_prompt:
  503. existing_template = existing_prompt["template"]
  504. existing_input_types = existing_prompt["input_types"]
  505. # If there's a difference in template or input_types, decide to overwrite or skip
  506. if (
  507. existing_template != template
  508. or existing_input_types != input_types
  509. ):
  510. if overwrite_on_diff:
  511. logger.warning(
  512. f"Overwriting existing prompt '{name}' due to detected diff."
  513. )
  514. else:
  515. logger.info(
  516. f"Prompt '{name}' differs from existing but overwrite_on_diff=False. Skipping update."
  517. )
  518. return
  519. prompt_id = generate_default_prompt_id(name)
  520. # Ensure input_types is properly serialized
  521. input_types_json = (
  522. json.dumps(input_types)
  523. if isinstance(input_types, dict)
  524. else input_types
  525. )
  526. # Upsert logic
  527. query = f"""
  528. INSERT INTO {self._get_table_name("prompts")} (id, name, template, input_types)
  529. VALUES ($1, $2, $3, $4)
  530. ON CONFLICT (name) DO UPDATE
  531. SET template = EXCLUDED.template,
  532. input_types = EXCLUDED.input_types,
  533. updated_at = CURRENT_TIMESTAMP
  534. RETURNING id, created_at, updated_at;
  535. """
  536. result = await self.connection_manager.fetchrow_query(
  537. query, [prompt_id, name, template, input_types_json]
  538. )
  539. self.prompts[name] = {
  540. "id": result["id"],
  541. "template": template,
  542. "input_types": input_types,
  543. "created_at": result["created_at"],
  544. "updated_at": result["updated_at"],
  545. }
  546. # Update template cache
  547. self._template_cache.set(
  548. name,
  549. {
  550. "id": prompt_id,
  551. "template": template,
  552. "input_types": input_types,
  553. },
  554. )
  555. # Invalidate any cached formatted prompts
  556. for key in list(self._prompt_cache._cache.keys()):
  557. if key.startswith(f"{name}:"):
  558. self._prompt_cache.invalidate(key)
  559. async def get_all_prompts(self) -> dict[str, Any]:
  560. """Retrieve all stored prompts."""
  561. query = f"""
  562. SELECT id, name, template, input_types, created_at, updated_at, COUNT(*) OVER() AS total_entries
  563. FROM {self._get_table_name("prompts")};
  564. """
  565. results = await self.connection_manager.fetch_query(query)
  566. if not results:
  567. return {"results": [], "total_entries": 0}
  568. total_entries = results[0]["total_entries"] if results else 0
  569. prompts = [
  570. {
  571. "name": row["name"],
  572. "id": row["id"],
  573. "template": row["template"],
  574. "input_types": (
  575. json.loads(row["input_types"])
  576. if isinstance(row["input_types"], str)
  577. else row["input_types"]
  578. ),
  579. "created_at": row["created_at"],
  580. "updated_at": row["updated_at"],
  581. }
  582. for row in results
  583. ]
  584. return {"results": prompts, "total_entries": total_entries}
  585. async def delete_prompt(self, name: str) -> None:
  586. """Delete a prompt template."""
  587. query = f"""
  588. DELETE FROM {self._get_table_name("prompts")}
  589. WHERE name = $1;
  590. """
  591. result = await self.connection_manager.execute_query(query, [name])
  592. if result == "DELETE 0":
  593. raise ValueError(f"Prompt template '{name}' not found")
  594. # Invalidate caches
  595. self._template_cache.invalidate(name)
  596. for key in list(self._prompt_cache._cache.keys()):
  597. if key.startswith(f"{name}:"):
  598. self._prompt_cache.invalidate(key)
  599. async def get_message_payload(
  600. self,
  601. system_prompt_name: Optional[str] = None,
  602. system_role: str = "system",
  603. system_inputs: dict | None = None,
  604. system_prompt_override: Optional[str] = None,
  605. task_prompt_name: Optional[str] = None,
  606. task_role: str = "user",
  607. task_inputs: Optional[dict] = None,
  608. task_prompt: Optional[str] = None,
  609. ) -> list[dict]:
  610. """Create a message payload from system and task prompts."""
  611. if system_inputs is None:
  612. system_inputs = {}
  613. if task_inputs is None:
  614. task_inputs = {}
  615. if system_prompt_override:
  616. system_prompt = system_prompt_override
  617. else:
  618. system_prompt = await self.get_cached_prompt(
  619. system_prompt_name or "system",
  620. system_inputs,
  621. prompt_override=system_prompt_override,
  622. )
  623. task_prompt = await self.get_cached_prompt(
  624. task_prompt_name or "rag",
  625. task_inputs,
  626. prompt_override=task_prompt,
  627. )
  628. return [
  629. {
  630. "role": system_role,
  631. "content": system_prompt,
  632. },
  633. {
  634. "role": task_role,
  635. "content": task_prompt,
  636. },
  637. ]