sub_stream_test.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import json
  2. from openai import AssistantEventHandler
  3. def test_sub_stream_with_submit_tool_outputs_stream(client):
  4. def get_current_weather(location):
  5. return f"{location}今天是雨天。 "
  6. assistant = client.beta.assistants.create(
  7. name="Assistant Demo",
  8. instructions="You are a helpful assistant. When asked a question, use tools wherever possible.",
  9. model="gpt-4o",
  10. tools=[
  11. {
  12. "type": "function",
  13. "function": {
  14. "name": "get_current_weather",
  15. "description": "当你想查询指定城市的天气时非常有用。",
  16. "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "城市或县区,比如北京市、杭州市、余杭区等。"}}, "required": ["location"]}, # 查询天气时需要提供位置,因此参数设置为location
  17. },
  18. }
  19. ],
  20. )
  21. print("=====> : %s\n", assistant)
  22. thread = client.beta.threads.create()
  23. print("=====> : %s\n", thread)
  24. message = client.beta.threads.messages.create(
  25. thread_id=thread.id,
  26. role="user",
  27. content="北京天气如何?",
  28. )
  29. print("=====> : %s\n", message)
  30. funcs = [get_current_weather]
  31. class EventHandler(AssistantEventHandler):
  32. def on_event(self, event):
  33. print(event.event)
  34. if event.event == "thread.run.requires_action":
  35. print(event)
  36. run_id = event.data.id # Retrieve the run ID from the event data
  37. self.handle_requires_action(event.data, run_id)
  38. def handle_requires_action(self, data, run_id):
  39. tool_outputs = []
  40. for tool in data.required_action.submit_tool_outputs.tool_calls:
  41. func = next(iter([func for func in funcs if func.__name__ == tool.function.name]))
  42. try:
  43. output = func(**eval(tool.function.arguments))
  44. except Exception as e:
  45. output = "Error: " + str(e)
  46. tool_outputs.append({"tool_call_id": tool.id, "output": json.dumps(output)})
  47. print(tool_outputs)
  48. # Submit all tool_outputs at the same time
  49. self.submit_tool_outputs(tool_outputs, run_id)
  50. def submit_tool_outputs(self, tool_outputs, run_id):
  51. # Use the submit_tool_outputs_stream helper
  52. with client.beta.threads.runs.submit_tool_outputs_stream(
  53. thread_id=self.current_run.thread_id,
  54. run_id=self.current_run.id,
  55. tool_outputs=tool_outputs,
  56. event_handler=EventHandler(),
  57. ) as stream:
  58. # for text in stream.text_deltas:
  59. # print(text, end="", flush=True)
  60. # print()
  61. stream.until_done()
  62. def on_text_delta(self, delta, snapshot) -> None:
  63. print("=====> text delta")
  64. print("delta : %s", delta)
  65. with client.beta.threads.runs.stream(
  66. thread_id=thread.id,
  67. assistant_id=assistant.id,
  68. event_handler=EventHandler(),
  69. ) as stream:
  70. stream.until_done()