#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
import time
import logging
import werkzeug.exceptions
from flask import Flask, request, jsonify
from argparse import ArgumentParser

from pqai_agent import configs

from pqai_agent import logging_service, chat_service, prompt_templates
from pqai_agent.agents.message_reply_agent import MessageReplyAgent
from pqai_agent.history_dialogue_service import HistoryDialogueService
from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
from pqai_agent_server.const import AgentApiConst
from pqai_agent_server.models import MySQLSessionManager
from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
from pqai_agent_server.utils import (
    run_extractor_prompt,
    run_chat_prompt,
    run_response_type_prompt,
)

app = Flask('agent_api_server')
logger = logging_service.logger
const = AgentApiConst()

@app.route('/api/listStaffs', methods=['GET'])
def list_staffs():
    staff_data = app.user_relation_manager.list_staffs()
    return wrap_response(200, data=staff_data)


@app.route('/api/getStaffProfile', methods=['GET'])
def get_staff_profile():
    staff_id = request.args['staff_id']
    profile = app.user_manager.get_staff_profile(staff_id)
    if not profile:
        return wrap_response(404, msg='staff not found')
    else:
        return wrap_response(200, data=profile)


@app.route('/api/getUserProfile', methods=['GET'])
def get_user_profile():
    user_id = request.args['user_id']
    profile = app.user_manager.get_user_profile(user_id)
    if not profile:
        resp = {
            'code': 404,
            'msg': 'user not found'
        }
    else:
        resp = {
            'code': 200,
            'msg': 'success',
            'data': profile
        }
    return jsonify(resp)


@app.route('/api/listUsers', methods=['GET'])
def list_users():
    user_name = request.args.get('user_name', None)
    user_union_id = request.args.get('user_union_id', None)
    if not user_name and not user_union_id:
        resp = {
            'code': 400,
            'msg': 'user_name or user_union_id is required'
        }
        return jsonify(resp)
    data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
    return jsonify({'code': 200, 'data': data})


@app.route('/api/getDialogueHistory', methods=['GET'])
def get_dialogue_history():
    staff_id = request.args['staff_id']
    user_id = request.args['user_id']
    recent_minutes = int(request.args.get('recent_minutes', 1440))
    dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
    return jsonify({'code': 200, 'data': dialogue_history})


@app.route('/api/listModels', methods=['GET'])
def list_models():
    models = [
        {
            'model_type': 'openai_compatible',
            'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
            'display_name': 'DeepSeek V3 on 火山'
        },
        {
            'model_type': 'openai_compatible',
            'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
            'display_name': '豆包Pro 32K'
        },
        {
            'model_type': 'openai_compatible',
            'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
            'display_name': '豆包Pro 1.5'
        },
        {
            'model_type': 'openai_compatible',
            'model_name': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
            'display_name': 'DeepSeek V3联网 on 火山'
        },
        {
            'model_type': 'openai_compatible',
            'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
            'display_name': '豆包1.5视觉理解Pro'
        },
    ]
    return wrap_response(200, data=models)


@app.route('/api/listScenes', methods=['GET'])
def list_scenes():
    scenes = [
        {'scene': 'greeting', 'display_name': '问候'},
        {'scene': 'chitchat', 'display_name': '闲聊'},
        {'scene': 'profile_extractor', 'display_name': '画像提取'},
        {'scene': 'response_type_detector', 'display_name': '回复模态判断'},
        {'scene': 'custom_debugging', 'display_name': '自定义调试场景'}
    ]
    return wrap_response(200, data=scenes)


@app.route('/api/getBasePrompt', methods=['GET'])
def get_base_prompt():
    scene = request.args['scene']
    prompt_map = {
        'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
        'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
        'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT,
        'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
        'custom_debugging': '',
    }
    model_map = {
        'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
        'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
        'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
        'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
        'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
    }
    if scene not in prompt_map:
        return wrap_response(404, msg='scene not found')
    data = {
        'model_name': model_map[scene],
        'content': prompt_map[scene]
    }
    return wrap_response(200, data=data)


@app.route('/api/runPrompt', methods=['POST'])
def run_prompt():
    try:
        req_data = request.json
        logger.debug(req_data)
        scene = req_data['scene']
        if scene == 'profile_extractor':
            response = run_extractor_prompt(req_data)
            return wrap_response(200, data=response)
        elif scene == 'response_type_detector':
            response = run_response_type_prompt(req_data)
            return wrap_response(200, data=response.choices[0].message.content)
        else:
            response = run_chat_prompt(req_data)
            return wrap_response(200, data=response.choices[0].message.content)
    except Exception as e:
        logger.error(e)
        return wrap_response(500, msg='Error: {}'.format(e))

