|
@@ -30,7 +30,9 @@ def pub_event(channel: str, data: dict) -> None:
|
|
|
redis_client.expire(channel, 10 * 60)
|
|
|
|
|
|
|
|
|
-def read_event(channel: str, x_index: str = None) -> Tuple[Optional[str], Optional[dict]]:
|
|
|
+def read_event(
|
|
|
+ channel: str, x_index: str = None
|
|
|
+) -> Tuple[Optional[str], Optional[dict]]:
|
|
|
"""
|
|
|
Read events from the channel, starting from the next index of x_index
|
|
|
:param channel: channel name
|
|
@@ -100,8 +102,12 @@ def _data_adjust(obj):
|
|
|
for key, value in data.items():
|
|
|
if isinstance(value, datetime):
|
|
|
data[key] = value.timestamp()
|
|
|
- data['parallel_tool_calls'] = True
|
|
|
- data["file_ids"] = json.loads(data['file_ids'])
|
|
|
+ print(
|
|
|
+ "--------------------------------====================================11221212212121212121"
|
|
|
+ )
|
|
|
+ print(data)
|
|
|
+ data["parallel_tool_calls"] = True
|
|
|
+ data["file_ids"] = json.loads(data["file_ids"]) if data["file_ids"] else []
|
|
|
return data
|
|
|
|
|
|
|
|
@@ -118,7 +124,12 @@ def _data_adjust_message_delta(step_details):
|
|
|
return step_details
|
|
|
|
|
|
|
|
|
-def sub_stream(run_id, request: Request, prefix_events: List[dict] = [], suffix_events: List[dict] = []):
|
|
|
+def sub_stream(
|
|
|
+ run_id,
|
|
|
+ request: Request,
|
|
|
+ prefix_events: List[dict] = [],
|
|
|
+ suffix_events: List[dict] = [],
|
|
|
+):
|
|
|
"""
|
|
|
Subscription chat response stream
|
|
|
"""
|
|
@@ -167,30 +178,54 @@ class StreamEventHandler:
|
|
|
pub_event(self._channel, {"event": event.event, "data": event.data.json()})
|
|
|
|
|
|
def pub_run_created(self, run):
|
|
|
- data=_data_adjust(run)
|
|
|
+ data = _data_adjust(run)
|
|
|
print(data)
|
|
|
self.pub_event(events.ThreadRunCreated(data=data, event="thread.run.created"))
|
|
|
|
|
|
def pub_run_queued(self, run):
|
|
|
- self.pub_event(events.ThreadRunQueued(data=_data_adjust(run), event="thread.run.queued"))
|
|
|
+ self.pub_event(
|
|
|
+ events.ThreadRunQueued(data=_data_adjust(run), event="thread.run.queued")
|
|
|
+ )
|
|
|
|
|
|
def pub_run_in_progress(self, run):
|
|
|
- self.pub_event(events.ThreadRunInProgress(data=_data_adjust(run), event="thread.run.in_progress"))
|
|
|
+ self.pub_event(
|
|
|
+ events.ThreadRunInProgress(
|
|
|
+ data=_data_adjust(run), event="thread.run.in_progress"
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
def pub_run_completed(self, run):
|
|
|
- self.pub_event(events.ThreadRunCompleted(data=_data_adjust(run), event="thread.run.completed"))
|
|
|
+ self.pub_event(
|
|
|
+ events.ThreadRunCompleted(
|
|
|
+ data=_data_adjust(run), event="thread.run.completed"
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
def pub_run_requires_action(self, run):
|
|
|
- self.pub_event(events.ThreadRunRequiresAction(data=_data_adjust(run), event="thread.run.requires_action"))
|
|
|
+ self.pub_event(
|
|
|
+ events.ThreadRunRequiresAction(
|
|
|
+ data=_data_adjust(run), event="thread.run.requires_action"
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
def pub_run_failed(self, run):
|
|
|
- self.pub_event(events.ThreadRunFailed(data=_data_adjust(run), event="thread.run.failed"))
|
|
|
+ self.pub_event(
|
|
|
+ events.ThreadRunFailed(data=_data_adjust(run), event="thread.run.failed")
|
|
|
+ )
|
|
|
|
|
|
def pub_run_step_created(self, step):
|
|
|
- self.pub_event(events.ThreadRunStepCreated(data=_data_adjust(step), event="thread.run.step.created"))
|
|
|
+ self.pub_event(
|
|
|
+ events.ThreadRunStepCreated(
|
|
|
+ data=_data_adjust(step), event="thread.run.step.created"
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
def pub_run_step_in_progress(self, step):
|
|
|
- self.pub_event(events.ThreadRunStepInProgress(data=_data_adjust(step), event="thread.run.step.in_progress"))
|
|
|
+ self.pub_event(
|
|
|
+ events.ThreadRunStepInProgress(
|
|
|
+ data=_data_adjust(step), event="thread.run.step.in_progress"
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
def pub_run_step_delta(self, step_id, step_details):
|
|
|
self.pub_event(
|
|
@@ -205,17 +240,31 @@ class StreamEventHandler:
|
|
|
)
|
|
|
|
|
|
def pub_run_step_completed(self, step):
|
|
|
- self.pub_event(events.ThreadRunStepCompleted(data=_data_adjust(step), event="thread.run.step.completed"))
|
|
|
+ self.pub_event(
|
|
|
+ events.ThreadRunStepCompleted(
|
|
|
+ data=_data_adjust(step), event="thread.run.step.completed"
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
def pub_run_step_failed(self, step):
|
|
|
- self.pub_event(events.ThreadRunStepFailed(data=_data_adjust(step), event="thread.run.step.failed"))
|
|
|
+ self.pub_event(
|
|
|
+ events.ThreadRunStepFailed(
|
|
|
+ data=_data_adjust(step), event="thread.run.step.failed"
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
def pub_message_created(self, message):
|
|
|
- self.pub_event(events.ThreadMessageCreated(data=_data_adjust_message(message), event="thread.message.created"))
|
|
|
+ self.pub_event(
|
|
|
+ events.ThreadMessageCreated(
|
|
|
+ data=_data_adjust_message(message), event="thread.message.created"
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
def pub_message_in_progress(self, message):
|
|
|
self.pub_event(
|
|
|
- events.ThreadMessageInProgress(data=_data_adjust_message(message), event="thread.message.in_progress")
|
|
|
+ events.ThreadMessageInProgress(
|
|
|
+ data=_data_adjust_message(message), event="thread.message.in_progress"
|
|
|
+ )
|
|
|
)
|
|
|
|
|
|
def pub_message_usage(self, chunk):
|
|
@@ -230,15 +279,19 @@ class StreamEventHandler:
|
|
|
"role": "assistant",
|
|
|
"status": "in_progress",
|
|
|
"thread_id": "",
|
|
|
- "metadata": {"usage": chunk.usage.json()}
|
|
|
+ "metadata": {"usage": chunk.usage.json()},
|
|
|
}
|
|
|
self.pub_event(
|
|
|
- events.ThreadMessageInProgress(data=data, event="thread.message.in_progress")
|
|
|
+ events.ThreadMessageInProgress(
|
|
|
+ data=data, event="thread.message.in_progress"
|
|
|
+ )
|
|
|
)
|
|
|
|
|
|
def pub_message_completed(self, message):
|
|
|
self.pub_event(
|
|
|
- events.ThreadMessageCompleted(data=_data_adjust_message(message), event="thread.message.completed")
|
|
|
+ events.ThreadMessageCompleted(
|
|
|
+ data=_data_adjust_message(message), event="thread.message.completed"
|
|
|
+ )
|
|
|
)
|
|
|
|
|
|
def pub_message_delta(self, message_id, index, content, role):
|
|
@@ -249,7 +302,12 @@ class StreamEventHandler:
|
|
|
events.ThreadMessageDelta(
|
|
|
data=events.MessageDeltaEvent(
|
|
|
id=message_id,
|
|
|
- delta={"content": [{"index": index, "type": "text", "text": {"value": content}}], "role": role},
|
|
|
+ delta={
|
|
|
+ "content": [
|
|
|
+ {"index": index, "type": "text", "text": {"value": content}}
|
|
|
+ ],
|
|
|
+ "role": role,
|
|
|
+ },
|
|
|
object="thread.message.delta",
|
|
|
),
|
|
|
event="thread.message.delta",
|