|
@@ -4,6 +4,7 @@ import random
|
|
|
import traceback
|
|
|
|
|
|
from datetime import datetime
|
|
|
+from typing import Dict, List, Tuple
|
|
|
|
|
|
from openai import OpenAI
|
|
|
from tqdm import tqdm
|
|
@@ -12,6 +13,7 @@ from pqai_agent.database import MySQLManager
|
|
|
from pqai_agent.agents.message_push_agent import MessagePushAgent
|
|
|
from pqai_agent.logging_service import logger
|
|
|
from pqai_agent import configs, logging_service
|
|
|
+from pqai_agent.mq_message import MessageType
|
|
|
|
|
|
logging_service.setup_root_logger()
|
|
|
|
|
@@ -91,12 +93,13 @@ def get_dialogue_history_by_id(staff_id, dialogue_id_tuple):
|
|
|
from qywx_chat_history
|
|
|
where id in %s;
|
|
|
"""
|
|
|
+
|
|
|
conversation_list = mysql_client.select(sql=sql, cursor_type=DictCursor, args=(dialogue_id_tuple,))
|
|
|
history_conversation = [
|
|
|
{
|
|
|
"content": i['content'],
|
|
|
"role": "assistant" if i['sender'] == staff_id else "user",
|
|
|
- "timestamp": int(i['sendtime'] / 1000)
|
|
|
+ "timestamp": i['sendtime']
|
|
|
} for i in conversation_list
|
|
|
]
|
|
|
return history_conversation
|
|
@@ -195,6 +198,23 @@ def generate_reply_dataset():
|
|
|
# f.write(json.dumps(data_set, ensure_ascii=False, indent=4))
|
|
|
|
|
|
|
|
|
+def compose_dialogue(dialogue: List[Dict], timestamp_type: str='ms') -> str:
|
|
|
+ role_map = {'user': '用户', 'assistant': '客服'}
|
|
|
+ messages = []
|
|
|
+ for msg in dialogue:
|
|
|
+ if not msg['content']:
|
|
|
+ continue
|
|
|
+ if msg['role'] not in role_map:
|
|
|
+ continue
|
|
|
+ if timestamp_type == 'ms':
|
|
|
+ format_dt = datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
|
|
|
+ else:
|
|
|
+ format_dt = datetime.fromtimestamp(msg['timestamp']).strftime('%Y-%m-%d %H:%M:%S')
|
|
|
+ msg_type = msg.get('type', MessageType.TEXT).description
|
|
|
+ messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
|
|
|
+ return '\n'.join(messages)
|
|
|
+
|
|
|
+
|
|
|
def generate_push_dataset():
|
|
|
|
|
|
fetch_query = f"""
|
|
@@ -204,9 +224,8 @@ def generate_push_dataset():
|
|
|
"""
|
|
|
data_set = mysql_client.select(fetch_query, cursor_type=DictCursor)
|
|
|
filter_conversation = [i for i in data_set if len(json.loads(i['conversation'])) >= 20]
|
|
|
- print(len(filter_conversation))
|
|
|
|
|
|
- samples =random.sample(filter_conversation, 300)
|
|
|
+ samples =random.sample(filter_conversation, 100)
|
|
|
|
|
|
# init message push agent
|
|
|
agent = MessagePushAgent()
|
|
@@ -217,18 +236,20 @@ def generate_push_dataset():
|
|
|
user_profile = json.loads(user_profile)
|
|
|
conversation = get_dialogue_history_by_id(
|
|
|
sample["staff_id"],
|
|
|
- tuple(sample["conversation"])
|
|
|
+ tuple(json.loads(sample["conversation"]))
|
|
|
)
|
|
|
conversation.append(
|
|
|
{
|
|
|
"content": sample["content"],
|
|
|
"role": "assistant",
|
|
|
- "timestamp": sample["send_time"]
|
|
|
+ "timestamp": sample["send_time"] * 1000,
|
|
|
+ # "type": 1
|
|
|
}
|
|
|
)
|
|
|
conversation = sorted(conversation, key=lambda i: i['timestamp'], reverse=False)
|
|
|
+
|
|
|
last_timestamp = int(conversation[-1]["timestamp"])
|
|
|
- push_time = last_timestamp + 24 * 3600
|
|
|
+ push_time = int(last_timestamp / 1000) + 24 * 3600
|
|
|
push_dt = datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
|
|
|
try:
|
|
|
push_message = agent.generate_message(
|
|
@@ -246,25 +267,28 @@ def generate_push_dataset():
|
|
|
"current_datetime": push_dt,
|
|
|
"avatar": None
|
|
|
},
|
|
|
- dialogue_history=conversation,
|
|
|
- timestamp_type="s"
|
|
|
+ dialogue_history=conversation
|
|
|
)
|
|
|
- insert_query = f"""
|
|
|
- insert into internal_conversation_data
|
|
|
- (dataset_id, staff_id, user_id, version_date, conversation, content, send_time, send_type, user_active_rate)
|
|
|
- values (%s, %s, %s, %s, %s, %s, %s, %s, %s);
|
|
|
- """
|
|
|
- mysql_client.execute(insert_query, args=(
|
|
|
- '2',
|
|
|
- sample["staff_id"],
|
|
|
- sample["user_id"],
|
|
|
- '2025-06-16',
|
|
|
- sample["conversation"],
|
|
|
- push_message,
|
|
|
- push_time,
|
|
|
- 1,
|
|
|
- sample["user_active_rate"]
|
|
|
- ))
|
|
|
+ if not push_message:
|
|
|
+ print("push message error")
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ insert_query = f"""
|
|
|
+ insert into internal_conversation_data
|
|
|
+ (dataset_id, staff_id, user_id, version_date, conversation, content, send_time, send_type, user_active_rate)
|
|
|
+ values (%s, %s, %s, %s, %s, %s, %s, %s, %s);
|
|
|
+ """
|
|
|
+ mysql_client.execute(insert_query, args=(
|
|
|
+ '2',
|
|
|
+ sample["staff_id"],
|
|
|
+ sample["user_id"],
|
|
|
+ '2025-06-16',
|
|
|
+ sample["conversation"],
|
|
|
+ push_message,
|
|
|
+ push_time,
|
|
|
+ 1,
|
|
|
+ sample["user_active_rate"]
|
|
|
+ ))
|
|
|
except Exception as e:
|
|
|
print("error", e)
|
|
|
print(traceback.format_exc())
|