jack 3 veckor sedan
förälder
incheckning
0601c57acb
3 ändrade filer med 14 tillägg och 8 borttagningar
  1. 5 4
      app/api/v1/runs.py
  2. 4 1
      app/core/runner/thread_runner.py
  3. 5 3
      app/tasks/run_task.py

+ 5 - 4
app/api/v1/runs.py

@@ -2,8 +2,7 @@ from fastapi import APIRouter, Depends, Request
 from sqlalchemy.ext.asyncio import AsyncSession
 from sqlmodel import select
 from starlette.responses import StreamingResponse
-
-from app.api.deps import get_async_session
+from app.api.deps import get_token_id, get_async_session
 from app.core.runner import pub_handler
 from app.libs.paginate import cursor_page, CommonPage
 from app.models.run import RunCreate, RunRead, RunUpdate, Run
@@ -48,6 +47,7 @@ async def create_run(
     session: AsyncSession = Depends(get_async_session),
     thread_id: str,
     body: RunCreate = ...,
+    token_id=Depends(get_token_id),
     request: Request,
 ):
     """
@@ -65,7 +65,7 @@ async def create_run(
     event_handler.pub_run_queued(db_run)
     print("22222233333333333344444444444444444555555555555555556")
     # print(run_task)
-    run_task.apply_async(args=(db_run.id, body.stream))
+    run_task.apply_async(args=(db_run.id, token_id, body.stream))
     print("22222222222222222222222222222222")
     print(body.stream)
     # db_run.file_ids = json.loads(db_run.file_ids)
@@ -188,6 +188,7 @@ async def submit_tool_outputs_to_run(
     thread_id: str,
     run_id: str = ...,
     body: SubmitToolOutputsRunRequest = ...,
+    token_id=Depends(get_token_id),
     request: Request,
 ) -> RunRead:
     """
@@ -200,7 +201,7 @@ async def submit_tool_outputs_to_run(
     )
     # Resume async task
     if db_run.status == "queued":
-        run_task.apply_async(args=(db_run.id, body.stream))
+        run_task.apply_async(args=(db_run.id, token_id, body.stream))
 
     if body.stream:
         return pub_handler.sub_stream(db_run.id, request)

+ 4 - 1
app/core/runner/thread_runner.py

@@ -46,8 +46,11 @@ class ThreadRunner:
         tool_settings.TOOL_WORKER_NUM, "tool_worker_"
     )
 
-    def __init__(self, run_id: str, session: Session, stream: bool = False):
+    def __init__(
+        self, run_id: str, token_id: str, session: Session, stream: bool = False
+    ):
         self.run_id = run_id
+        self.token_id = token_id
         self.session = session
         self.stream = stream
         self.max_step = llm_settings.LLM_MAX_STEP

+ 5 - 3
app/tasks/run_task.py

@@ -9,10 +9,12 @@ from app.services.run.run import RunService
 
 
 @celery_app.task(bind=True, autoretry_for=())
-def run_task(self, run_id: str, stream: bool = False):
-    logging.info(f"[run_task] [{run_id}] running at {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
+def run_task(self, run_id: str, token_id: str, stream: bool = False):
+    logging.info(
+        f"[run_task] [{run_id}]  [token_id] [{token_id}] running at {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
+    )
     try:
-        ThreadRunner(run_id, session, stream).run()
+        ThreadRunner(run_id, token_id, session, stream).run()
     except Exception as e:
         print("aawwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww")
         logging.exception(e)