Bläddra i källkod

Support Coze chat

StrayWarrior 4 veckor sedan
förälder
incheckning
6eec1526f6
7 ändrade filer med 178 tillägg och 63 borttagningar
  1. 50 28
      agent_service.py
  2. 65 0
      chat_service.py
  3. 50 25
      dialogue_manager.py
  4. 4 0
      message.py
  5. 2 2
      prompt_templates.py
  6. 7 7
      unit_test.py
  7. 0 1
      user_manager.py

+ 50 - 28
agent_service.py

@@ -2,6 +2,7 @@
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
 
+import sys
 import time
 from typing import Dict, List, Tuple, Any
 import logging
@@ -10,7 +11,9 @@ from datetime import datetime, timedelta
 import apscheduler.triggers.cron
 from apscheduler.schedulers.background import BackgroundScheduler
 
+import chat_service
 import global_flags
+from chat_service import CozeChat, ChatServiceType
 from dialogue_manager import DialogueManager, DialogueState
 from user_manager import UserManager, LocalUserManager
 from openai import OpenAI
@@ -27,7 +30,8 @@ class AgentService:
         receive_backend: MessageQueueBackend,
         send_backend: MessageQueueBackend,
         human_backend: MessageQueueBackend,
-        user_manager: UserManager
+        user_manager: UserManager,
+        chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE
     ):
         self.receive_queue = receive_backend
         self.send_queue = send_backend
@@ -42,7 +46,13 @@ class AgentService:
             api_key='5e275c38-44fd-415f-abcf-4b59f6377f72',
             base_url="https://ark.cn-beijing.volces.com/api/v3"
         )
-        self.model_name = "ep-20250213194558-rrmr2" # DeepSeek on Volces
+        # DeepSeek on Volces
+        self.model_name = "ep-20250213194558-rrmr2"
+        self.coze_client = CozeChat(
+            token=chat_service.COZE_API_TOKEN,
+            base_url=chat_service.COZE_CN_BASE_URL
+        )
+        self.chat_service_type = chat_service_type
 
         # 定时任务调度器
         self.scheduler = BackgroundScheduler()
@@ -67,7 +77,7 @@ class AgentService:
         while True:
             message = self.receive_queue.consume()
             if message:
-                self._process_single_message(message)
+                self.process_single_message(message)
             time.sleep(1)  # 避免CPU空转
 
     def _update_user_profile(self, user_id, user_profile, message: str):
@@ -92,7 +102,7 @@ class AgentService:
                                'date',
                                run_date=datetime.now() + timedelta(seconds=delay_sec))
 
-    def _process_single_message(self, message: Dict):
+    def process_single_message(self, message: Dict):
         user_id = message['user_id']
         message_text = message.get('text', None)
 
@@ -117,7 +127,7 @@ class AgentService:
         else:
             # 先更新用户画像再处理回复
             self._update_user_profile(user_id, user_profile, message_text)
-            self._process_llm_response(user_id, agent, message_text)
+            self._get_chat_response(user_id, agent, message_text)
 
     def _route_to_human_intervention(self, user_id: str, user_message: str, state: DialogueState):
         """路由到人工干预"""
@@ -127,20 +137,6 @@ class AgentService:
             'timestamp': datetime.now().isoformat()
         })
 
-    def _process_llm_response(self, user_id: str, agent: DialogueManager,
-                              user_message: str):
-        """处理LLM响应"""
-        messages = agent.make_llm_messages(user_message)
-        logging.debug(messages)
-        llm_response = self._call_llm_api(messages)
-
-        if response := agent.generate_response(llm_response):
-            logging.warning("user: {}, response: {}".format(user_id, response))
-            self.send_queue.produce({
-                'user_id': user_id,
-                'text': response,
-            })
-
     def _check_initiative_conversations(self):
         """定时检查主动发起对话"""
         for user_id in self.user_manager.list_all_users():
@@ -149,24 +145,49 @@ class AgentService:
 
             if should_initiate:
                 logging.warning("user: {}, initiate conversation".format(user_id))
-                self._process_llm_response(user_id, agent, None)
+                self._get_chat_response(user_id, agent, None)
             else:
                 logging.debug("user: {}, do not initiate conversation".format(user_id))
 
