demand_build_agent_tools.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import json
  2. from pathlib import Path
  3. from typing import Any, Dict, List, Optional
  4. from agent import tool
  5. from examples.demand.demand_agent_context import TopicBuildAgentContext
  6. from examples.demand.demand_pattern_tools import _log_tool_output, _log_tool_input
  7. def _get_result_base_dir() -> Path:
  8. """输出到“当前工作目录/result/”下。"""
  9. return Path.cwd() / "result"
  10. @tool(
  11. "存储需求到结果集。 - element_names - score(权重) - reason(原因)- desc(需求描述)"
  12. )
  13. def create_demand_item(
  14. element_names: List[str] = None,
  15. score: float = 0.0,
  16. reason: str = None,
  17. desc: str = None) -> str:
  18. """
  19. 每次调用向“execution_id 对应的本地 JSON 文件”追加一条记录。
  20. 写入对象包含以下字段:
  21. - element_names
  22. - score(权重)
  23. - reason(原因)
  24. - desc(需求描述)
  25. """
  26. execution_id: Optional[int] = TopicBuildAgentContext.get_execution_id()
  27. params: Dict[str, Any] = {
  28. "execution_id": execution_id,
  29. "score": score,
  30. "element_names": element_names,
  31. "reason": reason,
  32. "desc": desc,
  33. }
  34. _log_tool_input("create_demand_item", params)
  35. if not execution_id:
  36. return _log_tool_output("create_demand_item", "错误: 未设置 execution_id")
  37. record: Dict[str, Any] = {
  38. "element_names": element_names,
  39. "score": score,
  40. "reason": reason,
  41. "desc": desc,
  42. }
  43. # 按 execution_id 区分文件,避免不同执行互相污染。
  44. # 例如:result/{execution_id}/execution_id_{execution_id}_demand_items.json
  45. output_dir = _get_result_base_dir() / f"{execution_id}"
  46. output_path = output_dir / f"execution_id_{execution_id}_demand_items.json"
  47. output_path.parent.mkdir(parents=True, exist_ok=True)
  48. items: List[Dict[str, Any]] = []
  49. if output_path.exists():
  50. try:
  51. with open(output_path, "r", encoding="utf-8") as f:
  52. loaded = json.load(f)
  53. if isinstance(loaded, list):
  54. items = loaded
  55. elif isinstance(loaded, dict) and isinstance(loaded.get("items"), list):
  56. # 兼容可能的包装格式:{"items":[...]}
  57. items = loaded["items"]
  58. else:
  59. # 兜底:把已有内容当作单条记录追加
  60. items = [loaded]
  61. except json.JSONDecodeError:
  62. # 文件内容损坏时,不阻断执行;从空列表开始追加
  63. items = []
  64. items.append(record)
  65. with open(output_path, "w", encoding="utf-8") as f:
  66. json.dump(items, f, ensure_ascii=False, indent=2)
  67. result = json.dumps(
  68. {"success": True, "execution_id": execution_id, "written_to": str(output_path)},
  69. ensure_ascii=False,
  70. )
  71. return _log_tool_output("create_demand_item", result)
  72. @tool(
  73. "批量存储需求到结果集。 - element_names - score(权重) - reason(原因)- desc(需求描述)"
  74. )
  75. def create_demand_items(demand_items: List[Dict[str, Any]] = None) -> str:
  76. """
  77. 一次调用追加多条记录到“execution_id 对应的本地 JSON 文件”(JSON 数组)。
  78. 每条记录字段:
  79. - element_names
  80. - score(权重)
  81. - reason(原因)
  82. - desc(需求描述)
  83. """
  84. execution_id: Optional[int] = TopicBuildAgentContext.get_execution_id()
  85. params: Dict[str, Any] = {"execution_id": execution_id, "count": len(demand_items or []),
  86. "demand_items": demand_items}
  87. _log_tool_input("create_demand_items", params)
  88. if not execution_id:
  89. return _log_tool_output("create_demand_items", "错误: 未设置 execution_id")
  90. if not demand_items or not isinstance(demand_items, list):
  91. return _log_tool_output("create_demand_items", "错误: demand_items 必须为非空列表")
  92. output_dir = _get_result_base_dir() / f"{execution_id}"
  93. output_path = output_dir / f"execution_id_{execution_id}_demand_items.json"
  94. output_path.parent.mkdir(parents=True, exist_ok=True)
  95. items: List[Dict[str, Any]] = []
  96. if output_path.exists():
  97. try:
  98. with open(output_path, "r", encoding="utf-8") as f:
  99. loaded = json.load(f)
  100. if isinstance(loaded, list):
  101. items = loaded
  102. elif isinstance(loaded, dict) and isinstance(loaded.get("items"), list):
  103. items = loaded["items"]
  104. else:
  105. items = [loaded]
  106. except json.JSONDecodeError:
  107. items = []
  108. written_records: List[Dict[str, Any]] = []
  109. for i, di in enumerate(demand_items):
  110. if not isinstance(di, dict):
  111. return _log_tool_output("create_demand_items", f"错误: demand_items[{i}] 必须为对象(dict)")
  112. record = {
  113. "element_names": di.get("element_names"),
  114. "score": di.get("score", 0.0),
  115. "reason": di.get("reason"),
  116. "desc": di.get("desc"),
  117. }
  118. written_records.append(record)
  119. items.extend(written_records)
  120. with open(output_path, "w", encoding="utf-8") as f:
  121. json.dump(items, f, ensure_ascii=False, indent=2)
  122. result = json.dumps(
  123. {
  124. "success": True,
  125. "execution_id": execution_id,
  126. "written_to": str(output_path),
  127. "written_count": len(written_records),
  128. },
  129. ensure_ascii=False,
  130. )
  131. return _log_tool_output("create_demand_items", result)
  132. @tool(
  133. "写入本次执行总结(在所有分类完成后调用)。"
  134. "\n\n该工具用于把最终总结记录到本地/trace输出中(框架侧通过返回值与日志落盘)。"
  135. )
  136. def write_execution_summary(summary: str) -> str:
  137. """写入本次执行总结。在所有分类完成后调用。
  138. Args:
  139. summary: 执行总结(Markdown 格式)。
  140. Returns:
  141. JSON 字符串:
  142. - 成功:`{"success": True, "execution_id": execution_id}`
  143. - 失败:`"错误: 未设置 execution_id"`
  144. """
  145. execution_id: Optional[int] = TopicBuildAgentContext.get_execution_id()
  146. params: Dict[str, str] = {"summary": summary}
  147. _log_tool_input("write_execution_summary", params)
  148. if not execution_id:
  149. return _log_tool_output("write_execution_summary", "错误: 未设置 execution_id")
  150. result = json.dumps({"success": True, "execution_id": execution_id}, ensure_ascii=False)
  151. return _log_tool_output("write_execution_summary", result)