prompt.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. """Abstraction for a prompt that can be formatted with inputs."""
  2. import logging
  3. from datetime import datetime
  4. from typing import Any
  5. from uuid import UUID, uuid4
  6. from pydantic import BaseModel, Field
  7. logger = logging.getLogger()
  8. class Prompt(BaseModel):
  9. """A prompt that can be formatted with inputs."""
  10. id: UUID = Field(default_factory=uuid4)
  11. name: str
  12. template: str
  13. input_types: dict[str, str]
  14. created_at: datetime = Field(default_factory=datetime.utcnow)
  15. updated_at: datetime = Field(default_factory=datetime.utcnow)
  16. def format_prompt(self, inputs: dict[str, Any]) -> str:
  17. self._validate_inputs(inputs)
  18. return self.template.format(**inputs)
  19. def _validate_inputs(self, inputs: dict[str, Any]) -> None:
  20. for var, expected_type_name in self.input_types.items():
  21. expected_type = self._convert_type(expected_type_name)
  22. if var not in inputs:
  23. raise ValueError(f"Missing input: {var}")
  24. if not isinstance(inputs[var], expected_type):
  25. raise TypeError(
  26. f"Input '{var}' must be of type {expected_type.__name__}, got {type(inputs[var]).__name__} instead."
  27. )
  28. def _convert_type(self, type_name: str) -> type:
  29. type_mapping = {"int": int, "str": str}
  30. return type_mapping.get(type_name, str)