run.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. from datetime import datetime
  2. from typing import Optional, Any, Iterable, Dict, List, Union, Literal
  3. from pydantic import Field as PDField
  4. from sqlalchemy import Column, Enum
  5. from sqlalchemy.sql.sqltypes import JSON, TEXT
  6. from sqlmodel import Field
  7. from pydantic import model_validator
  8. from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
  9. from app.models.message import MessageCreate
  10. from app.schemas.tool.authentication import Authentication
  11. # from typing_extensions import Literal, Required, TypedDict, TypeAlias
  12. class ChatCompletionAudioParam(BaseModel):
  13. # format: Required[Literal["wav", "mp3", "flac", "opus", "pcm16"]]
  14. format: str = Field(
  15. sa_column=Column(Enum("wav", "mp3", "flac", "opus", "pcm16"), nullable=False)
  16. )
  17. """Specifies the output audio format.
  18. Must be one of `wav`, `mp3`, `flac`, `opus`, or `pcm16`.
  19. """
  20. # voice: Required[
  21. # Literal["alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse"]
  22. # ]
  23. voice: str = Field(
  24. sa_column=Column(
  25. Enum(
  26. "alloy",
  27. "ash",
  28. "ballad",
  29. "coral",
  30. "echo",
  31. "sage",
  32. "shimmer",
  33. "verse",
  34. ),
  35. nullable=False,
  36. )
  37. )
  38. """The voice the model uses to respond.
  39. Supported voices are `ash`, `ballad`, `coral`, `sage`, and `verse` (also
  40. supported but not recommended are `alloy`, `echo`, and `shimmer`; these voices
  41. are less expressive).
  42. """
  43. # class ChatCompletionModality(BaseModel):
  44. # text = str = Field(sa_column=Column(Enum("text"), nullable=False))
  45. # audio = str = Field(sa_column=Column(Enum("audio"), nullable=False))
  46. class RunBase(BaseModel):
  47. instructions: Optional[str] = Field(
  48. default=None, max_length=32768, sa_column=Column(TEXT)
  49. )
  50. model: Optional[str] = Field(default=None)
  51. status: str = Field(
  52. default="queued",
  53. sa_column=Column(
  54. Enum(
  55. "cancelled",
  56. "cancelling",
  57. "completed",
  58. "expired",
  59. "failed",
  60. "in_progress",
  61. "queued",
  62. "requires_action",
  63. ),
  64. default="queued",
  65. nullable=True,
  66. ),
  67. )
  68. assistant_id: str = Field(nullable=False)
  69. thread_id: str = Field(default=None, nullable=False)
  70. object: str = Field(nullable=False, default="thread.run")
  71. file_ids: Optional[list] = Field(default=[], sa_column=Column(JSON))
  72. metadata_: Optional[dict] = Field(
  73. default=None,
  74. sa_column=Column("metadata", JSON),
  75. schema_extra={"validation_alias": "metadata"},
  76. )
  77. last_error: Optional[dict] = Field(default=None, sa_column=Column(JSON))
  78. required_action: Optional[dict] = Field(default=None, sa_column=Column(JSON))
  79. tools: Optional[list] = Field(default=[], sa_column=Column(JSON))
  80. started_at: Optional[datetime] = Field(default=None)
  81. completed_at: Optional[datetime] = Field(default=None)
  82. cancelled_at: Optional[datetime] = Field(default=None)
  83. expires_at: Optional[datetime] = Field(default=None)
  84. failed_at: Optional[datetime] = Field(default=None)
  85. additional_instructions: Optional[str] = Field(
  86. default=None, max_length=32768, sa_column=Column(TEXT)
  87. )
  88. extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
  89. stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
  90. incomplete_details: Optional[str] = Field(default=None) # 未完成详情
  91. max_completion_tokens: Optional[int] = Field(default=None) # 最大完成长度
  92. max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度
  93. response_format: Optional[Union[str, dict]] = Field(
  94. default="auto", sa_column=Column(JSON)
  95. ) # 响应格式
  96. tool_choice: Optional[str] = Field(default=None) # 工具选择
  97. truncation_strategy: Optional[dict] = Field(
  98. default=None, sa_column=Column(JSON)
  99. ) # 截断策略
  100. usage: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 调用使用情况
  101. temperature: Optional[float] = Field(default=None) # 温度
  102. top_p: Optional[float] = Field(default=None) # top_p
  103. # parallel_tool_calls: bool = Field(default=False) # parallel_tool_calls
  104. # audio: Optional[ChatCompletionAudioParam] = Field(default=None) # audio
  105. # modalities: Optional[List[Literal["text", "audio"]]] = Field(
  106. # default=None, sa_column=Column(JSON)
  107. # )
  108. class Run(RunBase, PrimaryKeyMixin, TimeStampMixin, table=True):
  109. pass
  110. class RunCreate(BaseModel):
  111. assistant_id: str
  112. status: Optional[str] = "queued"
  113. # instructions: Optional[str] = None
  114. instructions: str = Field(default="") # 默认为空字符串
  115. additional_instructions: Optional[str] = None
  116. model: Optional[str] = None
  117. metadata_: Optional[dict] = Field(
  118. default=None, schema_extra={"validation_alias": "metadata"}
  119. )
  120. tools: Optional[list] = []
  121. extra_body: Optional[
  122. dict[str, Union[dict[str, Union[Authentication, Any]], Any]]
  123. ] = {}
  124. stream: Optional[bool] = False
  125. stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
  126. additional_messages: Optional[list[MessageCreate]] = Field(
  127. default=[], sa_column=Column(JSON)
  128. ) # 消息列表
  129. max_completion_tokens: Optional[int] = None # 最大完成长度
  130. max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度
  131. truncation_strategy: Optional[dict] = Field(
  132. default=None, sa_column=Column(JSON)
  133. ) # 截断策略
  134. response_format: Optional[Union[str, dict]] = Field(
  135. default="auto", sa_column=Column(JSON)
  136. ) # 响应格式
  137. tool_choice: Optional[str] = Field(default=None) # 工具选择
  138. temperature: Optional[float] = Field(default=None) # 温度
  139. top_p: Optional[float] = Field(default=None) # top_p
  140. # parallel_tool_calls: bool = Field(default=False) # parallel_tool_calls
  141. # audio: Optional[ChatCompletionAudioParam] = Field(default=None) # audio
  142. # modalities: Optional[List[Literal["text", "audio"]]] = Field(
  143. # default=None, sa_column=Column(JSON)
  144. # )
  145. @model_validator(mode="before")
  146. def model_validator(cls, data: Any):
  147. extra_body = data.get("extra_body")
  148. if extra_body:
  149. action_authentications = extra_body.get("action_authentications")
  150. if action_authentications:
  151. res = action_authentications.values()
  152. [Authentication.model_validate(i).encrypt() for i in res]
  153. return data
  154. class RunUpdate(BaseModel):
  155. tools: Optional[list] = []
  156. metadata_: Optional[dict] = Field(
  157. default=None, schema_extra={"validation_alias": "metadata"}
  158. )
  159. extra_body: Optional[dict[str, Authentication]] = {}
  160. @model_validator(mode="before")
  161. def model_validator(cls, data: Any):
  162. extra_body = data.get("extra_body")
  163. if extra_body:
  164. action_authentications = extra_body.get("action_authentications")
  165. if action_authentications:
  166. res = action_authentications.values()
  167. [Authentication.model_validate(i).encrypt() for i in res]
  168. return data
  169. class RunRead(RunBase, TimeStampMixin, PrimaryKeyMixin):
  170. metadata_: Optional[dict] = PDField(default=None, alias="metadata")