瀏覽代碼

Update api_server: update model service

StrayWarrior 6 月之前
父節點
當前提交
640cf8e744
共有 1 個文件被更改,包括 10 次插入2 次删除
  1. 10 2
      api_server.py

+ 10 - 2
api_server.py

@@ -88,8 +88,8 @@ def list_models():
     models = [
         {
             'model_type': 'openai_compatible',
-            'model_name': chat_service.DEEPSEEK_CHAT_MODEL,
-            'display_name': 'DeepSeek V3'
+            'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
+            'display_name': 'DeepSeek V3 on 火山'
         },
         {
             'model_type': 'openai_compatible',
@@ -155,6 +155,9 @@ def run_prompt():
         chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
         chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
     ]
+    deepseek_models = [
+        chat_service.DEEPSEEK_CHAT_MODEL,
+    ]
     current_timestr = datetime.fromtimestamp(current_timestamp).strftime('%Y-%m-%d %H:%M:%S')
     system_prompt = {
         'role': 'system',
@@ -168,6 +171,11 @@ def run_prompt():
         response = llm_client.chat.completions.create(
             messages=messages, model=model_name, temperature=1, top_p=0.7, max_tokens=1024)
         return wrap_response(200, data=response.choices[0].message.content)
+    elif model_name in deepseek_models:
+        llm_client = OpenAI(api_key=chat_service.DEEPSEEK_API_TOKEN, base_url=chat_service.DEEPSEEK_BASE_URL)
+        response = llm_client.chat.completions.create(
+            messages=messages, model=model_name, temperature=1, top_p=0.7, max_tokens=1024)
+        return wrap_response(200, data=response.choices[0].message.content)
     else:
         return wrap_response(400, msg='model not supported')