123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- #! /usr/bin/env python
- # -*- coding: utf-8 -*-
- # vim:fenc=utf-8
- import logging
- from calendar import prmonth
- import werkzeug.exceptions
- from flask import Flask, request, jsonify
- from datetime import datetime, timedelta
- from argparse import ArgumentParser
- from openai import OpenAI
- import chat_service
- import configs
- import json
- import logging_service
- import prompt_templates
- from dialogue_manager import DialogueManager
- from history_dialogue_service import HistoryDialogueService
- from user_manager import MySQLUserManager, MySQLUserRelationManager
- from user_profile_extractor import UserProfileExtractor
- app = Flask('agent_api_server')
- logger = logging_service.logger
- def wrap_response(code, msg=None, data=None):
- resp = {
- 'code': code,
- 'msg': msg
- }
- if code == 200 and not msg:
- resp['msg'] = 'success'
- if data:
- resp['data'] = data
- return jsonify(resp)
- @app.route('/api/listStaffs', methods=['GET'])
- def list_staffs():
- staff_data = app.user_relation_manager.list_staffs()
- return wrap_response(200, data=staff_data)
- @app.route('/api/getStaffProfile', methods=['GET'])
- def get_staff_profile():
- staff_id = request.args['staff_id']
- profile = app.user_manager.get_staff_profile(staff_id)
- if not profile:
- return wrap_response(404, msg='staff not found')
- else:
- return wrap_response(200, data=profile)
- @app.route('/api/getUserProfile', methods=['GET'])
- def get_user_profile():
- user_id = request.args['user_id']
- profile = app.user_manager.get_user_profile(user_id)
- if not profile:
- resp = {
- 'code': 404,
- 'msg': 'user not found'
- }
- else:
- resp = {
- 'code': 200,
- 'msg': 'success',
- 'data': profile
- }
- return jsonify(resp)
- @app.route('/api/listUsers', methods=['GET'])
- def list_users():
- user_name = request.args.get('user_name', None)
- user_union_id = request.args.get('user_union_id', None)
- if not user_name and not user_union_id:
- resp = {
- 'code': 400,
- 'msg': 'user_name or user_union_id is required'
- }
- return jsonify(resp)
- data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
- return jsonify({'code': 200, 'data': data})
- @app.route('/api/getDialogueHistory', methods=['GET'])
- def get_dialogue_history():
- staff_id = request.args['staff_id']
- user_id = request.args['user_id']
- recent_minutes = int(request.args.get('recent_minutes', 1440))
- dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
- return jsonify({'code': 200, 'data': dialogue_history})
- @app.route('/api/listModels', methods=['GET'])
- def list_models():
- models = [
- {
- 'model_type': 'openai_compatible',
- 'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
- 'display_name': 'DeepSeek V3 on 火山'
- },
- {
- 'model_type': 'openai_compatible',
- 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
- 'display_name': '豆包Pro 32K'
- },
- {
- 'model_type': 'openai_compatible',
- 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
- 'display_name': '豆包Pro 1.5'
- },
- ]
- return wrap_response(200, data=models)
- @app.route('/api/listScenes', methods=['GET'])
- def list_scenes():
- scenes = [
- {'scene': 'greeting', 'display_name': '问候'},
- {'scene': 'chitchat', 'display_name': '闲聊'},
- {'scene': 'profile_extractor', 'display_name': '画像提取'}
- ]
- return wrap_response(200, data=scenes)
- @app.route('/api/getBasePrompt', methods=['GET'])
- def get_base_prompt():
- scene = request.args['scene']
- prompt_map = {
- 'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
- 'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
- 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT
- }
- model_map = {
- 'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
- 'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
- 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
- }
- if scene not in prompt_map:
- return wrap_response(404, msg='scene not found')
- data = {
- 'model_name': model_map[scene],
- 'content': prompt_map[scene]
- }
- return wrap_response(200, data=data)
- def run_openai_chat(messages, model_name, **kwargs):
- volcengine_models = [
- chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
- chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
- chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
- ]
- deepseek_models = [
- chat_service.DEEPSEEK_CHAT_MODEL,
- ]
- if model_name in volcengine_models:
- llm_client = OpenAI(api_key=chat_service.VOLCENGINE_API_TOKEN, base_url=chat_service.VOLCENGINE_BASE_URL)
- response = llm_client.chat.completions.create(
- messages=messages, model=model_name, **kwargs)
- return response
- elif model_name in deepseek_models:
- llm_client = OpenAI(api_key=chat_service.DEEPSEEK_API_TOKEN, base_url=chat_service.DEEPSEEK_BASE_URL)
- response = llm_client.chat.completions.create(
- messages=messages, model=model_name, temperature=1, top_p=0.7, max_tokens=1024)
- return response
- else:
- raise Exception('model not supported')
- def run_extractor_prompt(req_data):
- prompt = req_data['prompt']
- user_profile = req_data['user_profile']
- staff_profile = req_data['staff_profile']
- dialogue_history = req_data['dialogue_history']
- model_name = req_data['model_name']
- prompt_context = {**staff_profile,
- **user_profile,
- 'dialogue_history': UserProfileExtractor.compose_dialogue(dialogue_history)}
- prompt = prompt.format(**prompt_context)
- messages = [
- {"role": "system", "content": '你是一个专业的用户画像分析助手。'},
- {"role": "user", "content": prompt}
- ]
- tools = [UserProfileExtractor.get_extraction_function()]
- response = run_openai_chat(messages, model_name, tools=tools, temperature=0)
- tool_calls = response.choices[0].message.tool_calls
- if tool_calls:
- function_call = tool_calls[0]
- if function_call.function.name == 'update_user_profile':
- profile_info = json.loads(function_call.function.arguments)
- return {k: v for k, v in profile_info.items() if v}
- else:
- logger.error("llm does not return update_user_profile")
- return {}
- else:
- return {}
- def run_chat_prompt(req_data):
- prompt = req_data['prompt']
- staff_profile = req_data['staff_profile']
- user_profile = req_data['user_profile']
- dialogue_history = req_data['dialogue_history']
- model_name = req_data['model_name']
- current_timestamp = req_data['current_timestamp'] / 1000
- prompt_context = {**staff_profile, **user_profile}
- current_hour = datetime.fromtimestamp(current_timestamp).hour
- prompt_context['last_interaction_interval'] = 0
- prompt_context['current_time_period'] = DialogueManager.get_time_context(current_hour)
- prompt_context['current_hour'] = current_hour
- prompt_context['if_first_interaction'] = False if dialogue_history else True
- last_message = dialogue_history[-1] if dialogue_history else {'role': 'assistant'}
- prompt_context['if_active_greeting'] = False if last_message['role'] == 'user' else True
- current_time_str = datetime.fromtimestamp(current_timestamp).strftime('%Y-%m-%d %H:%M:%S')
- system_prompt = {
- 'role': 'system',
- 'content': prompt.format(**prompt_context)
- }
- messages = [system_prompt]
- messages.extend(DialogueManager.compose_chat_messages_openai_compatible(dialogue_history, current_time_str))
- return run_openai_chat(messages, model_name, temperature=1, top_p=0.7, max_tokens=1024)
- @app.route('/api/runPrompt', methods=['POST'])
- def run_prompt():
- try:
- req_data = request.json
- logger.debug(req_data)
- scene = req_data['scene']
- if scene == 'profile_extractor':
- response = run_extractor_prompt(req_data)
- return wrap_response(200, data=response)
- else:
- response = run_chat_prompt(req_data)
- return wrap_response(200, data=response.choices[0].message.content)
- except Exception as e:
- logger.error(e)
- return wrap_response(500, msg='Error: {}'.format(e))
- @app.errorhandler(werkzeug.exceptions.BadRequest)
- def handle_bad_request(e):
- logger.error(e)
- return wrap_response(400, msg='Bad Request: {}'.format(e.description))
- if __name__ == '__main__':
- parser = ArgumentParser()
- parser.add_argument('--prod', action='store_true')
- parser.add_argument('--host', default='127.0.0.1')
- parser.add_argument('--port', type=int, default=8083)
- parser.add_argument('--log-level', default='INFO')
- args = parser.parse_args()
- config = configs.get()
- logging_level = logging.getLevelName(args.log_level)
- logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
- user_db_config = config['storage']['user']
- staff_db_config = config['storage']['staff']
- user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
- app.user_manager = user_manager
- wecom_db_config = config['storage']['user_relation']
- user_relation_manager = MySQLUserRelationManager(
- user_db_config['mysql'], wecom_db_config['mysql'],
- config['storage']['staff']['table'],
- user_db_config['table'],
- wecom_db_config['table']['staff'],
- wecom_db_config['table']['relation'],
- wecom_db_config['table']['user']
- )
- app.user_relation_manager = user_relation_manager
- app.history_dialogue_service = HistoryDialogueService(
- config['storage']['history_dialogue']['api_base_url']
- )
- app.run(debug=not args.prod, host=args.host, port=args.port)
|