api_server.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import logging
  5. import werkzeug.exceptions
  6. from flask import Flask, request, jsonify
  7. from datetime import datetime
  8. from argparse import ArgumentParser
  9. from openai import OpenAI
  10. from pqai_agent.message import MessageType
  11. from pqai_agent import configs
  12. import json
  13. from pqai_agent import logging_service, chat_service, prompt_templates
  14. from pqai_agent.dialogue_manager import DialogueManager
  15. from pqai_agent.history_dialogue_service import HistoryDialogueService
  16. from pqai_agent.response_type_detector import ResponseTypeDetector
  17. from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
  18. from pqai_agent.user_profile_extractor import UserProfileExtractor
  19. app = Flask('agent_api_server')
  20. logger = logging_service.logger
  21. def compose_openai_chat_messages_no_time(dialogue_history, multimodal=False):
  22. messages = []
  23. for entry in dialogue_history:
  24. role = entry['role']
  25. msg_type = entry.get('type', MessageType.TEXT)
  26. fmt_time = DialogueManager.format_timestamp(entry['timestamp'])
  27. if msg_type in (MessageType.IMAGE_GW, MessageType.IMAGE_QW, MessageType.GIF):
  28. if multimodal:
  29. messages.append({
  30. "role": role,
  31. "content": [
  32. {"type": "image_url", "image_url": {"url": entry["content"]}}
  33. ]
  34. })
  35. else:
  36. logger.warning("Image in non-multimodal mode")
  37. messages.append({
  38. "role": role,
  39. "content": "[图片]"
  40. })
  41. else:
  42. messages.append({
  43. "role": role,
  44. "content": f'{entry["content"]}'
  45. })
  46. return messages
  47. def wrap_response(code, msg=None, data=None):
  48. resp = {
  49. 'code': code,
  50. 'msg': msg
  51. }
  52. if code == 200 and not msg:
  53. resp['msg'] = 'success'
  54. if data:
  55. resp['data'] = data
  56. return jsonify(resp)
  57. @app.route('/api/listStaffs', methods=['GET'])
  58. def list_staffs():
  59. staff_data = app.user_relation_manager.list_staffs()
  60. return wrap_response(200, data=staff_data)
  61. @app.route('/api/getStaffProfile', methods=['GET'])
  62. def get_staff_profile():
  63. staff_id = request.args['staff_id']
  64. profile = app.user_manager.get_staff_profile(staff_id)
  65. if not profile:
  66. return wrap_response(404, msg='staff not found')
  67. else:
  68. return wrap_response(200, data=profile)
  69. @app.route('/api/getUserProfile', methods=['GET'])
  70. def get_user_profile():
  71. user_id = request.args['user_id']
  72. profile = app.user_manager.get_user_profile(user_id)
  73. if not profile:
  74. resp = {
  75. 'code': 404,
  76. 'msg': 'user not found'
  77. }
  78. else:
  79. resp = {
  80. 'code': 200,
  81. 'msg': 'success',
  82. 'data': profile
  83. }
  84. return jsonify(resp)
  85. @app.route('/api/listUsers', methods=['GET'])
  86. def list_users():
  87. user_name = request.args.get('user_name', None)
  88. user_union_id = request.args.get('user_union_id', None)
  89. if not user_name and not user_union_id:
  90. resp = {
  91. 'code': 400,
  92. 'msg': 'user_name or user_union_id is required'
  93. }
  94. return jsonify(resp)
  95. data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
  96. return jsonify({'code': 200, 'data': data})
  97. @app.route('/api/getDialogueHistory', methods=['GET'])
  98. def get_dialogue_history():
  99. staff_id = request.args['staff_id']
  100. user_id = request.args['user_id']
  101. recent_minutes = int(request.args.get('recent_minutes', 1440))
  102. dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
  103. return jsonify({'code': 200, 'data': dialogue_history})
  104. @app.route('/api/listModels', methods=['GET'])
  105. def list_models():
  106. models = [
  107. {
  108. 'model_type': 'openai_compatible',
  109. 'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
  110. 'display_name': 'DeepSeek V3 on 火山'
  111. },
  112. {
  113. 'model_type': 'openai_compatible',
  114. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  115. 'display_name': '豆包Pro 32K'
  116. },
  117. {
  118. 'model_type': 'openai_compatible',
  119. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  120. 'display_name': '豆包Pro 1.5'
  121. },
  122. {
  123. 'model_type': 'openai_compatible',
  124. 'model_name': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
  125. 'display_name': 'DeepSeek V3联网 on 火山'
  126. },
  127. {
  128. 'model_type': 'openai_compatible',
  129. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
  130. 'display_name': '豆包1.5视觉理解Pro'
  131. },
  132. ]
  133. return wrap_response(200, data=models)
  134. @app.route('/api/listScenes', methods=['GET'])
  135. def list_scenes():
  136. scenes = [
  137. {'scene': 'greeting', 'display_name': '问候'},
  138. {'scene': 'chitchat', 'display_name': '闲聊'},
  139. {'scene': 'profile_extractor', 'display_name': '画像提取'},
  140. {'scene': 'response_type_detector', 'display_name': '回复模态判断'},
  141. {'scene': 'custom_debugging', 'display_name': '自定义调试场景'}
  142. ]
  143. return wrap_response(200, data=scenes)
  144. @app.route('/api/getBasePrompt', methods=['GET'])
  145. def get_base_prompt():
  146. scene = request.args['scene']
  147. prompt_map = {
  148. 'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
  149. 'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
  150. 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT,
  151. 'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
  152. 'custom_debugging': '',
  153. }
  154. model_map = {
  155. 'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  156. 'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  157. 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  158. 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  159. 'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
  160. }
  161. if scene not in prompt_map:
  162. return wrap_response(404, msg='scene not found')
  163. data = {
  164. 'model_name': model_map[scene],
  165. 'content': prompt_map[scene]
  166. }
  167. return wrap_response(200, data=data)
  168. def run_openai_chat(messages, model_name, **kwargs):
  169. volcengine_models = [
  170. chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  171. chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  172. chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
  173. chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
  174. ]
  175. deepseek_models = [
  176. chat_service.DEEPSEEK_CHAT_MODEL,
  177. ]
  178. volcengine_bots = [
  179. chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
  180. ]
  181. if model_name in volcengine_models:
  182. llm_client = OpenAI(api_key=chat_service.VOLCENGINE_API_TOKEN, base_url=chat_service.VOLCENGINE_BASE_URL)
  183. elif model_name in volcengine_bots:
  184. llm_client = OpenAI(api_key=chat_service.VOLCENGINE_API_TOKEN, base_url=chat_service.VOLCENGINE_BOT_BASE_URL)
  185. elif model_name in deepseek_models:
  186. llm_client = OpenAI(api_key=chat_service.DEEPSEEK_API_TOKEN, base_url=chat_service.DEEPSEEK_BASE_URL)
  187. else:
  188. raise Exception('model not supported')
  189. response = llm_client.chat.completions.create(
  190. messages=messages, model=model_name, **kwargs)
  191. logger.debug(response)
  192. return response
  193. def run_extractor_prompt(req_data):
  194. prompt = req_data['prompt']
  195. user_profile = req_data['user_profile']
  196. staff_profile = req_data['staff_profile']
  197. dialogue_history = req_data['dialogue_history']
  198. model_name = req_data['model_name']
  199. prompt_context = {**staff_profile,
  200. **user_profile,
  201. 'dialogue_history': UserProfileExtractor.compose_dialogue(dialogue_history)}
  202. prompt = prompt.format(**prompt_context)
  203. messages = [
  204. {"role": "system", "content": '你是一个专业的用户画像分析助手。'},
  205. {"role": "user", "content": prompt}
  206. ]
  207. tools = [UserProfileExtractor.get_extraction_function()]
  208. response = run_openai_chat(messages, model_name, tools=tools, temperature=0)
  209. tool_calls = response.choices[0].message.tool_calls
  210. if tool_calls:
  211. function_call = tool_calls[0]
  212. if function_call.function.name == 'update_user_profile':
  213. profile_info = json.loads(function_call.function.arguments)
  214. return {k: v for k, v in profile_info.items() if v}
  215. else:
  216. logger.error("llm does not return update_user_profile")
  217. return {}
  218. else:
  219. return {}
  220. def run_chat_prompt(req_data):
  221. prompt = req_data['prompt']
  222. staff_profile = req_data.get('staff_profile', {})
  223. user_profile = req_data.get('user_profile', {})
  224. dialogue_history = req_data.get('dialogue_history', [])
  225. model_name = req_data['model_name']
  226. current_timestamp = req_data['current_timestamp'] / 1000
  227. prompt_context = {**staff_profile, **user_profile}
  228. current_hour = datetime.fromtimestamp(current_timestamp).hour
  229. prompt_context['last_interaction_interval'] = 0
  230. prompt_context['current_time_period'] = DialogueManager.get_time_context(current_hour)
  231. prompt_context['current_hour'] = current_hour
  232. prompt_context['if_first_interaction'] = False if dialogue_history else True
  233. last_message = dialogue_history[-1] if dialogue_history else {'role': 'assistant'}
  234. prompt_context['if_active_greeting'] = False if last_message['role'] == 'user' else True
  235. current_time_str = datetime.fromtimestamp(current_timestamp).strftime('%Y-%m-%d %H:%M:%S')
  236. system_prompt = {
  237. 'role': 'system',
  238. 'content': prompt.format(**prompt_context)
  239. }
  240. messages = [system_prompt]
  241. if req_data['scene'] == 'custom_debugging':
  242. messages.extend(compose_openai_chat_messages_no_time(dialogue_history))
  243. if '头像' in system_prompt['content']:
  244. messages.append({
  245. "role": 'user',
  246. "content": [
  247. {"type": "image_url", "image_url": {"url": user_profile['avatar']}}
  248. ]
  249. })
  250. else:
  251. messages.extend(DialogueManager.compose_chat_messages_openai_compatible(dialogue_history, current_time_str))
  252. return run_openai_chat(messages, model_name, temperature=1, top_p=0.7, max_tokens=1024)
  253. def run_response_type_prompt(req_data):
  254. prompt = req_data['prompt']
  255. dialogue_history = req_data['dialogue_history']
  256. model_name = req_data['model_name']
  257. composed_dialogue = ResponseTypeDetector.compose_dialogue(dialogue_history[:-1])
  258. next_message = DialogueManager.format_dialogue_content(dialogue_history[-1])
  259. prompt = prompt.format(
  260. dialogue_history=composed_dialogue,
  261. message=next_message
  262. )
  263. messages = [
  264. {'role': 'system', 'content': '你是一个专业的智能助手'},
  265. {'role': 'user', 'content': prompt}
  266. ]
  267. return run_openai_chat(messages, model_name,temperature=0.2, max_tokens=128)
  268. @app.route('/api/runPrompt', methods=['POST'])
  269. def run_prompt():
  270. try:
  271. req_data = request.json
  272. logger.debug(req_data)
  273. scene = req_data['scene']
  274. if scene == 'profile_extractor':
  275. response = run_extractor_prompt(req_data)
  276. return wrap_response(200, data=response)
  277. elif scene == 'response_type_detector':
  278. response = run_response_type_prompt(req_data)
  279. return wrap_response(200, data=response.choices[0].message.content)
  280. else:
  281. response = run_chat_prompt(req_data)
  282. return wrap_response(200, data=response.choices[0].message.content)
  283. except Exception as e:
  284. logger.error(e)
  285. return wrap_response(500, msg='Error: {}'.format(e))
  286. @app.errorhandler(werkzeug.exceptions.BadRequest)
  287. def handle_bad_request(e):
  288. logger.error(e)
  289. return wrap_response(400, msg='Bad Request: {}'.format(e.description))
  290. if __name__ == '__main__':
  291. parser = ArgumentParser()
  292. parser.add_argument('--prod', action='store_true')
  293. parser.add_argument('--host', default='127.0.0.1')
  294. parser.add_argument('--port', type=int, default=8083)
  295. parser.add_argument('--log-level', default='INFO')
  296. args = parser.parse_args()
  297. config = configs.get()
  298. logging_level = logging.getLevelName(args.log_level)
  299. logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
  300. user_db_config = config['storage']['user']
  301. staff_db_config = config['storage']['staff']
  302. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
  303. app.user_manager = user_manager
  304. wecom_db_config = config['storage']['user_relation']
  305. user_relation_manager = MySQLUserRelationManager(
  306. user_db_config['mysql'], wecom_db_config['mysql'],
  307. config['storage']['staff']['table'],
  308. user_db_config['table'],
  309. wecom_db_config['table']['staff'],
  310. wecom_db_config['table']['relation'],
  311. wecom_db_config['table']['user']
  312. )
  313. app.user_relation_manager = user_relation_manager
  314. app.history_dialogue_service = HistoryDialogueService(
  315. config['storage']['history_dialogue']['api_base_url']
  316. )
  317. app.run(debug=not args.prod, host=args.host, port=args.port)