llm_backend.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import logging
  2. from typing import List
  3. from openai import OpenAI, Stream
  4. from openai.types.chat import ChatCompletionChunk, ChatCompletion
  5. class LLMBackend:
  6. """
  7. openai chat 接口封装
  8. """
  9. def __init__(self, base_url: str, api_key) -> None:
  10. print("base_url", base_url)
  11. print("api_key", api_key)
  12. self.base_url = base_url + "/" if base_url else None
  13. self.api_key = api_key
  14. self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
  15. def run(
  16. self,
  17. messages: List,
  18. model: str,
  19. tools: List = None,
  20. tool_choice="auto",
  21. stream=False,
  22. stream_options=None,
  23. extra_body=None,
  24. temperature=None,
  25. top_p=None,
  26. response_format=None,
  27. parallel_tool_calls=True,
  28. audio=None,
  29. modalities=None,
  30. ) -> ChatCompletion | Stream[ChatCompletionChunk]:
  31. if any(model.startswith(prefix) for prefix in ["o1", "o3", "gpt-5"]):
  32. temperature = None
  33. top_p = None
  34. chat_params = {
  35. "messages": messages,
  36. "model": model,
  37. "stream": stream,
  38. "max_tokens": 100000,
  39. #"presence_penalty": 0,
  40. #"frequency_penalty": 0
  41. # "parallel_tool_calls": parallel_tool_calls,
  42. }
  43. if extra_body:
  44. model_params = extra_body.get("model_params")
  45. if model_params:
  46. if "n" in model_params:
  47. raise ValueError("n is not allowed in model_params")
  48. chat_params.update(model_params)
  49. stream_options_params = extra_body.get("stream_options")
  50. if stream_options_params:
  51. chat_params["stream_options"] = {
  52. "include_usage": bool(stream_options_params["include_usage"])
  53. }
  54. print("stream_optionsstream_optionsstream_optionsstream_optionsstream_options")
  55. print(stream_options)
  56. if stream_options:
  57. print(isinstance(stream_options, dict))
  58. if isinstance(stream_options, dict):
  59. if "include_usage" in stream_options:
  60. chat_params["stream_options"] = {
  61. "include_usage": bool(stream_options["include_usage"])
  62. }
  63. if audio:
  64. chat_params["audio"] = audio
  65. if modalities:
  66. chat_params["modalities"] = modalities
  67. if temperature:
  68. chat_params["temperature"] = temperature
  69. if top_p:
  70. chat_params["top_p"] = top_p
  71. if tools:
  72. chat_params["tools"] = tools
  73. chat_params["parallel_tool_calls"] = parallel_tool_calls
  74. chat_params["tool_choice"] = tool_choice if tool_choice else "auto"
  75. if (
  76. isinstance(response_format, dict)
  77. and response_format.get("type") == "json_object"
  78. ):
  79. chat_params["response_format"] = {"type": "json_object"}
  80. for message in chat_params["messages"]:
  81. if "content" not in message:
  82. message["content"] = ""
  83. chat_params["timeout"] = 300
  84. logging.info("chat_params: %s", chat_params)
  85. response = self.client.chat.completions.create(**chat_params)
  86. logging.info("chat_response: %s", response)
  87. return response