浏览代码

Update agent_service: create multimodal clients

StrayWarrior 6 月之前
父节点
当前提交
4052c73b08
共有 3 个文件被更改,包括 22 次插入10 次删除
  1. 16 10
      agent_service.py
  2. 3 0
      configs/dev.yaml
  3. 3 0
      configs/prod.yaml

+ 16 - 10
agent_service.py

@@ -50,12 +50,12 @@ class AgentService:
         self.response_type_detector = ResponseTypeDetector()
         self.agent_registry: Dict[str, DialogueManager] = {}
 
-        self.llm_client = OpenAI(
-            api_key=chat_service.VOLCENGINE_API_TOKEN,
-            base_url=chat_service.VOLCENGINE_BASE_URL
-        )
-        # DeepSeek on Volces
-        self.model_name = chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
+        chat_config = configs.get()['chat_api']['openai_compatible']
+        self.text_model_name = chat_config['text_model']
+        self.multimodal_model_name = chat_config['multimodal_model']
+        self.text_model_client = chat_service.OpenAICompatible.create_client(self.text_model_name)
+        self.multimodal_model_client = chat_service.OpenAICompatible.create_client(self.multimodal_model_name)
+
         coze_config = configs.get()['chat_api']['coze']
         coze_oauth_app = CozeChat.get_oauth_app(
             coze_config['oauth_client_id'], coze_config['private_key_path'], str(coze_config['public_key_id']),
@@ -223,10 +223,16 @@ class AgentService:
         if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
             return 'LLM模拟回复 {}'.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
         if self.chat_service_type == ChatServiceType.OPENAI_COMPATIBLE:
-            chat_completion = self.llm_client.chat.completions.create(
-                messages=chat_config['messages'],
-                model=self.model_name,
-            )
+            if chat_config['use_multimodal_model']:
+                chat_completion = self.multimodal_model_client.chat.completions.create(
+                    messages=chat_config['messages'],
+                    model=self.multimodal_model_name,
+                )
+            else:
+                chat_completion = self.text_model_client.chat.completions.create(
+                    messages=chat_config['messages'],
+                    model=self.text_model_client,
+                )
             response = chat_completion.choices[0].message.content
         elif self.chat_service_type == ChatServiceType.COZE_CHAT:
             bot_user_id = 'qywx_{}'.format(chat_config['user_id'])

+ 3 - 0
configs/dev.yaml

@@ -44,6 +44,9 @@ chat_api:
     public_key_id: xafitzyxY0OBCFJFzmhBxauo8LKe2pe2YjlTNYfEsns
     private_key_path: oauth/coze_privkey.pem
     account_id: 649175100044793
+  openai_compatible:
+    text_model: ep-20250414202859-6nkz5
+    multimodal_model: ep-20250421193334-nz5wd
 
 debug_flags:
   disable_llm_api_call: True

+ 3 - 0
configs/prod.yaml

@@ -39,6 +39,9 @@ chat_api:
     public_key_id: xafitzyxY0OBCFJFzmhBxauo8LKe2pe2YjlTNYfEsns
     private_key_path: oauth/coze_privkey.pem
     account_id: 649175100044793
+  openai_compatible:
+    text_model: ep-20250414202859-6nkz5
+    multimodal_model: ep-20250421193334-nz5wd
 
 agent_behavior:
   message_aggregation_sec: 20