llm_backend.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import logging
  2. from typing import List
  3. from openai import OpenAI, Stream
  4. from openai.types.chat import ChatCompletionChunk, ChatCompletion
  5. class LLMBackend:
  6. """
  7. openai chat 接口封装
  8. """
  9. def __init__(self, base_url: str, api_key) -> None:
  10. self.base_url = base_url + "/" if base_url else None
  11. self.api_key = api_key
  12. self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
  13. def run(
  14. self,
  15. messages: List,
  16. model: str,
  17. tools: List = None,
  18. tool_choice="auto",
  19. stream=False,
  20. stream_options=None,
  21. extra_body=None,
  22. temperature=None,
  23. top_p=None,
  24. response_format=None,
  25. parallel_tool_calls=False,
  26. audio=None,
  27. modalities=None,
  28. ) -> ChatCompletion | Stream[ChatCompletionChunk]:
  29. chat_params = {
  30. "messages": messages,
  31. "model": model,
  32. "stream": stream,
  33. "parallel_tool_calls": False,
  34. }
  35. if extra_body:
  36. model_params = extra_body.get("model_params")
  37. if model_params:
  38. if "n" in model_params:
  39. raise ValueError("n is not allowed in model_params")
  40. chat_params.update(model_params)
  41. stream_options_params = extra_body.get("stream_options")
  42. if stream_options_params:
  43. chat_params["stream_options"] = {
  44. "include_usage": bool(stream_options_params["include_usage"])
  45. }
  46. print("stream_optionsstream_optionsstream_optionsstream_optionsstream_options")
  47. print(stream_options)
  48. if stream_options:
  49. print(isinstance(stream_options, dict))
  50. if isinstance(stream_options, dict):
  51. if "include_usage" in stream_options:
  52. chat_params["stream_options"] = {
  53. "include_usage": bool(stream_options["include_usage"])
  54. }
  55. if parallel_tool_calls:
  56. chat_params["parallel_tool_calls"] = parallel_tool_calls
  57. if audio:
  58. chat_params["audio"] = audio
  59. if modalities:
  60. chat_params["modalities"] = modalities
  61. if temperature:
  62. chat_params["temperature"] = temperature
  63. if top_p:
  64. chat_params["top_p"] = top_p
  65. if tools:
  66. chat_params["tools"] = tools
  67. chat_params["tool_choice"] = tool_choice if tool_choice else "auto"
  68. if (
  69. isinstance(response_format, dict)
  70. and response_format.get("type") == "json_object"
  71. ):
  72. chat_params["response_format"] = {"type": "json_object"}
  73. for message in chat_params["messages"]:
  74. if "content" not in message:
  75. message["content"] = ""
  76. chat_params["timeout"] = 300
  77. logging.info("chat_params: %s", chat_params)
  78. response = self.client.chat.completions.create(**chat_params)
  79. logging.info("chat_response: %s", response)
  80. return response