ソースを参照

Fix api_server: prompt context

StrayWarrior 1 週間 前
コミット
c6537c97a6
1 ファイル変更10 行追加1 行削除
  1. 10 1
      api_server.py

+ 10 - 1
api_server.py

@@ -1,6 +1,7 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
+import logging
 from calendar import prmonth
 
 import werkzeug.exceptions
@@ -21,6 +22,7 @@ from user_manager import MySQLUserManager, MySQLUserRelationManager
 from user_profile_extractor import UserProfileExtractor
 
 app = Flask('agent_api_server')
+logger = logging_service.logger
 
 def wrap_response(code, msg=None, data=None):
     resp = {
@@ -196,6 +198,8 @@ def run_chat_prompt(req_data):
     prompt_context['current_time_period'] = DialogueManager.get_time_context(current_hour)
     prompt_context['current_hour'] = current_hour
     prompt_context['if_first_interaction'] = False if dialogue_history else True
+    last_message = dialogue_history[-1] if dialogue_history else {'role': 'assistant'}
+    prompt_context['if_active_greeting'] = False if last_message['role'] == 'user' else True
 
     current_time_str = datetime.fromtimestamp(current_timestamp).strftime('%Y-%m-%d %H:%M:%S')
     system_prompt = {
@@ -210,6 +214,7 @@ def run_chat_prompt(req_data):
 def run_prompt():
     try:
         req_data = request.json
+        logger.debug(req_data)
         scene = req_data['scene']
         if scene == 'profile_extractor':
             response = run_extractor_prompt(req_data)
@@ -218,10 +223,12 @@ def run_prompt():
             response = run_chat_prompt(req_data)
             return wrap_response(200, data=response.choices[0].message.content)
     except Exception as e:
+        logger.error(e)
         return wrap_response(500, msg='Error: {}'.format(e))
 
 @app.errorhandler(werkzeug.exceptions.BadRequest)
 def handle_bad_request(e):
+    logger.error(e)
     return wrap_response(400, msg='Bad Request: {}'.format(e.description))
 
 
@@ -230,10 +237,12 @@ if __name__ == '__main__':
     parser.add_argument('--prod', action='store_true')
     parser.add_argument('--host', default='127.0.0.1')
     parser.add_argument('--port', type=int, default=8083)
+    parser.add_argument('--log-level', default='INFO')
     args = parser.parse_args()
 
     config = configs.get()
-    logging_service.setup_root_logger(logfile_name='agent_api_server.log')
+    logging_level = logging.getLevelName(args.log_level)
+    logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
 
     user_db_config = config['storage']['user']
     staff_db_config = config['storage']['staff']