瀏覽代碼

Update chat_service: add OpenAICompatible

StrayWarrior 6 月之前
父節點
當前提交
3c605963df
共有 1 個文件被更改,包括 21 次插入0 次删除
  1. 21 0
      chat_service.py

+ 21 - 0
chat_service.py

@@ -11,6 +11,7 @@ from logging_service import logger
 import cozepy
 import cozepy
 from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
 from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
 import time
 import time
+from openai import OpenAI
 
 
 COZE_API_TOKEN = os.getenv("COZE_API_TOKEN")
 COZE_API_TOKEN = os.getenv("COZE_API_TOKEN")
 COZE_CN_BASE_URL = 'https://api.coze.cn'
 COZE_CN_BASE_URL = 'https://api.coze.cn'
@@ -28,6 +29,26 @@ class ChatServiceType(Enum):
     OPENAI_COMPATIBLE = auto()
     OPENAI_COMPATIBLE = auto()
     COZE_CHAT = auto()
     COZE_CHAT = auto()
 
 
+class OpenAICompatible:
+    @staticmethod
+    def create_client(model_name):
+        volcengine_models = [
+            VOLCENGINE_MODEL_DOUBAO_PRO_32K,
+            VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
+            VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
+            VOLCENGINE_MODEL_DEEPSEEK_V3
+        ]
+        deepseek_models = [
+            DEEPSEEK_CHAT_MODEL,
+        ]
+        if model_name in volcengine_models:
+            llm_client = OpenAI(api_key=VOLCENGINE_API_TOKEN, base_url=VOLCENGINE_BASE_URL)
+        elif model_name in deepseek_models:
+            llm_client = OpenAI(api_key=DEEPSEEK_API_TOKEN, base_url=DEEPSEEK_BASE_URL)
+        else:
+            raise Exception("Unsupported model: %s" % model_name)
+        return llm_client
+
 class CrossAccountJWTOAuthApp(JWTOAuthApp):
 class CrossAccountJWTOAuthApp(JWTOAuthApp):
     def __init__(self, account_id: str, client_id: str, private_key: str, public_key_id: str, base_url):
     def __init__(self, account_id: str, client_id: str, private_key: str, public_key_id: str, base_url):
         self.account_id = account_id
         self.account_id = account_id