Explorar el Código

Update chat_service: add OpenAICompatible

StrayWarrior hace 3 días
padre
commit
3c605963df
Se han modificado 1 ficheros con 21 adiciones y 0 borrados
  1. 21 0
      chat_service.py

+ 21 - 0
chat_service.py

@@ -11,6 +11,7 @@ from logging_service import logger
 import cozepy
 from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
 import time
+from openai import OpenAI
 
 COZE_API_TOKEN = os.getenv("COZE_API_TOKEN")
 COZE_CN_BASE_URL = 'https://api.coze.cn'
@@ -28,6 +29,26 @@ class ChatServiceType(Enum):
     OPENAI_COMPATIBLE = auto()
     COZE_CHAT = auto()
 
+class OpenAICompatible:
+    @staticmethod
+    def create_client(model_name):
+        volcengine_models = [
+            VOLCENGINE_MODEL_DOUBAO_PRO_32K,
+            VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
+            VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
+            VOLCENGINE_MODEL_DEEPSEEK_V3
+        ]
+        deepseek_models = [
+            DEEPSEEK_CHAT_MODEL,
+        ]
+        if model_name in volcengine_models:
+            llm_client = OpenAI(api_key=VOLCENGINE_API_TOKEN, base_url=VOLCENGINE_BASE_URL)
+        elif model_name in deepseek_models:
+            llm_client = OpenAI(api_key=DEEPSEEK_API_TOKEN, base_url=DEEPSEEK_BASE_URL)
+        else:
+            raise Exception("Unsupported model: %s" % model_name)
+        return llm_client
+
 class CrossAccountJWTOAuthApp(JWTOAuthApp):
     def __init__(self, account_id: str, client_id: str, private_key: str, public_key_id: str, base_url):
         self.account_id = account_id