@app.route('/api/formatForPrompt', methods=['POST'])
def format_data_for_prompt():
    try:
        req_data = request.json
        content = req_data['content']
        format_type = req_data['format_type']
        if format_type == 'staff_profile':
            if not isinstance(content, dict):
                return wrap_response(400, msg='staff_profile should be a dict')
            response = format_agent_profile(content)
        elif format_type == 'user_profile':
            if not isinstance(content, dict):
                return wrap_response(400, msg='user_profile should be a dict')
            response = format_user_profile(content)
        elif format_type == 'dialogue':
            if not isinstance(content, list):
                return wrap_response(400, msg='dialogue should be a list')
            from pqai_agent_server.utils.prompt_util import compose_dialogue
            response = compose_dialogue(content)
        else:
            return wrap_response(400, msg='Invalid format_type')
        return wrap_response(200, data=response)
    except Exception as e:
        logger.error(e)
        return wrap_response(500, msg='Error: {}'.format(e))


@app.route("/api/healthCheck", methods=["GET"])
def health_check():
    return wrap_response(200, msg="OK")


@app.route("/api/getStaffSessionSummary", methods=["GET"])
def get_staff_session_summary():
    staff_id = request.args.get("staff_id")
    status = request.args.get("status", const.DEFAULT_STAFF_STATUS)
    page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
    page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)

    # check params
    try:
        page_id = int(page_id)
        page_size = int(page_size)
        status = int(status)
    except Exception as e:
        return wrap_response(404, msg="Invalid parameter: {}".format(e))

    staff_session_summary = app.session_manager.get_staff_sessions_summary(
        staff_id, page_id, page_size, status
    )

    if not staff_session_summary:
        return wrap_response(404, msg="staff not found")
    else:
        return wrap_response(200, data=staff_session_summary)


@app.route("/api/getStaffSessionList", methods=["GET"])
def get_staff_session_list():
    staff_id = request.args.get("staff_id")
    if not staff_id:
        return wrap_response(404, msg="staff_id is required")

    page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
    page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
    staff_session_list = app.session_manager.get_staff_session_list(staff_id, page_id, page_size)
    if not staff_session_list:
        return wrap_response(404, msg="staff not found")

    return wrap_response(200, data=staff_session_list)


@app.route("/api/getStaffList", methods=["GET"])
def get_staff_list():
    page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
    page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
    staff_list = app.user_manager.get_staff_list(page_id, page_size)
    if not staff_list:
        return wrap_response(404, msg="staff not found")
    return wrap_response(200, data=staff_list)


@app.route("/api/getConversationList", methods=["GET"])
def get_conversation_list():
    """
    获取staff && user 私聊对话列表
    :return:
    """
    staff_id = request.args.get("staff_id")
    user_id = request.args.get("user_id")
    if not staff_id or not user_id:
        return wrap_response(404, msg="staff_id and user_id are required")

    page = request.args.get("page")
    response = app.session_manager.get_conversation_list(staff_id, user_id, page, const.DEFAULT_CONVERSATION_SIZE)
    return wrap_response(200, data=response)


@app.route("/api/sendMessage", methods=["POST"])
def send_message():
    return wrap_response(200, msg="暂不实现功能")


@app.route("/api/quitHumanInterventionStatus", methods=["POST"])
def quit_human_interventions_status():
    """
    退出人工介入状态
    :return:
    """
    req_data = request.json
    staff_id = req_data["staff_id"]
    user_id = req_data["user_id"]
    if not user_id or not staff_id:
        return wrap_response(404, msg="user_id and staff_id are required")
    response = quit_human_intervention_status(user_id, staff_id)

    return wrap_response(200, data=response)


@app.errorhandler(werkzeug.exceptions.BadRequest)
def handle_bad_request(e):
    logger.error(e)
    return wrap_response(400, msg='Bad Request: {}'.format(e.description))


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--prod', action='store_true')
    parser.add_argument('--host', default='127.0.0.1')
    parser.add_argument('--port', type=int, default=8083)
    parser.add_argument('--log-level', default='INFO')
    args = parser.parse_args()

    config = configs.get()
    logging_level = logging.getLevelName(args.log_level)
    logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')

    # set db config
    user_db_config = config['storage']['user']
    staff_db_config = config['storage']['staff']
    agent_state_db_config = config['storage']['agent_state']
    chat_history_db_config = config['storage']['chat_history']

    # init user manager
    user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
    app.user_manager = user_manager

    # init session manager
    session_manager = MySQLSessionManager(
        db_config=user_db_config['mysql'],
        staff_table=staff_db_config['table'],
        user_table=user_db_config['table'],
        agent_state_table=agent_state_db_config['table'],
        chat_history_table=chat_history_db_config['table']
    )
    app.session_manager = session_manager

    wecom_db_config = config['storage']['user_relation']
    user_relation_manager = MySQLUserRelationManager(
        user_db_config['mysql'], wecom_db_config['mysql'],
        config['storage']['staff']['table'],
        user_db_config['table'],
        wecom_db_config['table']['staff'],
        wecom_db_config['table']['relation'],
        wecom_db_config['table']['user']
    )
    app.user_relation_manager = user_relation_manager
    app.history_dialogue_service = HistoryDialogueService(
        config['storage']['history_dialogue']['api_base_url']
    )

    app.run(debug=not args.prod, host=args.host, port=args.port)