-    def _call_llm_api(self, messages: List[Dict]) -> str:
+    def _get_chat_response(self, user_id: str, agent: DialogueManager,
+                           user_message: str):
+        """处理LLM响应"""
+        chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
+        logging.debug(chat_config)
+        # FIXME(zhoutian): 这里的抽象不够好,DialogueManager和AgentService有耦合
+        chat_response = self._call_chat_api(chat_config)
+
+        if response := agent.generate_response(chat_response):
+            logging.warning("user: {}, response: {}".format(user_id, response))
+            self.send_queue.produce({
+                'user_id': user_id,
+                'type': MessageType.TEXT,
+                'text': response,
+            })
+
+    def _call_chat_api(self, chat_config: Dict) -> str:
         if global_flags.DISABLE_LLM_API_CALL:
             return 'LLM模拟回复'
-        chat_completion = self.llm_client.chat.completions.create(
-            messages=messages,
-            model=self.model_name,
-        )
-        response = chat_completion.choices[0].message.content
+        if self.chat_service_type == ChatServiceType.OPENAI_COMPATIBLE:
+            chat_completion = self.llm_client.chat.completions.create(
+                messages=chat_config['messages'],
+                model=self.model_name,
+            )
+            response = chat_completion.choices[0].message.content
+        elif self.chat_service_type == ChatServiceType.COZE_CHAT:
+            bot_user_id = 'dev_user'
+            response = self.coze_client.create(
+                chat_config['bot_id'], bot_user_id, chat_config['messages'],
+                chat_config['custom_variables']
+            )
+        else:
+            raise Exception('Unsupported chat service type: {}'.format(self.chat_service_type))
         return response
 
 if __name__ == "__main__":
     logging.getLogger().setLevel(logging.DEBUG)
     console_handler = logging.StreamHandler()
-    console_handler.setLevel(logging.WARNING)
+    console_handler.setLevel(logging.INFO)
     formatter = ColoredFormatter(
         '%(asctime)s - %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
     )
@@ -192,7 +213,8 @@ if __name__ == "__main__":
         receive_backend=receive_queue,
         send_backend=send_queue,
         human_backend=human_queue,
-        user_manager=user_manager
+        user_manager=user_manager,
+        chat_service_type=ChatServiceType.COZE_CHAT
     )
 
     process_thread = threading.Thread(target=service.process_messages)

+ 65 - 0
chat_service.py

@@ -0,0 +1,65 @@
+#! /usr/bin/env python
+# -*- coding: utf-8 -*-
+# vim:fenc=utf-8
+#
+
+import os
+from typing import List, Dict
+from enum import Enum, auto
+import logging
+from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageContentType, ChatEventType, MessageType
+
+COZE_API_TOKEN = os.getenv("COZE_API_TOKEN")
+COZE_CN_BASE_URL = 'https://api.coze.cn'
+
+class ChatServiceType(Enum):
+    OPENAI_COMPATIBLE = auto
+    COZE_CHAT = auto()
+
+class CozeChat:
+    def __init__(self, token, base_url: str):
+        self.coze = Coze(auth=TokenAuth(token), base_url=base_url)
+
+    def create(self, bot_id: str, user_id: str, messages: List, custom_variables: Dict):
+        response = self.coze.chat.create_and_poll(
+            bot_id=bot_id, user_id=user_id, additional_messages=messages,
+            custom_variables=custom_variables)
+        logging.debug("coze response: {}".format(response.messages))
+        for message in response.messages:
+            if message.type == MessageType.ANSWER:
+                return message.content
+        return None
+
+if __name__ == '__main__':
+    # Init the Coze client through the access_token.
+    coze = Coze(auth=TokenAuth(token=COZE_API_TOKEN), base_url=COZE_CN_BASE_URL)
+
+    # Create a bot instance in Coze, copy the last number from the web link as the bot's ID.
+    bot_id = "7479005417885417487"
+    # The user id identifies the identity of a user. Developers can use a custom business ID
+    # or a random string.
+    user_id = "dev_user"
+
+    chat = coze.chat.create_and_poll(
+        bot_id=bot_id,
+        user_id=user_id,
+        additional_messages=[Message.build_user_question_text("北京今天天气怎么样")],
+        custom_variables={
+            'agent_name': '芳华',
+            'agent_age': '25',
+            'agent_region': '北京',
+            'name': '李明',
+            'preferred_nickname': '李叔',
+            'age': '70',
+            'last_interaction_interval': '12',
+            'current_time_period': '上午',
+            'if_first_interaction': 'False',
+            'if_active_greeting': 'False'
+        }
+    )
+
+    for message in chat.messages:
+        print(message, flush=True)
+
+    if chat.chat.status == ChatStatus.COMPLETED:
+        print("token usage:", chat.chat.usage.token_count)

