Browse Source

Add api_server

StrayWarrior 1 month ago
parent
commit
f74b9aada6
1 changed files with 208 additions and 0 deletions
  1. 208 0
      api_server.py

+ 208 - 0
api_server.py

@@ -0,0 +1,208 @@
+#! /usr/bin/env python
+# -*- coding: utf-8 -*-
+# vim:fenc=utf-8
+from calendar import prmonth
+
+import werkzeug.exceptions
+from flask import Flask, request, jsonify
+from datetime import datetime, timedelta
+from argparse import ArgumentParser
+
+from openai import OpenAI
+
+import chat_service
+import configs
+import logging_service
+import prompt_templates
+from dialogue_manager import DialogueManager
+from history_dialogue_service import HistoryDialogueService
+from user_manager import MySQLUserManager, MySQLUserRelationManager
+
+app = Flask('agent_api_server')
+
+def wrap_response(code, msg=None, data=None):
+    resp = {
+        'code': code,
+        'msg': msg
+    }
+    if code == 200 and not msg:
+        resp['msg'] = 'success'
+    if data:
+        resp['data'] = data
+    return jsonify(resp)
+
+@app.route('/api/listStaffs', methods=['GET'])
+def list_staffs():
+    staff_data = app.user_relation_manager.list_staffs()
+    return wrap_response(200, data=staff_data)
+
+@app.route('/api/getStaffProfile', methods=['GET'])
+def get_staff_profile():
+    staff_id = request.args['staff_id']
+    profile = app.user_manager.get_staff_profile(staff_id)
+    if not profile:
+        return wrap_response(404, msg='staff not found')
+    else:
+        return wrap_response(200, data=profile)
+
+@app.route('/api/getUserProfile', methods=['GET'])
+def get_user_profile():
+    user_id = request.args['user_id']
+    profile = app.user_manager.get_user_profile(user_id)
+    if not profile:
+        resp = {
+            'code': 404,
+            'msg': 'user not found'
+        }
+    else:
+        resp = {
+            'code': 200,
+            'msg': 'success',
+            'data': profile
+        }
+    return jsonify(resp)
+
+@app.route('/api/listUsers', methods=['GET'])
+def list_users():
+    user_name = request.args.get('user_name', None)
+    user_union_id = request.args.get('user_union_id', None)
+    if not user_name and not user_union_id:
+        resp = {
+            'code': 400,
+            'msg': 'user_name or user_union_id is required'
+        }
+        return jsonify(resp)
+    data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
+    return jsonify({'code': 200, 'data': data})
+
+@app.route('/api/getDialogueHistory', methods=['GET'])
+def get_dialogue_history():
+    staff_id = request.args['staff_id']
+    user_id = request.args['user_id']
+    recent_minutes = request.args.get('recent_minutes', 1440)
+    dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
+    return jsonify({'code': 200, 'data': dialogue_history})
+
+@app.route('/api/listModels', methods=['GET'])
+def list_models():
+    models = [
+        {
+            'model_type': 'openai_compatible',
+            'model_name': chat_service.DEEPSEEK_CHAT_MODEL,
+            'display_name': 'DeepSeek V3'
+        },
+        {
+            'model_type': 'openai_compatible',
+            'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
+            'display_name': '豆包Pro 32K'
+        },
+        {
+            'model_type': 'openai_compatible',
+            'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
+            'display_name': '豆包Pro 1.5'
+        },
+    ]
+    return wrap_response(200, data=models)
+
+@app.route('/api/listScenes', methods=['GET'])
+def list_scenes():
+    scenes = [
+        {'scene': 'greeting', 'display_name': '问候'},
+        {'scene': 'chitchat', 'display_name': '闲聊'},
+        {'scene': 'profile_extractor', 'display_name': '画像提取'}
+    ]
+    return wrap_response(200, data=scenes)
+
+@app.route('/api/getBasePrompt', methods=['GET'])
+def get_base_prompt():
+    scene = request.args['scene']
+    prompt_map = {
+        'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
+        'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
+        'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT
+    }
+    model_map = {
+        'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
+        'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
+        'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
+    }
+    if scene not in prompt_map:
+        return wrap_response(404, msg='scene not found')
+    data = {
+        'model_name': model_map[scene],
+        'content': prompt_map[scene]
+    }
+    return wrap_response(200, data=data)
+
+@app.route('/api/runPrompt', methods=['POST'])
+def run_prompt():
+    req_data = request.json
+    scene = req_data['scene']
+    prompt = req_data['prompt']
+    staff_profile = req_data['staff_profile']
+    user_profile = req_data['user_profile']
+    dialogue_history = req_data['dialogue_history']
+    model_name = req_data['model_name']
+    current_timestamp = req_data['current_timestamp']
+    prompt_context = {**staff_profile, **user_profile}
+    current_hour = datetime.fromtimestamp(current_timestamp).hour
+    prompt_context['last_interaction_interval'] = 0
+    prompt_context['current_time_period'] = DialogueManager.get_time_context(current_hour)
+    prompt_context['current_hour'] = current_hour
+    prompt_context['if_first_interaction'] = False if dialogue_history else True
+    volcengine_models = [
+        chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
+        chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
+        chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
+    ]
+    current_timestr = datetime.fromtimestamp(current_timestamp).strftime('%Y-%m-%d %H:%M:%S')
+    system_prompt = {
+        'role': 'system',
+        'content': prompt.format(**prompt_context)
+    }
+    messages = []
+    messages.append(system_prompt)
+    messages.extend(DialogueManager.compose_chat_messages_openai_compatible(dialogue_history, current_timestr))
+    if model_name in volcengine_models:
+        llm_client = OpenAI(api_key=chat_service.VOLCENGINE_API_TOKEN, base_url=chat_service.VOLCENGINE_BASE_URL)
+        response = llm_client.chat.completions.create(
+            messages=messages, model=model_name, temperature=1, top_p=0.7, max_tokens=1024)
+        return wrap_response(200, data=response.choices[0].message.content)
+    else:
+        return wrap_response(400, msg='model not supported')
+
+@app.errorhandler(werkzeug.exceptions.BadRequest)
+def handle_bad_request(e):
+    return wrap_response(400, msg='Bad Request: {}'.format(e.description))
+
+
+if __name__ == '__main__':
+    parser = ArgumentParser()
+    parser.add_argument('--prod', action='store_true')
+    parser.add_argument('--host', default='127.0.0.1')
+    parser.add_argument('--port', type=int, default=8083)
+    args = parser.parse_args()
+
+    config = configs.get()
+    logging_service.setup_root_logger(logfile_name='agent_api_server.log')
+
+    user_db_config = config['storage']['user']
+    staff_db_config = config['storage']['staff']
+    user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
+    app.user_manager = user_manager
+
+    wecom_db_config = config['storage']['user_relation']
+    user_relation_manager = MySQLUserRelationManager(
+        user_db_config['mysql'], wecom_db_config['mysql'],
+        config['storage']['staff']['table'],
+        user_db_config['table'],
+        wecom_db_config['table']['staff'],
+        wecom_db_config['table']['relation'],
+        wecom_db_config['table']['user']
+    )
+    app.user_relation_manager = user_relation_manager
+    app.history_dialogue_service = HistoryDialogueService(
+        config['storage']['history_dialogue']['api_base_url']
+    )
+
+    app.run(debug=not args.prod, host=args.host, port=args.port)