|
@@ -5,7 +5,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.orm import Session
|
|
from sqlmodel import select, desc, update
|
|
from sqlmodel import select, desc, update
|
|
|
|
|
|
-from app.exceptions.exception import BadRequestError, ResourceNotFoundError, ValidateFailedError
|
|
|
|
|
|
+from app.exceptions.exception import (
|
|
|
|
+ BadRequestError,
|
|
|
|
+ ResourceNotFoundError,
|
|
|
|
+ ValidateFailedError,
|
|
|
|
+)
|
|
from app.models import RunStep
|
|
from app.models import RunStep
|
|
from app.models.run import Run, RunRead, RunCreate, RunUpdate
|
|
from app.models.run import Run, RunRead, RunCreate, RunUpdate
|
|
from app.schemas.runs import SubmitToolOutputsRunRequest
|
|
from app.schemas.runs import SubmitToolOutputsRunRequest
|
|
@@ -16,6 +20,7 @@ from app.services.thread.thread import ThreadService
|
|
from app.utils import revise_tool_names
|
|
from app.utils import revise_tool_names
|
|
import json
|
|
import json
|
|
|
|
|
|
|
|
+
|
|
class RunService:
|
|
class RunService:
|
|
@staticmethod
|
|
@staticmethod
|
|
async def create_run(
|
|
async def create_run(
|
|
@@ -26,7 +31,9 @@ class RunService:
|
|
) -> RunRead:
|
|
) -> RunRead:
|
|
revise_tool_names(body.tools)
|
|
revise_tool_names(body.tools)
|
|
# get assistant
|
|
# get assistant
|
|
- db_asst = await AssistantService.get_assistant(session=session, assistant_id=body.assistant_id)
|
|
|
|
|
|
+ db_asst = await AssistantService.get_assistant(
|
|
|
|
+ session=session, assistant_id=body.assistant_id
|
|
|
|
+ )
|
|
if not body.model and db_asst.model:
|
|
if not body.model and db_asst.model:
|
|
body.model = db_asst.model
|
|
body.model = db_asst.model
|
|
if not body.instructions and db_asst.instructions:
|
|
if not body.instructions and db_asst.instructions:
|
|
@@ -43,7 +50,12 @@ class RunService:
|
|
file_ids = []
|
|
file_ids = []
|
|
asst_file_ids = db_asst.file_ids
|
|
asst_file_ids = db_asst.file_ids
|
|
if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
|
|
if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
|
|
- asst_file_ids = db_asst.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
|
|
|
|
|
|
+ ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
|
|
|
|
+ asst_file_ids = (
|
|
|
|
+ db_asst.tool_resources.get("file_search")
|
|
|
|
+ .get("vector_stores")[0]
|
|
|
|
+ .get("file_ids")
|
|
|
|
+ )
|
|
if asst_file_ids:
|
|
if asst_file_ids:
|
|
file_ids += asst_file_ids
|
|
file_ids += asst_file_ids
|
|
|
|
|
|
@@ -54,16 +66,25 @@ class RunService:
|
|
file_search_tool = {"type": "file_search"}
|
|
file_search_tool = {"type": "file_search"}
|
|
if file_search_tool not in body.tools:
|
|
if file_search_tool not in body.tools:
|
|
body.tools.append(file_search_tool)
|
|
body.tools.append(file_search_tool)
|
|
- thread_file_ids = db_thread.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
|
|
|
|
|
|
+ thread_file_ids = (
|
|
|
|
+ db_thread.tool_resources.get("file_search")
|
|
|
|
+ .get("vector_stores")[0]
|
|
|
|
+ .get("file_ids")
|
|
|
|
+ )
|
|
if thread_file_ids:
|
|
if thread_file_ids:
|
|
file_ids += thread_file_ids
|
|
file_ids += thread_file_ids
|
|
|
|
|
|
# create run
|
|
# create run
|
|
- db_run = Run.model_validate(body.model_dump(by_alias=True), update={"thread_id": thread_id, "file_ids": file_ids})
|
|
|
|
- print("11111111111111111111111111111111111111111111111111111111111888888888888888888888888888888888")
|
|
|
|
- #print(db_run)
|
|
|
|
- #db_run.file_ids = json.dumps(db_run.file_ids)
|
|
|
|
- db_run.file_ids = json.dumps(db_run.file_ids)
|
|
|
|
|
|
+ db_run = Run.model_validate(
|
|
|
|
+ body.model_dump(by_alias=True),
|
|
|
|
+ update={"thread_id": thread_id, "file_ids": file_ids},
|
|
|
|
+ )
|
|
|
|
+ print(
|
|
|
|
+ "11111111111111111111111111111111111111111111111111111111111888888888888888888888888888888888"
|
|
|
|
+ )
|
|
|
|
+ # print(db_run)
|
|
|
|
+ # db_run.file_ids = json.dumps(db_run.file_ids)
|
|
|
|
+ # db_run.file_ids = json.dumps(db_run.file_ids)
|
|
session.add(db_run)
|
|
session.add(db_run)
|
|
test_run = db_run
|
|
test_run = db_run
|
|
run_id = db_run.id
|
|
run_id = db_run.id
|
|
@@ -78,7 +99,7 @@ class RunService:
|
|
)
|
|
)
|
|
await session.commit()
|
|
await session.commit()
|
|
await session.refresh(db_run)
|
|
await session.refresh(db_run)
|
|
- #db_run.file_ids = list(file_ids)
|
|
|
|
|
|
+ # db_run.file_ids = list(file_ids)
|
|
print(db_run)
|
|
print(db_run)
|
|
return db_run
|
|
return db_run
|
|
|
|
|
|
@@ -109,22 +130,34 @@ class RunService:
|
|
) -> RunRead:
|
|
) -> RunRead:
|
|
revise_tool_names(body.tools)
|
|
revise_tool_names(body.tools)
|
|
# get assistant
|
|
# get assistant
|
|
- db_asst = await AssistantService.get_assistant(session=session, assistant_id=body.assistant_id)
|
|
|
|
|
|
+ db_asst = await AssistantService.get_assistant(
|
|
|
|
+ session=session, assistant_id=body.assistant_id
|
|
|
|
+ )
|
|
file_ids = []
|
|
file_ids = []
|
|
asst_file_ids = db_asst.file_ids
|
|
asst_file_ids = db_asst.file_ids
|
|
if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
|
|
if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
|
|
- asst_file_ids = db_asst.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
|
|
|
|
|
|
+ asst_file_ids = (
|
|
|
|
+ db_asst.tool_resources.get("file_search")
|
|
|
|
+ .get("vector_stores")[0]
|
|
|
|
+ .get("file_ids")
|
|
|
|
+ )
|
|
if asst_file_ids:
|
|
if asst_file_ids:
|
|
file_ids += asst_file_ids
|
|
file_ids += asst_file_ids
|
|
|
|
|
|
# create thread
|
|
# create thread
|
|
thread_id = None
|
|
thread_id = None
|
|
if body.thread is not None:
|
|
if body.thread is not None:
|
|
- db_thread = await ThreadService.create_thread(session=session, body=body.thread)
|
|
|
|
|
|
+ db_thread = await ThreadService.create_thread(
|
|
|
|
+ session=session, body=body.thread
|
|
|
|
+ )
|
|
thread_id = db_thread.id
|
|
thread_id = db_thread.id
|
|
thread_file_ids = []
|
|
thread_file_ids = []
|
|
if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
|
|
if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
|
|
- thread_file_ids = db_thread.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
|
|
|
|
|
|
+ thread_file_ids = (
|
|
|
|
+ db_thread.tool_resources.get("file_search")
|
|
|
|
+ .get("vector_stores")[0]
|
|
|
|
+ .get("file_ids")
|
|
|
|
+ )
|
|
if thread_file_ids:
|
|
if thread_file_ids:
|
|
file_ids += thread_file_ids
|
|
file_ids += thread_file_ids
|
|
if body.model is None and db_asst.model is not None:
|
|
if body.model is None and db_asst.model is not None:
|
|
@@ -135,7 +168,10 @@ class RunService:
|
|
body.tools = db_asst.tools
|
|
body.tools = db_asst.tools
|
|
|
|
|
|
# create run
|
|
# create run
|
|
- db_run = Run.model_validate(body.model_dump(by_alias=True), update={"thread_id": thread_id, "file_ids": file_ids})
|
|
|
|
|
|
+ db_run = Run.model_validate(
|
|
|
|
+ body.model_dump(by_alias=True),
|
|
|
|
+ update={"thread_id": thread_id, "file_ids": file_ids},
|
|
|
|
+ )
|
|
session.add(db_run)
|
|
session.add(db_run)
|
|
await session.commit()
|
|
await session.commit()
|
|
await session.refresh(db_run)
|
|
await session.refresh(db_run)
|
|
@@ -167,13 +203,22 @@ class RunService:
|
|
*, session: AsyncSession, thread_id, run_id, body: SubmitToolOutputsRunRequest
|
|
*, session: AsyncSession, thread_id, run_id, body: SubmitToolOutputsRunRequest
|
|
) -> RunRead:
|
|
) -> RunRead:
|
|
# get run
|
|
# get run
|
|
- db_run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
|
|
|
|
|
|
+ db_run = await RunService.get_run(
|
|
|
|
+ session=session, run_id=run_id, thread_id=thread_id
|
|
|
|
+ )
|
|
# get run_step
|
|
# get run_step
|
|
- db_run_step = await RunService.get_in_progress_run_step(run_id=run_id, session=session)
|
|
|
|
|
|
+ db_run_step = await RunService.get_in_progress_run_step(
|
|
|
|
+ run_id=run_id, session=session
|
|
|
|
+ )
|
|
if db_run.status != "requires_action":
|
|
if db_run.status != "requires_action":
|
|
- raise BadRequestError(message=f'Run status is "${db_run.status}", cannot submit tool outputs')
|
|
|
|
|
|
+ raise BadRequestError(
|
|
|
|
+ message=f'Run status is "${db_run.status}", cannot submit tool outputs'
|
|
|
|
+ )
|
|
# For now, this is always submit_tool_outputs.
|
|
# For now, this is always submit_tool_outputs.
|
|
- if not db_run.required_action or db_run.required_action["type"] != "submit_tool_outputs":
|
|
|
|
|
|
+ if (
|
|
|
|
+ not db_run.required_action
|
|
|
|
+ or db_run.required_action["type"] != "submit_tool_outputs"
|
|
|
|
+ ):
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=500,
|
|
status_code=500,
|
|
detail=f'Run status is "${db_run.status}", but "run.required_action.type" is not '
|
|
detail=f'Run status is "${db_run.status}", but "run.required_action.type" is not '
|
|
@@ -185,7 +230,9 @@ class RunService:
|
|
raise HTTPException(status_code=500, detail="Invalid tool call")
|
|
raise HTTPException(status_code=500, detail="Invalid tool call")
|
|
|
|
|
|
for tool_output in body.tool_outputs:
|
|
for tool_output in body.tool_outputs:
|
|
- tool_call = next((t for t in tool_calls if t["id"] == tool_output.tool_call_id), None)
|
|
|
|
|
|
+ tool_call = next(
|
|
|
|
+ (t for t in tool_calls if t["id"] == tool_output.tool_call_id), None
|
|
|
|
+ )
|
|
if not tool_call:
|
|
if not tool_call:
|
|
raise HTTPException(status_code=500, detail="Invalid tool call")
|
|
raise HTTPException(status_code=500, detail="Invalid tool call")
|
|
if tool_call["type"] != "function":
|
|
if tool_call["type"] != "function":
|
|
@@ -193,39 +240,67 @@ class RunService:
|
|
tool_call["function"]["output"] = tool_output.output
|
|
tool_call["function"]["output"] = tool_output.output
|
|
|
|
|
|
# update
|
|
# update
|
|
- step_completed = not list(filter(lambda tool_call: "output" not in tool_call[tool_call["type"]], tool_calls))
|
|
|
|
|
|
+ step_completed = not list(
|
|
|
|
+ filter(
|
|
|
|
+ lambda tool_call: "output" not in tool_call[tool_call["type"]],
|
|
|
|
+ tool_calls,
|
|
|
|
+ )
|
|
|
|
+ )
|
|
if step_completed:
|
|
if step_completed:
|
|
stmt = (
|
|
stmt = (
|
|
update(RunStep)
|
|
update(RunStep)
|
|
.where(RunStep.id == db_run_step.id)
|
|
.where(RunStep.id == db_run_step.id)
|
|
- .values({"status": "completed", "step_details": {"type": "tool_calls", "tool_calls": tool_calls}})
|
|
|
|
|
|
+ .values(
|
|
|
|
+ {
|
|
|
|
+ "status": "completed",
|
|
|
|
+ "step_details": {
|
|
|
|
+ "type": "tool_calls",
|
|
|
|
+ "tool_calls": tool_calls,
|
|
|
|
+ },
|
|
|
|
+ }
|
|
|
|
+ )
|
|
)
|
|
)
|
|
else:
|
|
else:
|
|
stmt = (
|
|
stmt = (
|
|
update(RunStep)
|
|
update(RunStep)
|
|
.where(RunStep.id == db_run_step.id)
|
|
.where(RunStep.id == db_run_step.id)
|
|
- .values({"step_details": {"type": "tool_calls", "tool_calls": tool_calls}})
|
|
|
|
|
|
+ .values(
|
|
|
|
+ {"step_details": {"type": "tool_calls", "tool_calls": tool_calls}}
|
|
|
|
+ )
|
|
)
|
|
)
|
|
await session.execute(stmt)
|
|
await session.execute(stmt)
|
|
|
|
|
|
tool_call_ids = [tool_output.tool_call_id for tool_output in body.tool_outputs]
|
|
tool_call_ids = [tool_output.tool_call_id for tool_output in body.tool_outputs]
|
|
- required_action_tool_calls = db_run.required_action["submit_tool_outputs"]["tool_calls"]
|
|
|
|
|
|
+ required_action_tool_calls = db_run.required_action["submit_tool_outputs"][
|
|
|
|
+ "tool_calls"
|
|
|
|
+ ]
|
|
required_action_tool_calls = list(
|
|
required_action_tool_calls = list(
|
|
- filter(lambda tool_call: tool_call["id"] not in tool_call_ids, required_action_tool_calls)
|
|
|
|
|
|
+ filter(
|
|
|
|
+ lambda tool_call: tool_call["id"] not in tool_call_ids,
|
|
|
|
+ required_action_tool_calls,
|
|
|
|
+ )
|
|
)
|
|
)
|
|
|
|
|
|
required_action = {**db_run.required_action}
|
|
required_action = {**db_run.required_action}
|
|
if required_action_tool_calls:
|
|
if required_action_tool_calls:
|
|
- required_action["submit_tool_outputs"]["tool_calls"] = required_action_tool_calls
|
|
|
|
|
|
+ required_action["submit_tool_outputs"][
|
|
|
|
+ "tool_calls"
|
|
|
|
+ ] = required_action_tool_calls
|
|
else:
|
|
else:
|
|
required_action = {}
|
|
required_action = {}
|
|
|
|
|
|
if not required_action:
|
|
if not required_action:
|
|
stmt = (
|
|
stmt = (
|
|
- update(Run).where(Run.id == db_run.id).values({"required_action": required_action, "status": "queued"})
|
|
|
|
|
|
+ update(Run)
|
|
|
|
+ .where(Run.id == db_run.id)
|
|
|
|
+ .values({"required_action": required_action, "status": "queued"})
|
|
)
|
|
)
|
|
else:
|
|
else:
|
|
- stmt = update(Run).where(Run.id == db_run.id).values({"required_action": required_action})
|
|
|
|
|
|
+ stmt = (
|
|
|
|
+ update(Run)
|
|
|
|
+ .where(Run.id == db_run.id)
|
|
|
|
+ .values({"required_action": required_action})
|
|
|
|
+ )
|
|
|
|
|
|
await session.execute(stmt)
|
|
await session.execute(stmt)
|
|
await session.commit()
|
|
await session.commit()
|
|
@@ -257,7 +332,6 @@ class RunService:
|
|
run = result.scalars().one_or_none()
|
|
run = result.scalars().one_or_none()
|
|
if not run:
|
|
if not run:
|
|
raise ResourceNotFoundError(f"run {run_id} not found")
|
|
raise ResourceNotFoundError(f"run {run_id} not found")
|
|
-
|
|
|
|
return run
|
|
return run
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
@@ -274,10 +348,16 @@ class RunService:
|
|
return run
|
|
return run
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
- async def get_run_step(*, thread_id, run_id, step_id, session: AsyncSession) -> RunStep:
|
|
|
|
|
|
+ async def get_run_step(
|
|
|
|
+ *, thread_id, run_id, step_id, session: AsyncSession
|
|
|
|
+ ) -> RunStep:
|
|
statement = (
|
|
statement = (
|
|
select(RunStep)
|
|
select(RunStep)
|
|
- .where(RunStep.thread_id == thread_id, RunStep.run_id == run_id, RunStep.id == step_id)
|
|
|
|
|
|
+ .where(
|
|
|
|
+ RunStep.thread_id == thread_id,
|
|
|
|
+ RunStep.run_id == run_id,
|
|
|
|
+ RunStep.id == step_id,
|
|
|
|
+ )
|
|
.order_by(desc(RunStep.created_at))
|
|
.order_by(desc(RunStep.created_at))
|
|
)
|
|
)
|
|
result = await session.execute(statement)
|
|
result = await session.execute(statement)
|
|
@@ -290,7 +370,9 @@ class RunService:
|
|
def to_queued(*, session: Session, run_id) -> Run:
|
|
def to_queued(*, session: Session, run_id) -> Run:
|
|
run = RunService.get_run_sync(run_id=run_id, session=session)
|
|
run = RunService.get_run_sync(run_id=run_id, session=session)
|
|
RunService.check_cancel_and_expire_status(run=run, session=session)
|
|
RunService.check_cancel_and_expire_status(run=run, session=session)
|
|
- RunService.check_status_in(run=run, status_list=["requires_action", "in_progress", "queued"])
|
|
|
|
|
|
+ RunService.check_status_in(
|
|
|
|
+ run=run, status_list=["requires_action", "in_progress", "queued"]
|
|
|
|
+ )
|
|
|
|
|
|
if run.status != "queued":
|
|
if run.status != "queued":
|
|
run.status = "queued"
|
|
run.status = "queued"
|
|
@@ -320,7 +402,9 @@ class RunService:
|
|
def to_requires_action(*, session: Session, run_id, required_action) -> Run:
|
|
def to_requires_action(*, session: Session, run_id, required_action) -> Run:
|
|
run = RunService.get_run_sync(run_id=run_id, session=session)
|
|
run = RunService.get_run_sync(run_id=run_id, session=session)
|
|
RunService.check_cancel_and_expire_status(run=run, session=session)
|
|
RunService.check_cancel_and_expire_status(run=run, session=session)
|
|
- RunService.check_status_in(run=run, status_list=["in_progress", "requires_action"])
|
|
|
|
|
|
+ RunService.check_status_in(
|
|
|
|
+ run=run, status_list=["in_progress", "requires_action"]
|
|
|
|
+ )
|
|
|
|
|
|
if run.status != "requires_action":
|
|
if run.status != "requires_action":
|
|
run.status = "requires_action"
|
|
run.status = "requires_action"
|