api_server.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. from calendar import prmonth
  5. import werkzeug.exceptions
  6. from flask import Flask, request, jsonify
  7. from datetime import datetime, timedelta
  8. from argparse import ArgumentParser
  9. from openai import OpenAI
  10. import chat_service
  11. import configs
  12. import logging_service
  13. import prompt_templates
  14. from dialogue_manager import DialogueManager
  15. from history_dialogue_service import HistoryDialogueService
  16. from user_manager import MySQLUserManager, MySQLUserRelationManager
  17. app = Flask('agent_api_server')
  18. def wrap_response(code, msg=None, data=None):
  19. resp = {
  20. 'code': code,
  21. 'msg': msg
  22. }
  23. if code == 200 and not msg:
  24. resp['msg'] = 'success'
  25. if data:
  26. resp['data'] = data
  27. return jsonify(resp)
  28. @app.route('/api/listStaffs', methods=['GET'])
  29. def list_staffs():
  30. staff_data = app.user_relation_manager.list_staffs()
  31. return wrap_response(200, data=staff_data)
  32. @app.route('/api/getStaffProfile', methods=['GET'])
  33. def get_staff_profile():
  34. staff_id = request.args['staff_id']
  35. profile = app.user_manager.get_staff_profile(staff_id)
  36. if not profile:
  37. return wrap_response(404, msg='staff not found')
  38. else:
  39. return wrap_response(200, data=profile)
  40. @app.route('/api/getUserProfile', methods=['GET'])
  41. def get_user_profile():
  42. user_id = request.args['user_id']
  43. profile = app.user_manager.get_user_profile(user_id)
  44. if not profile:
  45. resp = {
  46. 'code': 404,
  47. 'msg': 'user not found'
  48. }
  49. else:
  50. resp = {
  51. 'code': 200,
  52. 'msg': 'success',
  53. 'data': profile
  54. }
  55. return jsonify(resp)
  56. @app.route('/api/listUsers', methods=['GET'])
  57. def list_users():
  58. user_name = request.args.get('user_name', None)
  59. user_union_id = request.args.get('user_union_id', None)
  60. if not user_name and not user_union_id:
  61. resp = {
  62. 'code': 400,
  63. 'msg': 'user_name or user_union_id is required'
  64. }
  65. return jsonify(resp)
  66. data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
  67. return jsonify({'code': 200, 'data': data})
  68. @app.route('/api/getDialogueHistory', methods=['GET'])
  69. def get_dialogue_history():
  70. staff_id = request.args['staff_id']
  71. user_id = request.args['user_id']
  72. recent_minutes = int(request.args.get('recent_minutes', 1440))
  73. dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
  74. return jsonify({'code': 200, 'data': dialogue_history})
  75. @app.route('/api/listModels', methods=['GET'])
  76. def list_models():
  77. models = [
  78. {
  79. 'model_type': 'openai_compatible',
  80. 'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
  81. 'display_name': 'DeepSeek V3 on 火山'
  82. },
  83. {
  84. 'model_type': 'openai_compatible',
  85. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  86. 'display_name': '豆包Pro 32K'
  87. },
  88. {
  89. 'model_type': 'openai_compatible',
  90. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  91. 'display_name': '豆包Pro 1.5'
  92. },
  93. ]
  94. return wrap_response(200, data=models)
  95. @app.route('/api/listScenes', methods=['GET'])
  96. def list_scenes():
  97. scenes = [
  98. {'scene': 'greeting', 'display_name': '问候'},
  99. {'scene': 'chitchat', 'display_name': '闲聊'},
  100. {'scene': 'profile_extractor', 'display_name': '画像提取'}
  101. ]
  102. return wrap_response(200, data=scenes)
  103. @app.route('/api/getBasePrompt', methods=['GET'])
  104. def get_base_prompt():
  105. scene = request.args['scene']
  106. prompt_map = {
  107. 'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
  108. 'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
  109. 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT
  110. }
  111. model_map = {
  112. 'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  113. 'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  114. 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
  115. }
  116. if scene not in prompt_map:
  117. return wrap_response(404, msg='scene not found')
  118. data = {
  119. 'model_name': model_map[scene],
  120. 'content': prompt_map[scene]
  121. }
  122. return wrap_response(200, data=data)
  123. def get_llm_response(model_name, messages):
  124. pass
  125. def run_chat_prompt():
  126. pass
  127. def run_extractor_prompt():
  128. pass
  129. @app.route('/api/runPrompt', methods=['POST'])
  130. def run_prompt():
  131. try:
  132. req_data = request.json
  133. scene = req_data['scene']
  134. prompt = req_data['prompt']
  135. staff_profile = req_data['staff_profile']
  136. user_profile = req_data['user_profile']
  137. dialogue_history = req_data['dialogue_history']
  138. model_name = req_data['model_name']
  139. current_timestamp = req_data['current_timestamp'] / 1000
  140. prompt_context = {**staff_profile, **user_profile}
  141. current_hour = datetime.fromtimestamp(current_timestamp).hour
  142. prompt_context['last_interaction_interval'] = 0
  143. prompt_context['current_time_period'] = DialogueManager.get_time_context(current_hour)
  144. prompt_context['current_hour'] = current_hour
  145. prompt_context['if_first_interaction'] = False if dialogue_history else True
  146. volcengine_models = [
  147. chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  148. chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  149. chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
  150. ]
  151. deepseek_models = [
  152. chat_service.DEEPSEEK_CHAT_MODEL,
  153. ]
  154. current_time_str = datetime.fromtimestamp(current_timestamp).strftime('%Y-%m-%d %H:%M:%S')
  155. system_prompt = {
  156. 'role': 'system',
  157. 'content': prompt.format(**prompt_context)
  158. }
  159. messages = []
  160. messages.append(system_prompt)
  161. messages.extend(DialogueManager.compose_chat_messages_openai_compatible(dialogue_history, current_time_str))
  162. if model_name in volcengine_models:
  163. llm_client = OpenAI(api_key=chat_service.VOLCENGINE_API_TOKEN, base_url=chat_service.VOLCENGINE_BASE_URL)
  164. response = llm_client.chat.completions.create(
  165. messages=messages, model=model_name, temperature=1, top_p=0.7, max_tokens=1024)
  166. return wrap_response(200, data=response.choices[0].message.content)
  167. elif model_name in deepseek_models:
  168. llm_client = OpenAI(api_key=chat_service.DEEPSEEK_API_TOKEN, base_url=chat_service.DEEPSEEK_BASE_URL)
  169. response = llm_client.chat.completions.create(
  170. messages=messages, model=model_name, temperature=1, top_p=0.7, max_tokens=1024)
  171. return wrap_response(200, data=response.choices[0].message.content)
  172. else:
  173. return wrap_response(400, msg='model not supported')
  174. except Exception as e:
  175. return wrap_response(500, msg='Error: {}'.format(e))
  176. @app.errorhandler(werkzeug.exceptions.BadRequest)
  177. def handle_bad_request(e):
  178. return wrap_response(400, msg='Bad Request: {}'.format(e.description))
  179. if __name__ == '__main__':
  180. parser = ArgumentParser()
  181. parser.add_argument('--prod', action='store_true')
  182. parser.add_argument('--host', default='127.0.0.1')
  183. parser.add_argument('--port', type=int, default=8083)
  184. args = parser.parse_args()
  185. config = configs.get()
  186. logging_service.setup_root_logger(logfile_name='agent_api_server.log')
  187. user_db_config = config['storage']['user']
  188. staff_db_config = config['storage']['staff']
  189. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
  190. app.user_manager = user_manager
  191. wecom_db_config = config['storage']['user_relation']
  192. user_relation_manager = MySQLUserRelationManager(
  193. user_db_config['mysql'], wecom_db_config['mysql'],
  194. config['storage']['staff']['table'],
  195. user_db_config['table'],
  196. wecom_db_config['table']['staff'],
  197. wecom_db_config['table']['relation'],
  198. wecom_db_config['table']['user']
  199. )
  200. app.user_relation_manager = user_relation_manager
  201. app.history_dialogue_service = HistoryDialogueService(
  202. config['storage']['history_dialogue']['api_base_url']
  203. )
  204. app.run(debug=not args.prod, host=args.host, port=args.port)