run.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from datetime import datetime
  2. from typing import Optional, Any, Union
  3. from pydantic import Field as PDField
  4. from sqlalchemy import Column, Enum
  5. from sqlalchemy.sql.sqltypes import JSON, TEXT
  6. from sqlmodel import Field
  7. from pydantic import model_validator
  8. from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
  9. from app.models.message import MessageCreate
  10. from app.schemas.tool.authentication import Authentication
  11. class RunBase(BaseModel):
  12. instructions: Optional[str] = Field(
  13. default=None, max_length=32768, sa_column=Column(TEXT)
  14. )
  15. model: Optional[str] = Field(default=None)
  16. status: str = Field(
  17. default="queued",
  18. sa_column=Column(
  19. Enum(
  20. "cancelled",
  21. "cancelling",
  22. "completed",
  23. "expired",
  24. "failed",
  25. "in_progress",
  26. "queued",
  27. "requires_action",
  28. ),
  29. default="queued",
  30. nullable=True,
  31. ),
  32. )
  33. assistant_id: str = Field(nullable=False)
  34. thread_id: str = Field(default=None, nullable=False)
  35. object: str = Field(nullable=False, default="thread.run")
  36. file_ids: Optional[list] = Field(default=[], sa_column=Column(JSON))
  37. metadata_: Optional[dict] = Field(
  38. default=None,
  39. sa_column=Column("metadata", JSON),
  40. schema_extra={"validation_alias": "metadata"},
  41. )
  42. last_error: Optional[dict] = Field(default=None, sa_column=Column(JSON))
  43. required_action: Optional[dict] = Field(default=None, sa_column=Column(JSON))
  44. tools: Optional[list] = Field(default=[], sa_column=Column(JSON))
  45. started_at: Optional[datetime] = Field(default=None)
  46. completed_at: Optional[datetime] = Field(default=None)
  47. cancelled_at: Optional[datetime] = Field(default=None)
  48. expires_at: Optional[datetime] = Field(default=None)
  49. failed_at: Optional[datetime] = Field(default=None)
  50. additional_instructions: Optional[str] = Field(
  51. default=None, max_length=32768, sa_column=Column(TEXT)
  52. )
  53. extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
  54. stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
  55. incomplete_details: Optional[str] = Field(default=None) # 未完成详情
  56. max_completion_tokens: Optional[int] = Field(default=None) # 最大完成长度
  57. max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度
  58. response_format: Optional[Union[str, dict]] = Field(
  59. default="auto", sa_column=Column(JSON)
  60. ) # 响应格式
  61. tool_choice: Optional[str] = Field(default=None) # 工具选择
  62. truncation_strategy: Optional[dict] = Field(
  63. default=None, sa_column=Column(JSON)
  64. ) # 截断策略
  65. usage: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 调用使用情况
  66. temperature: Optional[float] = Field(default=None) # 温度
  67. top_p: Optional[float] = Field(default=None) # top_p
  68. class Run(RunBase, PrimaryKeyMixin, TimeStampMixin, table=True):
  69. pass
  70. class RunCreate(BaseModel):
  71. assistant_id: str
  72. status: Optional[str] = "queued"
  73. instructions: Optional[str] = None
  74. additional_instructions: Optional[str] = None
  75. model: Optional[str] = None
  76. metadata_: Optional[dict] = Field(
  77. default=None, schema_extra={"validation_alias": "metadata"}
  78. )
  79. tools: Optional[list] = []
  80. extra_body: Optional[
  81. dict[str, Union[dict[str, Union[Authentication, Any]], Any]]
  82. ] = {}
  83. stream: Optional[bool] = False
  84. stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
  85. additional_messages: Optional[list[MessageCreate]] = Field(
  86. default=[], sa_column=Column(JSON)
  87. ) # 消息列表
  88. max_completion_tokens: Optional[int] = None # 最大完成长度
  89. max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度
  90. truncation_strategy: Optional[dict] = Field(
  91. default=None, sa_column=Column(JSON)
  92. ) # 截断策略
  93. response_format: Optional[Union[str, dict]] = Field(
  94. default="auto", sa_column=Column(JSON)
  95. ) # 响应格式
  96. tool_choice: Optional[str] = Field(default=None) # 工具选择
  97. temperature: Optional[float] = Field(default=None) # 温度
  98. top_p: Optional[float] = Field(default=None) # top_p
  99. @model_validator(mode="before")
  100. def model_validator(cls, data: Any):
  101. extra_body = data.get("extra_body")
  102. if extra_body:
  103. action_authentications = extra_body.get("action_authentications")
  104. if action_authentications:
  105. res = action_authentications.values()
  106. [Authentication.model_validate(i).encrypt() for i in res]
  107. return data
  108. class RunUpdate(BaseModel):
  109. tools: Optional[list] = []
  110. metadata_: Optional[dict] = Field(
  111. default=None, schema_extra={"validation_alias": "metadata"}
  112. )
  113. extra_body: Optional[dict[str, Authentication]] = {}
  114. @model_validator(mode="before")
  115. def model_validator(cls, data: Any):
  116. extra_body = data.get("extra_body")
  117. if extra_body:
  118. action_authentications = extra_body.get("action_authentications")
  119. if action_authentications:
  120. res = action_authentications.values()
  121. [Authentication.model_validate(i).encrypt() for i in res]
  122. return data
  123. class RunRead(RunBase, TimeStampMixin, PrimaryKeyMixin):
  124. metadata_: Optional[dict] = PDField(default=None, alias="metadata")