| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 处理消息文件,生成结构化的JSON数据
- """
- import json
- import os
- from pathlib import Path
- from typing import Dict, List, Any, Optional
- from collections import defaultdict
- def load_all_messages(messages_dir: str) -> List[Dict[str, Any]]:
- """加载所有JSON消息文件"""
- messages = []
- messages_path = Path(messages_dir)
- # 只处理JSON文件
- for json_file in sorted(messages_path.glob("*.json")):
- try:
- with open(json_file, 'r', encoding='utf-8') as f:
- data = json.load(f)
- messages.append(data)
- except Exception as e:
- print(f"警告: 无法读取文件 {json_file}: {e}")
- # 按sequence排序
- messages.sort(key=lambda x: x.get('sequence', 0))
- return messages
- def extract_tool_calls(content: Any) -> List[Dict[str, Any]]:
- """从content中提取tool_calls"""
- if isinstance(content, dict):
- return content.get('tool_calls', [])
- return []
- def find_tool_result(tool_call_id: str, messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
- """根据tool_call_id查找对应的tool结果消息"""
- for msg in messages:
- if msg.get('role') == 'tool' and msg.get('tool_call_id') == tool_call_id:
- return msg
- return None
- def format_message(msg: Dict[str, Any], messages: List[Dict[str, Any]]) -> Dict[str, Any]:
- """格式化单个消息为结构化数据"""
- result = {
- "sequence": msg.get('sequence'),
- "role": msg.get('role'),
- "parent_sequence": msg.get('parent_sequence'),
- "status": msg.get('status'),
- }
- # 处理content
- content = msg.get('content')
- if isinstance(content, str):
- result["content"] = content
- result["text"] = content
- elif isinstance(content, dict):
- result["text"] = content.get('text', '')
- result["content"] = content
- # 生成title:取text的前60个字符加省略号
- text = result.get('text', '')
- if text:
- if len(text) > 60:
- result["title"] = text[:60] + "..."
- else:
- result["title"] = text
- else:
- result["title"] = ""
- # 处理tokens信息
- if msg.get('tokens') is not None:
- result["tokens"] = msg.get('tokens')
- if msg.get('prompt_tokens') is not None:
- result["prompt_tokens"] = msg.get('prompt_tokens')
- if msg.get('completion_tokens') is not None:
- result["completion_tokens"] = msg.get('completion_tokens')
- if msg.get('cost') is not None:
- result["cost"] = msg.get('cost')
- # 如果是assistant消息且有tool_calls,添加children
- if msg.get('role') == 'assistant':
- tool_calls = extract_tool_calls(content)
- if tool_calls:
- result["children"] = []
- for tool_call in tool_calls:
- tool_call_id = tool_call.get('id')
- tool_name = tool_call.get('function', {}).get('name', 'unknown')
- tool_args = tool_call.get('function', {}).get('arguments', '{}')
- # 尝试解析arguments
- try:
- tool_args_parsed = json.loads(tool_args)
- except:
- tool_args_parsed = tool_args
- tool_node = {
- "type": "tool_call",
- "tool_call_id": tool_call_id,
- "tool_name": tool_name,
- "arguments": tool_args_parsed,
- "raw_arguments": tool_args,
- }
- # 查找对应的tool结果
- tool_result = find_tool_result(tool_call_id, messages)
- if tool_result:
- tool_node["result"] = {
- "sequence": tool_result.get('sequence'),
- "tool_name": tool_result.get('content', {}).get('tool_name') if isinstance(
- tool_result.get('content'), dict) else None,
- "result": tool_result.get('content', {}).get('result') if isinstance(tool_result.get('content'),
- dict) else tool_result.get(
- 'content'),
- "status": tool_result.get('status'),
- }
- result["children"].append(tool_node)
- # 如果title为空(text不存在),尝试从children列表的最后一个对象的result.result中获取
- if not result.get("title") or result["title"].strip() == "":
- children = result.get("children", [])
- if children:
- last_child = children[-1]
- if last_child.get("result") and last_child["result"].get("result"):
- result_text = str(last_child["result"]["result"])
- if result_text:
- # 优先匹配 "Summary:" 后面的字符
- summary_match = None
- if "Summary:" in result_text:
- # 查找 "Summary:" 后面的内容
- summary_index = result_text.find("Summary:")
- if summary_index != -1:
- summary_text = result_text[summary_index + len("Summary:"):].strip()
- # 取到换行符或前60个字符
- if "\n" in summary_text:
- summary_text = summary_text.split("\n")[0].strip()
- if summary_text:
- if len(summary_text) > 60:
- summary_match = summary_text[:60] + "..."
- else:
- summary_match = summary_text
-
- if summary_match:
- result["title"] = summary_match
- else:
- # 如果不存在Summary,则从result.result中获取前60个字符
- if len(result_text) > 60:
- result["title"] = result_text[:60] + "..."
- else:
- result["title"] = result_text
- # tool消息不单独创建记录,结果已经放在assistant消息的children中
- # 如果是tool消息,返回None,后续会被过滤掉
- if msg.get('role') == 'tool':
- return None
- return result
- def process_messages(messages_dir: str, output_path: str):
- """处理所有消息并生成结构化数据"""
- messages_dir_path = Path(messages_dir).resolve()
- output_file_path = Path(output_path).resolve()
- if not messages_dir_path.exists():
- raise ValueError(f"输入目录不存在: {messages_dir_path}")
- if not messages_dir_path.is_dir():
- raise ValueError(f"输入路径不是目录: {messages_dir_path}")
- print(f"正在读取消息文件从: {messages_dir_path}")
- messages = load_all_messages(str(messages_dir_path))
- print(f"共读取 {len(messages)} 条消息")
- # 格式化所有消息,过滤掉tool消息(结果已放在assistant的children中)
- structured_messages = []
- for msg in messages:
- formatted = format_message(msg, messages)
- if formatted is not None: # 过滤掉tool消息(返回None)
- structured_messages.append(formatted)
- # 确保输出目录存在
- output_file_path.parent.mkdir(parents=True, exist_ok=True)
- # 保存结果
- with open(output_file_path, 'w', encoding='utf-8') as f:
- json.dump(structured_messages, f, ensure_ascii=False, indent=2)
- print(f"结构化数据已保存到: {output_file_path}")
- print(f"共处理 {len(structured_messages)} 条消息")
- # 统计信息
- tool_calls_count = sum(1 for msg in structured_messages if msg.get('children'))
- print(f"包含工具调用的消息数: {tool_calls_count}")
- return structured_messages
- if __name__ == "__main__":
- # 使用定义的变量
- try:
- input = '/Users/shimeng/Desktop/py/Agent/examples/content_needs_generation/.trace/6bddb982-21db-4cbc-b064-8a568ce0791d/messages'
- output = '/Users/shimeng/Desktop/py/Agent/examples/content_needs_generation/.trace/6bddb982-21db-4cbc-b064-8a568ce0791d/output.json'
- process_messages(input, output)
- except Exception as e:
- print(f"错误: {e}")
- exit(1)
|