|
@@ -2,8 +2,7 @@ from fastapi import APIRouter, Depends, Request
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
from sqlmodel import select
|
|
|
from starlette.responses import StreamingResponse
|
|
|
-
|
|
|
-from app.api.deps import get_async_session
|
|
|
+from app.api.deps import get_token_id, get_async_session
|
|
|
from app.core.runner import pub_handler
|
|
|
from app.libs.paginate import cursor_page, CommonPage
|
|
|
from app.models.run import RunCreate, RunRead, RunUpdate, Run
|
|
@@ -48,6 +47,7 @@ async def create_run(
|
|
|
session: AsyncSession = Depends(get_async_session),
|
|
|
thread_id: str,
|
|
|
body: RunCreate = ...,
|
|
|
+ token_id=Depends(get_token_id),
|
|
|
request: Request,
|
|
|
):
|
|
|
"""
|
|
@@ -65,7 +65,7 @@ async def create_run(
|
|
|
event_handler.pub_run_queued(db_run)
|
|
|
print("22222233333333333344444444444444444555555555555555556")
|
|
|
# print(run_task)
|
|
|
- run_task.apply_async(args=(db_run.id, body.stream))
|
|
|
+ run_task.apply_async(args=(db_run.id, token_id, body.stream))
|
|
|
print("22222222222222222222222222222222")
|
|
|
print(body.stream)
|
|
|
# db_run.file_ids = json.loads(db_run.file_ids)
|
|
@@ -188,6 +188,7 @@ async def submit_tool_outputs_to_run(
|
|
|
thread_id: str,
|
|
|
run_id: str = ...,
|
|
|
body: SubmitToolOutputsRunRequest = ...,
|
|
|
+ token_id=Depends(get_token_id),
|
|
|
request: Request,
|
|
|
) -> RunRead:
|
|
|
"""
|
|
@@ -200,7 +201,7 @@ async def submit_tool_outputs_to_run(
|
|
|
)
|
|
|
# Resume async task
|
|
|
if db_run.status == "queued":
|
|
|
- run_task.apply_async(args=(db_run.id, body.stream))
|
|
|
+ run_task.apply_async(args=(db_run.id, token_id, body.stream))
|
|
|
|
|
|
if body.stream:
|
|
|
return pub_handler.sub_stream(db_run.id, request)
|