浏览代码

Update chat_service: add GPT

StrayWarrior 1 周之前
父节点
当前提交
f9c6b1293d
共有 1 个文件被更改,包括 13 次插入3 次删除
  1. 13 3
      chat_service.py

+ 13 - 3
chat_service.py

@@ -26,6 +26,10 @@ DEEPSEEK_BASE_URL = 'https://api.deepseek.com/'
 DEEPSEEK_CHAT_MODEL = 'deepseek-chat'
 VOLCENGINE_BOT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3/bots"
 VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH = "bot-20250427173459-9h2xp"
+OPENAI_API_TOKEN = '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'
+OPENAI_BASE_URL = 'https://api.openai.com/v1'
+OPENAI_MODEL_GPT_4o = 'gpt-4o'
+OPENAI_MODEL_GPT_4o_mini = 'gpt-4o-mini'
 
 class ChatServiceType(Enum):
     OPENAI_COMPATIBLE = auto()
@@ -33,7 +37,7 @@ class ChatServiceType(Enum):
 
 class OpenAICompatible:
     @staticmethod
-    def create_client(model_name):
+    def create_client(model_name, **kwargs):
         volcengine_models = [
             VOLCENGINE_MODEL_DOUBAO_PRO_32K,
             VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
@@ -43,10 +47,16 @@ class OpenAICompatible:
         deepseek_models = [
             DEEPSEEK_CHAT_MODEL,
         ]
+        openai_models = [
+            OPENAI_MODEL_GPT_4o_mini,
+            OPENAI_MODEL_GPT_4o
+        ]
         if model_name in volcengine_models:
-            llm_client = OpenAI(api_key=VOLCENGINE_API_TOKEN, base_url=VOLCENGINE_BASE_URL)
+            llm_client = OpenAI(api_key=VOLCENGINE_API_TOKEN, base_url=VOLCENGINE_BASE_URL, **kwargs)
         elif model_name in deepseek_models:
-            llm_client = OpenAI(api_key=DEEPSEEK_API_TOKEN, base_url=DEEPSEEK_BASE_URL)
+            llm_client = OpenAI(api_key=DEEPSEEK_API_TOKEN, base_url=DEEPSEEK_BASE_URL, **kwargs)
+        elif model_name in openai_models:
+            llm_client = OpenAI(api_key=OPENAI_API_TOKEN, base_url=OPENAI_BASE_URL, **kwargs)
         else:
             raise Exception("Unsupported model: %s" % model_name)
         return llm_client