Ver Fonte

Quick fix of api server: format dialogue

StrayWarrior há 3 semanas atrás
pai
commit
6e9afc82d9
2 ficheiros alterados com 18 adições e 1 exclusões
  1. 2 1
      pqai_agent_server/api_server.py
  2. 16 0
      pqai_agent_server/utils/prompt_util.py

+ 2 - 1
pqai_agent_server/api_server.py

@@ -190,7 +190,8 @@ def format_data_for_prompt():
         elif format_type == 'dialogue':
             if not isinstance(content, list):
                 return wrap_response(400, msg='dialogue should be a list')
-            response = MessageReplyAgent.compose_dialogue(content)
+            from pqai_agent_server.utils.prompt_util import compose_dialogue
+            response = compose_dialogue(content)
         else:
             return wrap_response(400, msg='Invalid format_type')
         return wrap_response(200, data=response)

+ 16 - 0
pqai_agent_server/utils/prompt_util.py

@@ -1,6 +1,8 @@
 import json
 
 from datetime import datetime
+from typing import List, Dict
+
 from openai import OpenAI
 
 from pqai_agent import logging_service, chat_service
@@ -174,3 +176,17 @@ def run_response_type_prompt(req_data):
         {"role": "user", "content": prompt},
     ]
     return run_openai_chat(messages, model_name, temperature=0.2, max_tokens=128)
+
+
+def compose_dialogue(dialogue: List[Dict]) -> str:
+    role_map = {'user': '用户', 'assistant': '客服'}
+    messages = []
+    for msg in dialogue:
+        if not msg['content']:
+            continue
+        if msg['role'] not in role_map:
+            continue
+        format_dt = datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
+        msg_type = MessageType(msg.get('type', MessageType.TEXT.value)).description
+        messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
+    return '\n'.join(messages)