Ver Fonte

Update chat_service: add OpenAICompatible

StrayWarrior há 3 dias atrás
pai
commit
3c605963df
1 ficheiros alterados com 21 adições e 0 exclusões
  1. 21 0
      chat_service.py

+ 21 - 0
chat_service.py

@@ -11,6 +11,7 @@ from logging_service import logger
 import cozepy
 import cozepy
 from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
 from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
 import time
 import time
+from openai import OpenAI
 
 
 COZE_API_TOKEN = os.getenv("COZE_API_TOKEN")
 COZE_API_TOKEN = os.getenv("COZE_API_TOKEN")
 COZE_CN_BASE_URL = 'https://api.coze.cn'
 COZE_CN_BASE_URL = 'https://api.coze.cn'
@@ -28,6 +29,26 @@ class ChatServiceType(Enum):
     OPENAI_COMPATIBLE = auto()
     OPENAI_COMPATIBLE = auto()
     COZE_CHAT = auto()
     COZE_CHAT = auto()
 
 
+class OpenAICompatible:
+    @staticmethod
+    def create_client(model_name):
+        volcengine_models = [
+            VOLCENGINE_MODEL_DOUBAO_PRO_32K,
+            VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
+            VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
+            VOLCENGINE_MODEL_DEEPSEEK_V3
+        ]
+        deepseek_models = [
+            DEEPSEEK_CHAT_MODEL,
+        ]
+        if model_name in volcengine_models:
+            llm_client = OpenAI(api_key=VOLCENGINE_API_TOKEN, base_url=VOLCENGINE_BASE_URL)
+        elif model_name in deepseek_models:
+            llm_client = OpenAI(api_key=DEEPSEEK_API_TOKEN, base_url=DEEPSEEK_BASE_URL)
+        else:
+            raise Exception("Unsupported model: %s" % model_name)
+        return llm_client
+
 class CrossAccountJWTOAuthApp(JWTOAuthApp):
 class CrossAccountJWTOAuthApp(JWTOAuthApp):
     def __init__(self, account_id: str, client_id: str, private_key: str, public_key_id: str, base_url):
     def __init__(self, account_id: str, client_id: str, private_key: str, public_key_id: str, base_url):
         self.account_id = account_id
         self.account_id = account_id