+ 50 - 25
dialogue_manager.py

@@ -8,6 +8,9 @@ from datetime import datetime
 import time
 import logging
 
+import cozepy
+
+from chat_service import ChatServiceType
 from message import MessageType
 # from vector_memory_manager import VectorMemoryManager
 from structured_memory_manager import StructuredMemoryManager
@@ -279,9 +282,11 @@ class DialogueManager:
 
         return llm_response
 
-    def _get_hours_since_last_interaction(self):
+    def _get_hours_since_last_interaction(self, precision: int = -1):
         time_diff = (time.time() * 1000) - self.last_interaction_time
         hours_passed = time_diff / 1000 / 3600
+        if precision >= 0:
+            return round(hours_passed, precision)
         return hours_passed
 
     def should_initiate_conversation(self) -> bool:
@@ -329,11 +334,10 @@ class DialogueManager:
             "current_state": self.current_state.name,
             "previous_state": self.previous_state.name if self.previous_state else None,
             "current_time_period": time_context.description,
-            "dialogue_history": self.dialogue_history[-10:],
-            "user_message": user_message,
-            "last_interaction_interval": self._get_hours_since_last_interaction(),
+            # "dialogue_history": self.dialogue_history[-10:],
+            "last_interaction_interval": self._get_hours_since_last_interaction(2),
             "if_first_interaction": False,
-            "if_active_greeting": True if user_message else False
+            "if_active_greeting": False if user_message else True
         }
 
         # 获取长期记忆
@@ -352,37 +356,58 @@ class DialogueManager:
         }
         return state_to_prompt_map[state]
 
-    def _create_system_message(self):
-        prompt_context = self.get_prompt_context(None)
+    def _select_coze_bot(self, state):
+        state_to_bot_map = {
+            DialogueState.GREETING: '7479005417885417487',
+            DialogueState.CHITCHAT: '7479005417885417487'
+        }
+        return state_to_bot_map[state]
+
+    def _create_system_message(self, prompt_context):
         prompt_template = self._select_prompt(self.current_state)
         prompt = prompt_template.format(**prompt_context['user_profile'], **prompt_context)
         return {'role': 'system', 'content': prompt}
 
-    def make_llm_messages(self, user_message: Optional[str] = None) -> List[Dict[str, str]]:
+    def build_chat_configuration(
+            self,
+            user_message: Optional[str] = None,
+            chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE
+    ) -> Dict:
         """
         参数:
-            dialogue_manager: 对话管理器实例
             user_message: 当前用户消息,如果是主动交互则为None
         返回:
             消息列表
         """
-        messages = []
-
-        # 添加系统消息
-        system_message = self._create_system_message()
-        messages.append(system_message)
-
-        # 添加历史对话
         dialogue_history = self.dialogue_history[-10:] \
             if len(self.dialogue_history) > 10 \
             else self.dialogue_history
-
-        for entry in dialogue_history:
-            role = entry['role']
-            messages.append({
-                "role": role,
-                "content": entry["content"]
-            })
-
-        return messages
+        messages = []
+        config = {}
+
+        prompt_context = self.get_prompt_context(user_message)
+        if chat_service_type == ChatServiceType.OPENAI_COMPATIBLE:
+            system_message = self._create_system_message(prompt_context)
+            messages.append(system_message)
+            for entry in dialogue_history:
+                role = entry['role']
+                messages.append({
+                    "role": role,
+                    "content": entry["content"]
+                })
+        elif chat_service_type == ChatServiceType.COZE_CHAT:
+            for entry in dialogue_history:
+                role = entry['role']
+                if role == 'user':
+                    messages.append(cozepy.Message.build_user_question_text(entry["content"]))
+                elif role == 'assistant':
+                    messages.append(cozepy.Message.build_assistant_answer(entry['content']))
+            custom_variables = {}
+            for k, v in prompt_context.items():
+                custom_variables[k] = str(v)
+            config['custom_variables'] = custom_variables
+            config['bot_id'] = self._select_coze_bot(self.current_state)
+        config['messages'] = messages
+
+        return config
 

