Browse Source

add current_state, msg_type
and fix a bug which occurs Type Error because of int + str

luojunhui 1 month ago
parent
commit
ddaec40304

+ 42 - 2
pqai_agent_server/api_server.py

@@ -1,6 +1,7 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
+import json
 import time
 import logging
 import werkzeug.exceptions
@@ -11,11 +12,12 @@ from pqai_agent import configs
 
 from pqai_agent import logging_service, chat_service, prompt_templates
 from pqai_agent.agents.message_reply_agent import MessageReplyAgent
+from pqai_agent.configs import apollo_config
 from pqai_agent.history_dialogue_service import HistoryDialogueService
 from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
 from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
 from pqai_agent_server.const import AgentApiConst
-from pqai_agent_server.models import MySQLSessionManager
+from pqai_agent_server.models import MySQLSessionManager, MySQLStaffManager
 from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
 from pqai_agent_server.utils import (
     run_extractor_prompt,
@@ -36,13 +38,41 @@ def list_staffs():
 @app.route('/api/getStaffProfile', methods=['GET'])
 def get_staff_profile():
     staff_id = request.args['staff_id']
-    profile = app.user_manager.get_staff_profile(staff_id)
+    if not staff_id:
+        return wrap_response(400, msg='staff_id is required')
+    profile = app.staff_manager.get_staff_profile(staff_id)
     if not profile:
         return wrap_response(404, msg='staff not found')
     else:
         return wrap_response(200, data=profile)
 
 
+@app.route('/api/saveStaffProfile', methods=['POST'])
+def save_staff_profile():
+    staff_id = request.json.get('staff_id')
+    profile = request.json.get('profile')
+    if not staff_id:
+        return wrap_response(400, msg='staff id is required')
+
+    if not profile:
+        return wrap_response(400, msg='profile is required')
+    else:
+        try:
+            profile_json = json.loads(profile)
+            affected_rows = app.staff_manager.save_staff_profile(staff_id, profile_json)
+            if not affected_rows:
+                return wrap_response(500, msg='save staff profile failed')
+            else:
+                return wrap_response(200, msg='save staff profile success')
+        except json.decoder.JSONDecodeError:
+            return wrap_response(400, msg='profile is not a valid json')
+
+
+@app.route('/api/getProfileFields', methods=['GET'])
+def get_profile_fields():
+    return wrap_response(200, data=app.staff_manager.list_profile_fields())
+
+
 @app.route('/api/getUserProfile', methods=['GET'])
 def get_user_profile():
     user_id = request.args['user_id']
@@ -346,6 +376,16 @@ if __name__ == '__main__':
     )
     app.session_manager = session_manager
 
+    # init staff manager
+    staff_manager = MySQLStaffManager(
+        db_config=user_db_config['mysql'],
+        staff_table=staff_db_config['table'],
+        user_table=user_db_config['table'],
+        config=apollo_config
+    )
+    app.staff_manager = staff_manager
+
+    # init wecom manager
     wecom_db_config = config['storage']['user_relation']
     user_relation_manager = MySQLUserRelationManager(
         user_db_config['mysql'], wecom_db_config['mysql'],

+ 2 - 1
pqai_agent_server/models/__init__.py

@@ -1 +1,2 @@
-from .mysql_session_manager import MySQLSessionManager
+from .mysql_session_manager import MySQLSessionManager
+from .mysql_staff_manager import MySQLStaffManager

+ 105 - 0
pqai_agent_server/models/mysql_staff_manager.py

@@ -0,0 +1,105 @@
+import abc
+import json
+import pymysql.cursors
+
+from typing import Dict, List
+from pqai_agent.database import MySQLManager
+
+
+class StaffManager(abc.ABC):
+
+    @abc.abstractmethod
+    def list_all_staffs(self, page_id: int, page_size: int) -> List[Dict]:
+        pass
+
+    @abc.abstractmethod
+    def get_staff_profile(self, staff_id) -> Dict:
+        pass
+
+    @abc.abstractmethod
+    def save_staff_profile(self, staff_id: str, staff_profile: Dict):
+        pass
+
+
+class MySQLStaffManager(StaffManager):
+
+    def __init__(self, db_config, staff_table, user_table, config):
+        self.db = MySQLManager(db_config)
+        self.staff_table = staff_table
+        self.user_table = user_table
+        self.config = config
+
+    def save_staff_profile(self, staff_id: str, staff_profile: Dict):
+        update_query = f"""
+            update {self.staff_table} set agent_profile = %s where third_party_user_id = %s;
+        """
+        affected_rows = self.db.execute(
+            update_query, (json.dumps(staff_profile), staff_id)
+        )
+        return affected_rows
+
+    def get_staff_profile(self, staff_id) -> Dict:
+        profile_obj = {"staff_id": staff_id}
+        sql = f"""select agent_profile from {self.staff_table} where third_party_user_id = %s;"""
+        response = self.db.select(
+            sql=sql,
+            cursor_type=pymysql.cursors.DictCursor,
+            args=(staff_id,),
+        )
+        if not response:
+            profile_obj["data"] = []
+            return profile_obj
+
+        agent_profile = response[0]["agent_profile"]
+        if not agent_profile:
+            profile_obj["data"] = []
+        else:
+            field_map_list = self.list_profile_fields()
+            field_map = {
+                item["field_name"]: item["display_name"] for item in field_map_list
+            }
+            agent_profile = json.loads(agent_profile)
+            profile_obj["data"] = [
+                {
+                    "field_name": key,
+                    "display_name": field_map[key],
+                    "field_value": value,
+                }
+                for key, value in agent_profile.items()
+                if agent_profile.get(key)
+            ]
+        return profile_obj
+
+    def list_all_staffs(self, page_id: int, page_size: int) -> Dict:
+        """
+        :param page_size:
+        :param page_id:
+        :return:
+        """
+        sql = f"""
+            select t1.third_party_user_id as staff_id, t1.name as staff_name, t2.iconurl as avatar
+            from {self.staff_table} t1 left join {self.user_table} t2
+            on t1.third_party_user_id = t2.third_party_user_id
+            limit %s offset %s;
+        """
+        staff_list = self.db.select(
+            sql=sql,
+            cursor_type=pymysql.cursors.DictCursor,
+            args=(page_size + 1, page_size * (page_id - 1)),
+        )
+        if len(staff_list) > page_size:
+            has_next_page = True
+            next_page_id = page_id + 1
+            staff_list = staff_list[:page_size]
+        else:
+            has_next_page = False
+            next_page_id = None
+        return {
+            "has_next_page": has_next_page,
+            "next_page": next_page_id,
+            "data": staff_list,
+        }
+
+    def list_profile_fields(self) -> List[Dict]:
+        response = self.config.get_json_value("field_map_list", [])
+        return response