|
@@ -0,0 +1,208 @@
|
|
|
+#! /usr/bin/env python
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+# vim:fenc=utf-8
|
|
|
+from calendar import prmonth
|
|
|
+
|
|
|
+import werkzeug.exceptions
|
|
|
+from flask import Flask, request, jsonify
|
|
|
+from datetime import datetime, timedelta
|
|
|
+from argparse import ArgumentParser
|
|
|
+
|
|
|
+from openai import OpenAI
|
|
|
+
|
|
|
+import chat_service
|
|
|
+import configs
|
|
|
+import logging_service
|
|
|
+import prompt_templates
|
|
|
+from dialogue_manager import DialogueManager
|
|
|
+from history_dialogue_service import HistoryDialogueService
|
|
|
+from user_manager import MySQLUserManager, MySQLUserRelationManager
|
|
|
+
|
|
|
+app = Flask('agent_api_server')
|
|
|
+
|
|
|
+def wrap_response(code, msg=None, data=None):
|
|
|
+ resp = {
|
|
|
+ 'code': code,
|
|
|
+ 'msg': msg
|
|
|
+ }
|
|
|
+ if code == 200 and not msg:
|
|
|
+ resp['msg'] = 'success'
|
|
|
+ if data:
|
|
|
+ resp['data'] = data
|
|
|
+ return jsonify(resp)
|
|
|
+
|
|
|
+@app.route('/api/listStaffs', methods=['GET'])
|
|
|
+def list_staffs():
|
|
|
+ staff_data = app.user_relation_manager.list_staffs()
|
|
|
+ return wrap_response(200, data=staff_data)
|
|
|
+
|
|
|
+@app.route('/api/getStaffProfile', methods=['GET'])
|
|
|
+def get_staff_profile():
|
|
|
+ staff_id = request.args['staff_id']
|
|
|
+ profile = app.user_manager.get_staff_profile(staff_id)
|
|
|
+ if not profile:
|
|
|
+ return wrap_response(404, msg='staff not found')
|
|
|
+ else:
|
|
|
+ return wrap_response(200, data=profile)
|
|
|
+
|
|
|
+@app.route('/api/getUserProfile', methods=['GET'])
|
|
|
+def get_user_profile():
|
|
|
+ user_id = request.args['user_id']
|
|
|
+ profile = app.user_manager.get_user_profile(user_id)
|
|
|
+ if not profile:
|
|
|
+ resp = {
|
|
|
+ 'code': 404,
|
|
|
+ 'msg': 'user not found'
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ resp = {
|
|
|
+ 'code': 200,
|
|
|
+ 'msg': 'success',
|
|
|
+ 'data': profile
|
|
|
+ }
|
|
|
+ return jsonify(resp)
|
|
|
+
|
|
|
+@app.route('/api/listUsers', methods=['GET'])
|
|
|
+def list_users():
|
|
|
+ user_name = request.args.get('user_name', None)
|
|
|
+ user_union_id = request.args.get('user_union_id', None)
|
|
|
+ if not user_name and not user_union_id:
|
|
|
+ resp = {
|
|
|
+ 'code': 400,
|
|
|
+ 'msg': 'user_name or user_union_id is required'
|
|
|
+ }
|
|
|
+ return jsonify(resp)
|
|
|
+ data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
|
|
|
+ return jsonify({'code': 200, 'data': data})
|
|
|
+
|
|
|
+@app.route('/api/getDialogueHistory', methods=['GET'])
|
|
|
+def get_dialogue_history():
|
|
|
+ staff_id = request.args['staff_id']
|
|
|
+ user_id = request.args['user_id']
|
|
|
+ recent_minutes = request.args.get('recent_minutes', 1440)
|
|
|
+ dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
|
|
|
+ return jsonify({'code': 200, 'data': dialogue_history})
|
|
|
+
|
|
|
+@app.route('/api/listModels', methods=['GET'])
|
|
|
+def list_models():
|
|
|
+ models = [
|
|
|
+ {
|
|
|
+ 'model_type': 'openai_compatible',
|
|
|
+ 'model_name': chat_service.DEEPSEEK_CHAT_MODEL,
|
|
|
+ 'display_name': 'DeepSeek V3'
|
|
|
+ },
|
|
|
+ {
|
|
|
+ 'model_type': 'openai_compatible',
|
|
|
+ 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
|
+ 'display_name': '豆包Pro 32K'
|
|
|
+ },
|
|
|
+ {
|
|
|
+ 'model_type': 'openai_compatible',
|
|
|
+ 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
|
|
|
+ 'display_name': '豆包Pro 1.5'
|
|
|
+ },
|
|
|
+ ]
|
|
|
+ return wrap_response(200, data=models)
|
|
|
+
|
|
|
+@app.route('/api/listScenes', methods=['GET'])
|
|
|
+def list_scenes():
|
|
|
+ scenes = [
|
|
|
+ {'scene': 'greeting', 'display_name': '问候'},
|
|
|
+ {'scene': 'chitchat', 'display_name': '闲聊'},
|
|
|
+ {'scene': 'profile_extractor', 'display_name': '画像提取'}
|
|
|
+ ]
|
|
|
+ return wrap_response(200, data=scenes)
|
|
|
+
|
|
|
+@app.route('/api/getBasePrompt', methods=['GET'])
|
|
|
+def get_base_prompt():
|
|
|
+ scene = request.args['scene']
|
|
|
+ prompt_map = {
|
|
|
+ 'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
|
|
|
+ 'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
|
|
|
+ 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT
|
|
|
+ }
|
|
|
+ model_map = {
|
|
|
+ 'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
|
+ 'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
|
+ 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
|
|
|
+ }
|
|
|
+ if scene not in prompt_map:
|
|
|
+ return wrap_response(404, msg='scene not found')
|
|
|
+ data = {
|
|
|
+ 'model_name': model_map[scene],
|
|
|
+ 'content': prompt_map[scene]
|
|
|
+ }
|
|
|
+ return wrap_response(200, data=data)
|
|
|
+
|
|
|
+@app.route('/api/runPrompt', methods=['POST'])
|
|
|
+def run_prompt():
|
|
|
+ req_data = request.json
|
|
|
+ scene = req_data['scene']
|
|
|
+ prompt = req_data['prompt']
|
|
|
+ staff_profile = req_data['staff_profile']
|
|
|
+ user_profile = req_data['user_profile']
|
|
|
+ dialogue_history = req_data['dialogue_history']
|
|
|
+ model_name = req_data['model_name']
|
|
|
+ current_timestamp = req_data['current_timestamp']
|
|
|
+ prompt_context = {**staff_profile, **user_profile}
|
|
|
+ current_hour = datetime.fromtimestamp(current_timestamp).hour
|
|
|
+ prompt_context['last_interaction_interval'] = 0
|
|
|
+ prompt_context['current_time_period'] = DialogueManager.get_time_context(current_hour)
|
|
|
+ prompt_context['current_hour'] = current_hour
|
|
|
+ prompt_context['if_first_interaction'] = False if dialogue_history else True
|
|
|
+ volcengine_models = [
|
|
|
+ chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
|
+ chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
|
|
|
+ chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
|
|
|
+ ]
|
|
|
+ current_timestr = datetime.fromtimestamp(current_timestamp).strftime('%Y-%m-%d %H:%M:%S')
|
|
|
+ system_prompt = {
|
|
|
+ 'role': 'system',
|
|
|
+ 'content': prompt.format(**prompt_context)
|
|
|
+ }
|
|
|
+ messages = []
|
|
|
+ messages.append(system_prompt)
|
|
|
+ messages.extend(DialogueManager.compose_chat_messages_openai_compatible(dialogue_history, current_timestr))
|
|
|
+ if model_name in volcengine_models:
|
|
|
+ llm_client = OpenAI(api_key=chat_service.VOLCENGINE_API_TOKEN, base_url=chat_service.VOLCENGINE_BASE_URL)
|
|
|
+ response = llm_client.chat.completions.create(
|
|
|
+ messages=messages, model=model_name, temperature=1, top_p=0.7, max_tokens=1024)
|
|
|
+ return wrap_response(200, data=response.choices[0].message.content)
|
|
|
+ else:
|
|
|
+ return wrap_response(400, msg='model not supported')
|
|
|
+
|
|
|
+@app.errorhandler(werkzeug.exceptions.BadRequest)
|
|
|
+def handle_bad_request(e):
|
|
|
+ return wrap_response(400, msg='Bad Request: {}'.format(e.description))
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ parser = ArgumentParser()
|
|
|
+ parser.add_argument('--prod', action='store_true')
|
|
|
+ parser.add_argument('--host', default='127.0.0.1')
|
|
|
+ parser.add_argument('--port', type=int, default=8083)
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ config = configs.get()
|
|
|
+ logging_service.setup_root_logger(logfile_name='agent_api_server.log')
|
|
|
+
|
|
|
+ user_db_config = config['storage']['user']
|
|
|
+ staff_db_config = config['storage']['staff']
|
|
|
+ user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
|
|
|
+ app.user_manager = user_manager
|
|
|
+
|
|
|
+ wecom_db_config = config['storage']['user_relation']
|
|
|
+ user_relation_manager = MySQLUserRelationManager(
|
|
|
+ user_db_config['mysql'], wecom_db_config['mysql'],
|
|
|
+ config['storage']['staff']['table'],
|
|
|
+ user_db_config['table'],
|
|
|
+ wecom_db_config['table']['staff'],
|
|
|
+ wecom_db_config['table']['relation'],
|
|
|
+ wecom_db_config['table']['user']
|
|
|
+ )
|
|
|
+ app.user_relation_manager = user_relation_manager
|
|
|
+ app.history_dialogue_service = HistoryDialogueService(
|
|
|
+ config['storage']['history_dialogue']['api_base_url']
|
|
|
+ )
|
|
|
+
|
|
|
+ app.run(debug=not args.prod, host=args.host, port=args.port)
|