process_messages.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 处理消息文件,生成结构化的JSON数据
  5. """
  6. import json
  7. import os
  8. from pathlib import Path
  9. from typing import Dict, List, Any, Optional
  10. from collections import defaultdict
  11. def load_all_messages(messages_dir: str) -> List[Dict[str, Any]]:
  12. """加载所有JSON消息文件"""
  13. messages = []
  14. messages_path = Path(messages_dir)
  15. # 只处理JSON文件
  16. for json_file in sorted(messages_path.glob("*.json")):
  17. try:
  18. with open(json_file, 'r', encoding='utf-8') as f:
  19. data = json.load(f)
  20. messages.append(data)
  21. except Exception as e:
  22. print(f"警告: 无法读取文件 {json_file}: {e}")
  23. # 按sequence排序
  24. messages.sort(key=lambda x: x.get('sequence', 0))
  25. return messages
  26. def extract_tool_calls(content: Any) -> List[Dict[str, Any]]:
  27. """从content中提取tool_calls"""
  28. if isinstance(content, dict):
  29. return content.get('tool_calls', [])
  30. return []
  31. def find_tool_result(tool_call_id: str, messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
  32. """根据tool_call_id查找对应的tool结果消息"""
  33. for msg in messages:
  34. if msg.get('role') == 'tool' and msg.get('tool_call_id') == tool_call_id:
  35. return msg
  36. return None
  37. def format_message(msg: Dict[str, Any], messages: List[Dict[str, Any]]) -> Dict[str, Any]:
  38. """格式化单个消息为结构化数据"""
  39. result = {
  40. "sequence": msg.get('sequence'),
  41. "role": msg.get('role'),
  42. "parent_sequence": msg.get('parent_sequence'),
  43. "status": msg.get('status'),
  44. }
  45. # 处理content
  46. content = msg.get('content')
  47. if isinstance(content, str):
  48. result["content"] = content
  49. result["text"] = content
  50. elif isinstance(content, dict):
  51. result["text"] = content.get('text', '')
  52. result["content"] = content
  53. # 生成title:取text的前60个字符加省略号
  54. text = result.get('text', '')
  55. if text:
  56. if len(text) > 60:
  57. result["title"] = text[:60] + "..."
  58. else:
  59. result["title"] = text
  60. else:
  61. result["title"] = ""
  62. # 处理tokens信息
  63. if msg.get('tokens') is not None:
  64. result["tokens"] = msg.get('tokens')
  65. if msg.get('prompt_tokens') is not None:
  66. result["prompt_tokens"] = msg.get('prompt_tokens')
  67. if msg.get('completion_tokens') is not None:
  68. result["completion_tokens"] = msg.get('completion_tokens')
  69. if msg.get('cost') is not None:
  70. result["cost"] = msg.get('cost')
  71. # 如果是assistant消息且有tool_calls,添加children
  72. if msg.get('role') == 'assistant':
  73. tool_calls = extract_tool_calls(content)
  74. if tool_calls:
  75. result["children"] = []
  76. for tool_call in tool_calls:
  77. tool_call_id = tool_call.get('id')
  78. tool_name = tool_call.get('function', {}).get('name', 'unknown')
  79. tool_args = tool_call.get('function', {}).get('arguments', '{}')
  80. # 尝试解析arguments
  81. try:
  82. tool_args_parsed = json.loads(tool_args)
  83. except:
  84. tool_args_parsed = tool_args
  85. tool_node = {
  86. "type": "tool_call",
  87. "tool_call_id": tool_call_id,
  88. "tool_name": tool_name,
  89. "arguments": tool_args_parsed,
  90. "raw_arguments": tool_args,
  91. }
  92. # 查找对应的tool结果
  93. tool_result = find_tool_result(tool_call_id, messages)
  94. if tool_result:
  95. tool_node["result"] = {
  96. "sequence": tool_result.get('sequence'),
  97. "tool_name": tool_result.get('content', {}).get('tool_name') if isinstance(
  98. tool_result.get('content'), dict) else None,
  99. "result": tool_result.get('content', {}).get('result') if isinstance(tool_result.get('content'),
  100. dict) else tool_result.get(
  101. 'content'),
  102. "status": tool_result.get('status'),
  103. }
  104. result["children"].append(tool_node)
  105. # 如果title为空(text不存在),尝试从children列表的最后一个对象的result.result中获取
  106. if not result.get("title") or result["title"].strip() == "":
  107. children = result.get("children", [])
  108. if children:
  109. last_child = children[-1]
  110. if last_child.get("result") and last_child["result"].get("result"):
  111. result_text = str(last_child["result"]["result"])
  112. if result_text:
  113. # 优先匹配 "Summary:" 后面的字符
  114. summary_match = None
  115. if "Summary:" in result_text:
  116. # 查找 "Summary:" 后面的内容
  117. summary_index = result_text.find("Summary:")
  118. if summary_index != -1:
  119. summary_text = result_text[summary_index + len("Summary:"):].strip()
  120. # 取到换行符或前60个字符
  121. if "\n" in summary_text:
  122. summary_text = summary_text.split("\n")[0].strip()
  123. if summary_text:
  124. if len(summary_text) > 60:
  125. summary_match = summary_text[:60] + "..."
  126. else:
  127. summary_match = summary_text
  128. if summary_match:
  129. result["title"] = summary_match
  130. else:
  131. # 如果不存在Summary,则从result.result中获取前60个字符
  132. if len(result_text) > 60:
  133. result["title"] = result_text[:60] + "..."
  134. else:
  135. result["title"] = result_text
  136. # tool消息不单独创建记录,结果已经放在assistant消息的children中
  137. # 如果是tool消息,返回None,后续会被过滤掉
  138. if msg.get('role') == 'tool':
  139. return None
  140. return result
  141. def process_messages(messages_dir: str, output_path: str):
  142. """处理所有消息并生成结构化数据"""
  143. messages_dir_path = Path(messages_dir).resolve()
  144. output_file_path = Path(output_path).resolve()
  145. if not messages_dir_path.exists():
  146. raise ValueError(f"输入目录不存在: {messages_dir_path}")
  147. if not messages_dir_path.is_dir():
  148. raise ValueError(f"输入路径不是目录: {messages_dir_path}")
  149. print(f"正在读取消息文件从: {messages_dir_path}")
  150. messages = load_all_messages(str(messages_dir_path))
  151. print(f"共读取 {len(messages)} 条消息")
  152. # 格式化所有消息,过滤掉tool消息(结果已放在assistant的children中)
  153. structured_messages = []
  154. for msg in messages:
  155. formatted = format_message(msg, messages)
  156. if formatted is not None: # 过滤掉tool消息(返回None)
  157. structured_messages.append(formatted)
  158. # 确保输出目录存在
  159. output_file_path.parent.mkdir(parents=True, exist_ok=True)
  160. # 保存结果
  161. with open(output_file_path, 'w', encoding='utf-8') as f:
  162. json.dump(structured_messages, f, ensure_ascii=False, indent=2)
  163. print(f"结构化数据已保存到: {output_file_path}")
  164. print(f"共处理 {len(structured_messages)} 条消息")
  165. # 统计信息
  166. tool_calls_count = sum(1 for msg in structured_messages if msg.get('children'))
  167. print(f"包含工具调用的消息数: {tool_calls_count}")
  168. return structured_messages
  169. if __name__ == "__main__":
  170. # 使用定义的变量
  171. try:
  172. input = '/Users/shimeng/Desktop/py/Agent/examples/content_needs_generation/.trace/6bddb982-21db-4cbc-b064-8a568ce0791d/messages'
  173. output = '/Users/shimeng/Desktop/py/Agent/examples/content_needs_generation/.trace/6bddb982-21db-4cbc-b064-8a568ce0791d/output.json'
  174. process_messages(input, output)
  175. except Exception as e:
  176. print(f"错误: {e}")
  177. exit(1)