api_server.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import logging
  5. from calendar import prmonth
  6. import werkzeug.exceptions
  7. from flask import Flask, request, jsonify
  8. from datetime import datetime, timedelta
  9. from argparse import ArgumentParser
  10. from openai import OpenAI
  11. import chat_service
  12. import configs
  13. import json
  14. import logging_service
  15. import prompt_templates
  16. from dialogue_manager import DialogueManager
  17. from history_dialogue_service import HistoryDialogueService
  18. from response_type_detector import ResponseTypeDetector
  19. from user_manager import MySQLUserManager, MySQLUserRelationManager
  20. from user_profile_extractor import UserProfileExtractor
  21. app = Flask('agent_api_server')
  22. logger = logging_service.logger
  23. def wrap_response(code, msg=None, data=None):
  24. resp = {
  25. 'code': code,
  26. 'msg': msg
  27. }
  28. if code == 200 and not msg:
  29. resp['msg'] = 'success'
  30. if data:
  31. resp['data'] = data
  32. return jsonify(resp)
  33. @app.route('/api/listStaffs', methods=['GET'])
  34. def list_staffs():
  35. staff_data = app.user_relation_manager.list_staffs()
  36. return wrap_response(200, data=staff_data)
  37. @app.route('/api/getStaffProfile', methods=['GET'])
  38. def get_staff_profile():
  39. staff_id = request.args['staff_id']
  40. profile = app.user_manager.get_staff_profile(staff_id)
  41. if not profile:
  42. return wrap_response(404, msg='staff not found')
  43. else:
  44. return wrap_response(200, data=profile)
  45. @app.route('/api/getUserProfile', methods=['GET'])
  46. def get_user_profile():
  47. user_id = request.args['user_id']
  48. profile = app.user_manager.get_user_profile(user_id)
  49. if not profile:
  50. resp = {
  51. 'code': 404,
  52. 'msg': 'user not found'
  53. }
  54. else:
  55. resp = {
  56. 'code': 200,
  57. 'msg': 'success',
  58. 'data': profile
  59. }
  60. return jsonify(resp)
  61. @app.route('/api/listUsers', methods=['GET'])
  62. def list_users():
  63. user_name = request.args.get('user_name', None)
  64. user_union_id = request.args.get('user_union_id', None)
  65. if not user_name and not user_union_id:
  66. resp = {
  67. 'code': 400,
  68. 'msg': 'user_name or user_union_id is required'
  69. }
  70. return jsonify(resp)
  71. data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
  72. return jsonify({'code': 200, 'data': data})
  73. @app.route('/api/getDialogueHistory', methods=['GET'])
  74. def get_dialogue_history():
  75. staff_id = request.args['staff_id']
  76. user_id = request.args['user_id']
  77. recent_minutes = int(request.args.get('recent_minutes', 1440))
  78. dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
  79. return jsonify({'code': 200, 'data': dialogue_history})
  80. @app.route('/api/listModels', methods=['GET'])
  81. def list_models():
  82. models = [
  83. {
  84. 'model_type': 'openai_compatible',
  85. 'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
  86. 'display_name': 'DeepSeek V3 on 火山'
  87. },
  88. {
  89. 'model_type': 'openai_compatible',
  90. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  91. 'display_name': '豆包Pro 32K'
  92. },
  93. {
  94. 'model_type': 'openai_compatible',
  95. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  96. 'display_name': '豆包Pro 1.5'
  97. },
  98. ]
  99. return wrap_response(200, data=models)
  100. @app.route('/api/listScenes', methods=['GET'])
  101. def list_scenes():
  102. scenes = [
  103. {'scene': 'greeting', 'display_name': '问候'},
  104. {'scene': 'chitchat', 'display_name': '闲聊'},
  105. {'scene': 'profile_extractor', 'display_name': '画像提取'},
  106. {'scene': 'response_type_detector', 'display_name': '回复模态判断'}
  107. ]
  108. return wrap_response(200, data=scenes)
  109. @app.route('/api/getBasePrompt', methods=['GET'])
  110. def get_base_prompt():
  111. scene = request.args['scene']
  112. prompt_map = {
  113. 'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
  114. 'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
  115. 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT,
  116. 'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT
  117. }
  118. model_map = {
  119. 'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  120. 'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  121. 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  122. 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
  123. }
  124. if scene not in prompt_map:
  125. return wrap_response(404, msg='scene not found')
  126. data = {
  127. 'model_name': model_map[scene],
  128. 'content': prompt_map[scene]
  129. }
  130. return wrap_response(200, data=data)
  131. def run_openai_chat(messages, model_name, **kwargs):
  132. volcengine_models = [
  133. chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  134. chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  135. chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
  136. ]
  137. deepseek_models = [
  138. chat_service.DEEPSEEK_CHAT_MODEL,
  139. ]
  140. if model_name in volcengine_models:
  141. llm_client = OpenAI(api_key=chat_service.VOLCENGINE_API_TOKEN, base_url=chat_service.VOLCENGINE_BASE_URL)
  142. response = llm_client.chat.completions.create(
  143. messages=messages, model=model_name, **kwargs)
  144. return response
  145. elif model_name in deepseek_models:
  146. llm_client = OpenAI(api_key=chat_service.DEEPSEEK_API_TOKEN, base_url=chat_service.DEEPSEEK_BASE_URL)
  147. response = llm_client.chat.completions.create(
  148. messages=messages, model=model_name, temperature=1, top_p=0.7, max_tokens=1024)
  149. return response
  150. else:
  151. raise Exception('model not supported')
  152. def run_extractor_prompt(req_data):
  153. prompt = req_data['prompt']
  154. user_profile = req_data['user_profile']
  155. staff_profile = req_data['staff_profile']
  156. dialogue_history = req_data['dialogue_history']
  157. model_name = req_data['model_name']
  158. prompt_context = {**staff_profile,
  159. **user_profile,
  160. 'dialogue_history': UserProfileExtractor.compose_dialogue(dialogue_history)}
  161. prompt = prompt.format(**prompt_context)
  162. messages = [
  163. {"role": "system", "content": '你是一个专业的用户画像分析助手。'},
  164. {"role": "user", "content": prompt}
  165. ]
  166. tools = [UserProfileExtractor.get_extraction_function()]
  167. response = run_openai_chat(messages, model_name, tools=tools, temperature=0)
  168. tool_calls = response.choices[0].message.tool_calls
  169. if tool_calls:
  170. function_call = tool_calls[0]
  171. if function_call.function.name == 'update_user_profile':
  172. profile_info = json.loads(function_call.function.arguments)
  173. return {k: v for k, v in profile_info.items() if v}
  174. else:
  175. logger.error("llm does not return update_user_profile")
  176. return {}
  177. else:
  178. return {}
  179. def run_chat_prompt(req_data):
  180. prompt = req_data['prompt']
  181. staff_profile = req_data['staff_profile']
  182. user_profile = req_data['user_profile']
  183. dialogue_history = req_data['dialogue_history']
  184. model_name = req_data['model_name']
  185. current_timestamp = req_data['current_timestamp'] / 1000
  186. prompt_context = {**staff_profile, **user_profile}
  187. current_hour = datetime.fromtimestamp(current_timestamp).hour
  188. prompt_context['last_interaction_interval'] = 0
  189. prompt_context['current_time_period'] = DialogueManager.get_time_context(current_hour)
  190. prompt_context['current_hour'] = current_hour
  191. prompt_context['if_first_interaction'] = False if dialogue_history else True
  192. last_message = dialogue_history[-1] if dialogue_history else {'role': 'assistant'}
  193. prompt_context['if_active_greeting'] = False if last_message['role'] == 'user' else True
  194. current_time_str = datetime.fromtimestamp(current_timestamp).strftime('%Y-%m-%d %H:%M:%S')
  195. system_prompt = {
  196. 'role': 'system',
  197. 'content': prompt.format(**prompt_context)
  198. }
  199. messages = [system_prompt]
  200. messages.extend(DialogueManager.compose_chat_messages_openai_compatible(dialogue_history, current_time_str))
  201. return run_openai_chat(messages, model_name, temperature=1, top_p=0.7, max_tokens=1024)
  202. def run_response_type_prompt(req_data):
  203. prompt = req_data['prompt']
  204. dialogue_history = req_data['dialogue_history']
  205. model_name = req_data['model_name']
  206. composed_dialogue = ResponseTypeDetector.compose_dialogue(dialogue_history[:-1])
  207. next_message = DialogueManager.format_dialogue_content(dialogue_history[-1])
  208. prompt = prompt.format(
  209. dialogue_history=composed_dialogue,
  210. message=next_message
  211. )
  212. messages = [
  213. {'role': 'system', 'content': '你是一个专业的智能助手'},
  214. {'role': 'user', 'content': prompt}
  215. ]
  216. return run_openai_chat(messages, model_name,temperature=0.2, max_tokens=128)
  217. @app.route('/api/runPrompt', methods=['POST'])
  218. def run_prompt():
  219. try:
  220. req_data = request.json
  221. logger.debug(req_data)
  222. scene = req_data['scene']
  223. if scene == 'profile_extractor':
  224. response = run_extractor_prompt(req_data)
  225. return wrap_response(200, data=response)
  226. elif scene == 'response_type_detector':
  227. response = run_response_type_prompt(req_data)
  228. return wrap_response(200, data=response.choices[0].message.content)
  229. else:
  230. response = run_chat_prompt(req_data)
  231. return wrap_response(200, data=response.choices[0].message.content)
  232. except Exception as e:
  233. logger.error(e)
  234. return wrap_response(500, msg='Error: {}'.format(e))
  235. @app.errorhandler(werkzeug.exceptions.BadRequest)
  236. def handle_bad_request(e):
  237. logger.error(e)
  238. return wrap_response(400, msg='Bad Request: {}'.format(e.description))
  239. if __name__ == '__main__':
  240. parser = ArgumentParser()
  241. parser.add_argument('--prod', action='store_true')
  242. parser.add_argument('--host', default='127.0.0.1')
  243. parser.add_argument('--port', type=int, default=8083)
  244. parser.add_argument('--log-level', default='INFO')
  245. args = parser.parse_args()
  246. config = configs.get()
  247. logging_level = logging.getLevelName(args.log_level)
  248. logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
  249. user_db_config = config['storage']['user']
  250. staff_db_config = config['storage']['staff']
  251. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
  252. app.user_manager = user_manager
  253. wecom_db_config = config['storage']['user_relation']
  254. user_relation_manager = MySQLUserRelationManager(
  255. user_db_config['mysql'], wecom_db_config['mysql'],
  256. config['storage']['staff']['table'],
  257. user_db_config['table'],
  258. wecom_db_config['table']['staff'],
  259. wecom_db_config['table']['relation'],
  260. wecom_db_config['table']['user']
  261. )
  262. app.user_relation_manager = user_relation_manager
  263. app.history_dialogue_service = HistoryDialogueService(
  264. config['storage']['history_dialogue']['api_base_url']
  265. )
  266. app.run(debug=not args.prod, host=args.host, port=args.port)