+ 4 - 0
message.py

@@ -15,3 +15,7 @@ class MessageType(Enum):
 
     ACTIVE_TRIGGER = auto()
     AGGREGATION_TRIGGER = auto()
+
+class MessageChannel(Enum):
+    CORP_WECHAT = auto()
+    MINI_PROGRAM = auto()

+ 2 - 2
prompt_templates.py

@@ -47,7 +47,7 @@ GENERAL_GREETING_PROMPT = """
 - 用药信息:{medications}
 - 兴趣爱好:{interests}
 对话上下文信息:
-- 上次交互距当前小时: {last_interaction_interval:.2f}
+- 上次交互距当前小时: {last_interaction_interval}
 - 当前时间段: {current_time_period}
 - 是否首次交互: {if_first_interaction}
 - 是否为主动问候: {if_active_greeting}
@@ -92,7 +92,7 @@ CHITCHAT_PROMPT = """
 - 用药信息:{medications}
 - 兴趣爱好:{interests}
 对话上下文信息:
-- 上次交互距当前小时: {last_interaction_interval:.2f}
+- 上次交互距当前小时: {last_interaction_interval}
 - 当前时间段: {current_time_period}
 
 指导原则:

+ 7 - 7
unit_test.py

@@ -42,7 +42,7 @@ def test_env():
     service.user_profile_extractor.extract_profile_info = Mock(return_value=None)
 
     # 替换LLM调用为模拟响应
-    service._call_llm_api = Mock(return_value="模拟响应")
+    service._call_chat_api = Mock(return_value="模拟响应")
 
     return service, queues
 
@@ -63,7 +63,7 @@ def test_normal_conversation_flow(test_env):
     # 处理消息
     message = service.receive_queue.consume()
     if message:
-        service._process_single_message(message)
+        service.process_single_message(message)
 
     # 验证响应消息
     sent_msg = queues.send_queue.consume()
@@ -96,7 +96,7 @@ def test_aggregated_conversation_flow(test_env):
     # 处理消息
     message = service.receive_queue.consume()
     if message:
-        service._process_single_message(message)
+        service.process_single_message(message)
 
     # 验证第一次响应消息
     sent_msg = queues.send_queue.consume()
@@ -104,13 +104,13 @@ def test_aggregated_conversation_flow(test_env):
 
     message = service.receive_queue.consume()
     if message:
-        service._process_single_message(message)
+        service.process_single_message(message)
     # 验证第二次响应消息
     sent_msg = queues.send_queue.consume()
     assert sent_msg is None
 
     # 模拟定时器产生空消息触发响应
-    service._process_single_message({
+    service.process_single_message({
         "user_id": "user_id_0",
         "type": MessageType.AGGREGATION_TRIGGER,
         "timestamp": ts_begin + 2 * 1000
@@ -138,7 +138,7 @@ def test_human_intervention_trigger(test_env):
     # 处理消息
     message = service.receive_queue.consume()
     if message:
-        service._process_single_message(message)
+        service.process_single_message(message)
 
     # 验证人工队列消息
     human_msg = queues.human_queue.consume()
@@ -150,7 +150,7 @@ def test_initiative_conversation(test_env):
     """测试主动发起对话"""
     service, queues = test_env
     service._get_agent_instance("user_id_0").message_aggregation_sec = 0
-    service._call_llm_api = Mock(return_value="主动发起")
+    service._call_chat_api = Mock(return_value="主动发起模拟消息")
 
     # 设置Agent需要主动发起对话
     agent = service._get_agent_instance("user_id_0")

+ 0 - 1
user_manager.py

@@ -64,5 +64,4 @@ class LocalUserManager(UserManager):
             for file in files:
                 if file.endswith('.json'):
                     json_files.append(os.path.splitext(file)[0])
-        print(json_files)
         return json_files