|
@@ -1,6 +1,7 @@
|
|
|
#! /usr/bin/env python
|
|
|
# -*- coding: utf-8 -*-
|
|
|
# vim:fenc=utf-8
|
|
|
+import json
|
|
|
import time
|
|
|
import logging
|
|
|
import werkzeug.exceptions
|
|
@@ -10,10 +11,13 @@ from argparse import ArgumentParser
|
|
|
from pqai_agent import configs
|
|
|
|
|
|
from pqai_agent import logging_service, chat_service, prompt_templates
|
|
|
+from pqai_agent.agents.message_reply_agent import MessageReplyAgent
|
|
|
+from pqai_agent.configs import apollo_config
|
|
|
from pqai_agent.history_dialogue_service import HistoryDialogueService
|
|
|
from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
|
|
|
+from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
|
|
|
from pqai_agent_server.const import AgentApiConst
|
|
|
-from pqai_agent_server.models import MySQLSessionManager
|
|
|
+from pqai_agent_server.models import MySQLSessionManager, MySQLStaffManager
|
|
|
from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
|
|
|
from pqai_agent_server.utils import (
|
|
|
run_extractor_prompt,
|
|
@@ -34,11 +38,55 @@ def list_staffs():
|
|
|
@app.route('/api/getStaffProfile', methods=['GET'])
|
|
|
def get_staff_profile():
|
|
|
staff_id = request.args['staff_id']
|
|
|
- profile = app.user_manager.get_staff_profile(staff_id)
|
|
|
- if not profile:
|
|
|
+ if not staff_id:
|
|
|
+ return wrap_response(400, msg='staff_id is required')
|
|
|
+ agent_profile = app.staff_manager.get_staff_profile(staff_id)
|
|
|
+ if not agent_profile:
|
|
|
return wrap_response(404, msg='staff not found')
|
|
|
else:
|
|
|
- return wrap_response(200, data=profile)
|
|
|
+ field_map_list = apollo_config.get_json_value("field_map_list", [])
|
|
|
+ field_map = {
|
|
|
+ item["field_name"]: item["display_name"] for item in field_map_list
|
|
|
+ }
|
|
|
+ profile_info = [
|
|
|
+ {
|
|
|
+ "field_name": key,
|
|
|
+ "display_name": field_map[key],
|
|
|
+ "field_value": value,
|
|
|
+ }
|
|
|
+ for key, value in agent_profile.items()
|
|
|
+ if agent_profile.get(key)
|
|
|
+ ]
|
|
|
+ return wrap_response(200, data=profile_info)
|
|
|
+
|
|
|
+
|
|
|
+@app.route('/api/saveStaffProfile', methods=['POST'])
|
|
|
+def save_staff_profile():
|
|
|
+ staff_id = request.json.get('staff_id')
|
|
|
+ staff_profile = request.json.get('staff_profile')
|
|
|
+ if not staff_id:
|
|
|
+ return wrap_response(400, msg='staff id is required')
|
|
|
+
|
|
|
+ if not staff_profile:
|
|
|
+ return wrap_response(400, msg='profile is required')
|
|
|
+ else:
|
|
|
+ try:
|
|
|
+ profile_info_list = json.loads(staff_profile)
|
|
|
+ profile_dict = {item['field_name']: item['field_value'] for item in profile_info_list}
|
|
|
+ affected_rows = app.staff_manager.save_staff_profile(staff_id, profile_dict)
|
|
|
+
|
|
|
+ if not affected_rows:
|
|
|
+ return wrap_response(500, msg='save staff profile failed')
|
|
|
+ else:
|
|
|
+ return wrap_response(200, msg='save staff profile success')
|
|
|
+
|
|
|
+ except json.decoder.JSONDecodeError:
|
|
|
+ return wrap_response(400, msg='profile is not a valid json')
|
|
|
+
|
|
|
+
|
|
|
+@app.route('/api/getProfileFields', methods=['GET'])
|
|
|
+def get_profile_fields():
|
|
|
+ return wrap_response(200, data=apollo_config.get_json_value("field_map_list", []))
|
|
|
|
|
|
|
|
|
@app.route('/api/getUserProfile', methods=['GET'])
|
|
@@ -171,6 +219,32 @@ def run_prompt():
|
|
|
logger.error(e)
|
|
|
return wrap_response(500, msg='Error: {}'.format(e))
|
|
|
|
|
|
+@app.route('/api/formatForPrompt', methods=['POST'])
|
|
|
+def format_data_for_prompt():
|
|
|
+ try:
|
|
|
+ req_data = request.json
|
|
|
+ content = req_data['content']
|
|
|
+ format_type = req_data['format_type']
|
|
|
+ if format_type == 'staff_profile':
|
|
|
+ if not isinstance(content, dict):
|
|
|
+ return wrap_response(400, msg='staff_profile should be a dict')
|
|
|
+ response = format_agent_profile(content)
|
|
|
+ elif format_type == 'user_profile':
|
|
|
+ if not isinstance(content, dict):
|
|
|
+ return wrap_response(400, msg='user_profile should be a dict')
|
|
|
+ response = format_user_profile(content)
|
|
|
+ elif format_type == 'dialogue':
|
|
|
+ if not isinstance(content, list):
|
|
|
+ return wrap_response(400, msg='dialogue should be a list')
|
|
|
+ from pqai_agent_server.utils.prompt_util import format_dialogue_history
|
|
|
+ response = format_dialogue_history(content)
|
|
|
+ else:
|
|
|
+ return wrap_response(400, msg='Invalid format_type')
|
|
|
+ return wrap_response(200, data=response)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(e)
|
|
|
+ return wrap_response(500, msg='Error: {}'.format(e))
|
|
|
+
|
|
|
|
|
|
@app.route("/api/healthCheck", methods=["GET"])
|
|
|
def health_check():
|
|
@@ -208,17 +282,15 @@ def get_staff_session_list():
|
|
|
if not staff_id:
|
|
|
return wrap_response(404, msg="staff_id is required")
|
|
|
|
|
|
+ page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
|
|
|
+ page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
|
|
|
+
|
|
|
# check params
|
|
|
- page_size = request.args.get("page_size")
|
|
|
- if page_size:
|
|
|
- page_size = int(page_size)
|
|
|
- else:
|
|
|
- page_size = const.DEFAULT_PAGE_SIZE
|
|
|
- page_id = request.args.get("page_id")
|
|
|
- if page_id:
|
|
|
+ try:
|
|
|
page_id = int(page_id)
|
|
|
- else:
|
|
|
- page_id = const.DEFAULT_PAGE_ID
|
|
|
+ page_size = int(page_size)
|
|
|
+ except Exception as e:
|
|
|
+ return wrap_response(404, msg="Invalid parameter: {}".format(e))
|
|
|
|
|
|
staff_session_list = app.session_manager.get_staff_session_list(staff_id, page_id, page_size)
|
|
|
if not staff_session_list:
|
|
@@ -229,17 +301,15 @@ def get_staff_session_list():
|
|
|
|
|
|
@app.route("/api/getStaffList", methods=["GET"])
|
|
|
def get_staff_list():
|
|
|
+ page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
|
|
|
+ page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
|
|
|
+
|
|
|
# check params
|
|
|
- page_size = request.args.get("page_size")
|
|
|
- page_id = request.args.get("page_id")
|
|
|
- if page_id:
|
|
|
+ try:
|
|
|
page_id = int(page_id)
|
|
|
- else:
|
|
|
- page_id = const.DEFAULT_PAGE_ID
|
|
|
- if page_size:
|
|
|
page_size = int(page_size)
|
|
|
- else:
|
|
|
- page_size = const.DEFAULT_PAGE_SIZE
|
|
|
+ except Exception as e:
|
|
|
+ return wrap_response(404, msg="Invalid parameter: {}".format(e))
|
|
|
|
|
|
staff_list = app.user_manager.get_staff_list(page_id, page_size)
|
|
|
if not staff_list:
|
|
@@ -258,8 +328,8 @@ def get_conversation_list():
|
|
|
if not staff_id or not user_id:
|
|
|
return wrap_response(404, msg="staff_id and user_id are required")
|
|
|
|
|
|
- page_id = request.args.get("page_id")
|
|
|
- response = app.session_manager.get_conversation_list(staff_id, user_id, page_id, const.DEFAULT_CONVERSATION_SIZE)
|
|
|
+ page = request.args.get("page_id")
|
|
|
+ response = app.session_manager.get_conversation_list(staff_id, user_id, page, const.DEFAULT_CONVERSATION_SIZE)
|
|
|
return wrap_response(200, data=response)
|
|
|
|
|
|
|
|
@@ -279,8 +349,6 @@ def quit_human_interventions_status():
|
|
|
user_id = req_data["user_id"]
|
|
|
if not user_id or not staff_id:
|
|
|
return wrap_response(404, msg="user_id and staff_id are required")
|
|
|
- # dev
|
|
|
- staff_id = 1688854492669990
|
|
|
response = quit_human_intervention_status(user_id, staff_id)
|
|
|
|
|
|
return wrap_response(200, data=response)
|
|
@@ -324,6 +392,15 @@ if __name__ == '__main__':
|
|
|
)
|
|
|
app.session_manager = session_manager
|
|
|
|
|
|
+ # init staff manager
|
|
|
+ staff_manager = MySQLStaffManager(
|
|
|
+ db_config=user_db_config['mysql'],
|
|
|
+ staff_table=staff_db_config['table'],
|
|
|
+ user_table=user_db_config['table']
|
|
|
+ )
|
|
|
+ app.staff_manager = staff_manager
|
|
|
+
|
|
|
+ # init wecom manager
|
|
|
wecom_db_config = config['storage']['user_relation']
|
|
|
user_relation_manager = MySQLUserRelationManager(
|
|
|
user_db_config['mysql'], wecom_db_config['mysql'],
|