process_messages.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 处理消息文件,生成结构化的JSON数据
  5. """
  6. import json
  7. import os
  8. from pathlib import Path
  9. from typing import Dict, List, Any, Optional
  10. from collections import defaultdict
  11. def load_all_messages(messages_dir: str) -> List[Dict[str, Any]]:
  12. """加载所有JSON消息文件"""
  13. messages = []
  14. messages_path = Path(messages_dir)
  15. # 只处理JSON文件
  16. for json_file in sorted(messages_path.glob("*.json")):
  17. try:
  18. with open(json_file, 'r', encoding='utf-8') as f:
  19. data = json.load(f)
  20. messages.append(data)
  21. except Exception as e:
  22. print(f"警告: 无法读取文件 {json_file}: {e}")
  23. # 按sequence排序
  24. messages.sort(key=lambda x: x.get('sequence', 0))
  25. return messages
  26. def extract_tool_calls(content: Any) -> List[Dict[str, Any]]:
  27. """从content中提取tool_calls"""
  28. if isinstance(content, dict):
  29. return content.get('tool_calls', [])
  30. return []
  31. def find_tool_result(tool_call_id: str, messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
  32. """根据tool_call_id查找对应的tool结果消息"""
  33. for msg in messages:
  34. if msg.get('role') == 'tool' and msg.get('tool_call_id') == tool_call_id:
  35. return msg
  36. return None
  37. def format_message(msg: Dict[str, Any], messages: List[Dict[str, Any]]) -> Dict[str, Any]:
  38. """格式化单个消息为结构化数据"""
  39. result = {
  40. "sequence": msg.get('sequence'),
  41. "role": msg.get('role'),
  42. "message_id": msg.get('message_id'),
  43. "parent_sequence": msg.get('parent_sequence'),
  44. "status": msg.get('status'),
  45. "goal_id": msg.get('goal_id'),
  46. "created_at": msg.get('created_at'),
  47. }
  48. # 处理content
  49. content = msg.get('content')
  50. if isinstance(content, str):
  51. result["content"] = content
  52. result["text"] = content
  53. elif isinstance(content, dict):
  54. result["text"] = content.get('text', '')
  55. result["content"] = content
  56. # 处理description
  57. if msg.get('description'):
  58. result["description"] = msg.get('description')
  59. # 处理tokens信息
  60. if msg.get('tokens') is not None:
  61. result["tokens"] = msg.get('tokens')
  62. if msg.get('prompt_tokens') is not None:
  63. result["prompt_tokens"] = msg.get('prompt_tokens')
  64. if msg.get('completion_tokens') is not None:
  65. result["completion_tokens"] = msg.get('completion_tokens')
  66. if msg.get('cost') is not None:
  67. result["cost"] = msg.get('cost')
  68. # 如果是assistant消息且有tool_calls,添加children
  69. if msg.get('role') == 'assistant':
  70. tool_calls = extract_tool_calls(content)
  71. if tool_calls:
  72. result["children"] = []
  73. for tool_call in tool_calls:
  74. tool_call_id = tool_call.get('id')
  75. tool_name = tool_call.get('function', {}).get('name', 'unknown')
  76. tool_args = tool_call.get('function', {}).get('arguments', '{}')
  77. # 尝试解析arguments
  78. try:
  79. tool_args_parsed = json.loads(tool_args)
  80. except:
  81. tool_args_parsed = tool_args
  82. tool_node = {
  83. "type": "tool_call",
  84. "tool_call_id": tool_call_id,
  85. "tool_name": tool_name,
  86. "arguments": tool_args_parsed,
  87. "raw_arguments": tool_args,
  88. }
  89. # 查找对应的tool结果
  90. tool_result = find_tool_result(tool_call_id, messages)
  91. if tool_result:
  92. tool_node["result"] = {
  93. "sequence": tool_result.get('sequence'),
  94. "tool_name": tool_result.get('content', {}).get('tool_name') if isinstance(
  95. tool_result.get('content'), dict) else None,
  96. "result": tool_result.get('content', {}).get('result') if isinstance(tool_result.get('content'),
  97. dict) else tool_result.get(
  98. 'content'),
  99. "status": tool_result.get('status'),
  100. "created_at": tool_result.get('created_at'),
  101. }
  102. result["children"].append(tool_node)
  103. # 如果是tool消息,添加工具相关信息
  104. if msg.get('role') == 'tool':
  105. result["tool_call_id"] = msg.get('tool_call_id')
  106. if isinstance(content, dict):
  107. result["tool_name"] = content.get('tool_name')
  108. result["tool_result"] = content.get('result')
  109. return result
  110. def process_messages(messages_dir: str, output_path: str):
  111. """处理所有消息并生成结构化数据"""
  112. messages_dir_path = Path(messages_dir).resolve()
  113. output_file_path = Path(output_path).resolve()
  114. if not messages_dir_path.exists():
  115. raise ValueError(f"输入目录不存在: {messages_dir_path}")
  116. if not messages_dir_path.is_dir():
  117. raise ValueError(f"输入路径不是目录: {messages_dir_path}")
  118. print(f"正在读取消息文件从: {messages_dir_path}")
  119. messages = load_all_messages(str(messages_dir_path))
  120. print(f"共读取 {len(messages)} 条消息")
  121. # 格式化所有消息
  122. structured_messages = []
  123. for msg in messages:
  124. formatted = format_message(msg, messages)
  125. structured_messages.append(formatted)
  126. # 确保输出目录存在
  127. output_file_path.parent.mkdir(parents=True, exist_ok=True)
  128. # 保存结果
  129. with open(output_file_path, 'w', encoding='utf-8') as f:
  130. json.dump(structured_messages, f, ensure_ascii=False, indent=2)
  131. print(f"结构化数据已保存到: {output_file_path}")
  132. print(f"共处理 {len(structured_messages)} 条消息")
  133. # 统计信息
  134. tool_calls_count = sum(1 for msg in structured_messages if msg.get('children'))
  135. print(f"包含工具调用的消息数: {tool_calls_count}")
  136. return structured_messages
  137. if __name__ == "__main__":
  138. # 使用定义的变量
  139. try:
  140. input = ''
  141. output = ''
  142. process_messages(input, output)
  143. except Exception as e:
  144. print(f"错误: {e}")
  145. exit(1)