Quellcode durchsuchen

Update api_server: add formatForPrompt

StrayWarrior vor 3 Wochen
Ursprung
Commit
355c8fd49a
1 geänderte Dateien mit 27 neuen und 0 gelöschten Zeilen
  1. 27 0
      pqai_agent_server/api_server.py

+ 27 - 0
pqai_agent_server/api_server.py

@@ -10,8 +10,10 @@ from argparse import ArgumentParser
 from pqai_agent import configs
 
 from pqai_agent import logging_service, chat_service, prompt_templates
+from pqai_agent.agents.message_reply_agent import MessageReplyAgent
 from pqai_agent.history_dialogue_service import HistoryDialogueService
 from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
+from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
 from pqai_agent_server.const import AgentApiConst
 from pqai_agent_server.models import MySQLSessionManager
 from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
@@ -171,6 +173,31 @@ def run_prompt():
         logger.error(e)
         return wrap_response(500, msg='Error: {}'.format(e))
 
+@app.route('/api/formatForPrompt', methods=['POST'])
+def format_data_for_prompt():
+    try:
+        req_data = request.json
+        content = req_data['content']
+        format_type = req_data['format_type']
+        if format_type == 'staff_profile':
+            if not isinstance(content, dict):
+                return wrap_response(400, msg='staff_profile should be a dict')
+            response = format_agent_profile(content)
+        elif format_type == 'user_profile':
+            if not isinstance(content, dict):
+                return wrap_response(400, msg='user_profile should be a dict')
+            response = format_user_profile(content)
+        elif format_type == 'dialogue':
+            if not isinstance(content, list):
+                return wrap_response(400, msg='dialogue should be a list')
+            response = MessageReplyAgent.compose_dialogue(content)
+        else:
+            return wrap_response(400, msg='Invalid format_type')
+        return wrap_response(200, data=response)
+    except Exception as e:
+        logger.error(e)
+        return wrap_response(500, msg='Error: {}'.format(e))
+
 
 @app.route("/api/healthCheck", methods=["GET"])
 def health_check():