weight_score_query_tools.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 权重分查询工具
  5. 从 examples/piaoquan_demand/data/{execution_id} 目录读取权重分 JSON,
  6. 支持按元素/分类查询 TopN,以及按名称列表批量查询权重分。
  7. """
  8. import json
  9. from pathlib import Path
  10. from agent import tool
  11. from examples.piaoquan_demand.topic_build_agent_context import TopicBuildAgentContext
  12. from examples.piaoquan_demand.topic_build_pattern_tools import _log_tool_input, _log_tool_output
  13. _VALID_LEVELS = {"元素", "分类"}
  14. _VALID_DIMENSIONS = {"实质", "形式", "意图"}
  15. def _get_weight_file_path(level: str, dimension: str) -> Path:
  16. """根据参数构造权重数据文件路径。"""
  17. execution_id = TopicBuildAgentContext.get_execution_id()
  18. if execution_id is None:
  19. raise ValueError("未设置 execution_id,请先在 TopicBuildAgentContext 中设置")
  20. filename = f"{dimension}_{level}.json"
  21. base_dir = Path(__file__).parent / "data" / str(execution_id)
  22. return base_dir / filename
  23. def _load_weight_data(level: str, dimension: str) -> list[dict]:
  24. """读取并返回权重数据列表。"""
  25. file_path = _get_weight_file_path(level=level, dimension=dimension)
  26. if not file_path.exists():
  27. raise FileNotFoundError(f"权重数据文件不存在: {file_path}")
  28. with file_path.open("r", encoding="utf-8") as f:
  29. data = json.load(f)
  30. if not isinstance(data, list):
  31. raise ValueError(f"权重数据格式错误,期望 list,实际为: {type(data).__name__}")
  32. return data
  33. def _validate_params(level: str, dimension: str):
  34. """校验通用参数。"""
  35. if level not in _VALID_LEVELS:
  36. raise ValueError(f"level 参数非法: {level},可选值: {sorted(_VALID_LEVELS)}")
  37. if dimension not in _VALID_DIMENSIONS:
  38. raise ValueError(f"dimension 参数非法: {dimension},可选值: {sorted(_VALID_DIMENSIONS)}")
  39. @tool("查询元素或分类权重分排名区间。参数:level(元素/分类)、dimension(实质/形式/意图)、start(起始排名,含,从1开始)、end(结束排名,含,从1开始)。")
  40. def get_weight_score_topn(level: str, dimension: str, start: int = 1, end: int = 10) -> str:
  41. """查询元素或分类权重分排名区间。
  42. Args:
  43. level: 查询层级,元素 或 分类。
  44. dimension: 查询维度,实质 / 形式 / 意图。
  45. start: 起始排名(包含),从 1 开始。
  46. end: 结束排名(包含),从 1 开始。
  47. Returns:
  48. JSON 字符串,包含查询参数、总量和区间数据。
  49. """
  50. execution_id = TopicBuildAgentContext.get_execution_id()
  51. params = {
  52. "execution_id": execution_id,
  53. "level": level,
  54. "dimension": dimension,
  55. "start": start,
  56. "end": end,
  57. }
  58. _log_tool_input("get_weight_score_topn", params)
  59. try:
  60. _validate_params(level=level, dimension=dimension)
  61. if start < 1 or end < 1:
  62. return _log_tool_output(
  63. "get_weight_score_topn",
  64. f"错误: start/end 必须为大于等于 1 的整数,当前值: start={start}, end={end}",
  65. )
  66. if start > end:
  67. return _log_tool_output(
  68. "get_weight_score_topn",
  69. f"错误: start 不能大于 end,当前值: start={start}, end={end}",
  70. )
  71. data = _load_weight_data(level=level, dimension=dimension)
  72. sorted_data = sorted(data, key=lambda x: float(x.get("score", 0)), reverse=True)
  73. # 用户输入为 1-based 且 end 为包含边界,需转换为 Python 切片
  74. ranged_items = sorted_data[start - 1 : end]
  75. result = {
  76. "level": level,
  77. "dimension": dimension,
  78. "start": start,
  79. "end": end,
  80. "total_count": len(data),
  81. "matched_count": len(ranged_items),
  82. "items": ranged_items,
  83. }
  84. return _log_tool_output("get_weight_score_topn", json.dumps(result, ensure_ascii=False, indent=2))
  85. except Exception as e:
  86. return _log_tool_output("get_weight_score_topn", f"查询失败: {e}")
  87. @tool("批量查询指定名称的权重分。参数:level(元素/分类)、dimension(实质/形式/意图)、names(名称列表)。")
  88. def get_weight_score_by_name(level: str, dimension: str, names: list[str]) -> str:
  89. """批量查询指定名称的权重分。
  90. Args:
  91. level: 查询层级,元素 或 分类。
  92. dimension: 查询维度,实质 / 形式 / 意图。
  93. names: 要查询的名称列表(元素名或分类名),顺序与返回 results 一一对应。
  94. Returns:
  95. JSON 字符串,含每个名称的 matched_count 与 items。
  96. """
  97. execution_id = TopicBuildAgentContext.get_execution_id()
  98. params = {
  99. "execution_id": execution_id,
  100. "level": level,
  101. "dimension": dimension,
  102. "names": names,
  103. }
  104. _log_tool_input("get_weight_score_by_name", params)
  105. try:
  106. _validate_params(level=level, dimension=dimension)
  107. if not names:
  108. return _log_tool_output("get_weight_score_by_name", "错误: names 不能为空列表")
  109. if not isinstance(names, list):
  110. return _log_tool_output("get_weight_score_by_name", f"错误: names 必须为列表,当前类型: {type(names).__name__}")
  111. stripped: list[str] = []
  112. for i, n in enumerate(names):
  113. if n is None or (isinstance(n, str) and not n.strip()):
  114. return _log_tool_output(
  115. "get_weight_score_by_name",
  116. f"错误: names[{i}] 不能为空",
  117. )
  118. stripped.append(str(n).strip())
  119. data = _load_weight_data(level=level, dimension=dimension)
  120. if level == "元素":
  121. key = "name"
  122. else:
  123. key = "category"
  124. results = []
  125. for target_name in stripped:
  126. matched = [item for item in data if str(item.get(key, "")).strip() == target_name]
  127. results.append(
  128. {
  129. "name": target_name,
  130. "matched_count": len(matched),
  131. "items": matched,
  132. }
  133. )
  134. result = {
  135. "level": level,
  136. "dimension": dimension,
  137. "query_count": len(stripped),
  138. "results": results,
  139. }
  140. return _log_tool_output("get_weight_score_by_name", json.dumps(result, ensure_ascii=False, indent=2))
  141. except Exception as e:
  142. return _log_tool_output("get_weight_score_by_name", f"查询失败: {e}")