| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 权重分查询工具
- 从 examples/piaoquan_demand/data/{execution_id} 目录读取权重分 JSON,
- 支持按元素/分类查询 TopN,以及按名称列表批量查询权重分。
- """
- import json
- from pathlib import Path
- from agent import tool
- from examples.piaoquan_demand.topic_build_agent_context import TopicBuildAgentContext
- from examples.piaoquan_demand.topic_build_pattern_tools import _log_tool_input, _log_tool_output
- _VALID_LEVELS = {"元素", "分类"}
- _VALID_DIMENSIONS = {"实质", "形式", "意图"}
- def _get_weight_file_path(level: str, dimension: str) -> Path:
- """根据参数构造权重数据文件路径。"""
- execution_id = TopicBuildAgentContext.get_execution_id()
- if execution_id is None:
- raise ValueError("未设置 execution_id,请先在 TopicBuildAgentContext 中设置")
- filename = f"{dimension}_{level}.json"
- base_dir = Path(__file__).parent / "data" / str(execution_id)
- return base_dir / filename
- def _load_weight_data(level: str, dimension: str) -> list[dict]:
- """读取并返回权重数据列表。"""
- file_path = _get_weight_file_path(level=level, dimension=dimension)
- if not file_path.exists():
- raise FileNotFoundError(f"权重数据文件不存在: {file_path}")
- with file_path.open("r", encoding="utf-8") as f:
- data = json.load(f)
- if not isinstance(data, list):
- raise ValueError(f"权重数据格式错误,期望 list,实际为: {type(data).__name__}")
- return data
- def _validate_params(level: str, dimension: str):
- """校验通用参数。"""
- if level not in _VALID_LEVELS:
- raise ValueError(f"level 参数非法: {level},可选值: {sorted(_VALID_LEVELS)}")
- if dimension not in _VALID_DIMENSIONS:
- raise ValueError(f"dimension 参数非法: {dimension},可选值: {sorted(_VALID_DIMENSIONS)}")
- @tool("查询元素或分类权重分排名区间。参数:level(元素/分类)、dimension(实质/形式/意图)、start(起始排名,含,从1开始)、end(结束排名,含,从1开始)。")
- def get_weight_score_topn(level: str, dimension: str, start: int = 1, end: int = 10) -> str:
- """查询元素或分类权重分排名区间。
- Args:
- level: 查询层级,元素 或 分类。
- dimension: 查询维度,实质 / 形式 / 意图。
- start: 起始排名(包含),从 1 开始。
- end: 结束排名(包含),从 1 开始。
- Returns:
- JSON 字符串,包含查询参数、总量和区间数据。
- """
- execution_id = TopicBuildAgentContext.get_execution_id()
- params = {
- "execution_id": execution_id,
- "level": level,
- "dimension": dimension,
- "start": start,
- "end": end,
- }
- _log_tool_input("get_weight_score_topn", params)
- try:
- _validate_params(level=level, dimension=dimension)
- if start < 1 or end < 1:
- return _log_tool_output(
- "get_weight_score_topn",
- f"错误: start/end 必须为大于等于 1 的整数,当前值: start={start}, end={end}",
- )
- if start > end:
- return _log_tool_output(
- "get_weight_score_topn",
- f"错误: start 不能大于 end,当前值: start={start}, end={end}",
- )
- data = _load_weight_data(level=level, dimension=dimension)
- sorted_data = sorted(data, key=lambda x: float(x.get("score", 0)), reverse=True)
- # 用户输入为 1-based 且 end 为包含边界,需转换为 Python 切片
- ranged_items = sorted_data[start - 1 : end]
- result = {
- "level": level,
- "dimension": dimension,
- "start": start,
- "end": end,
- "total_count": len(data),
- "matched_count": len(ranged_items),
- "items": ranged_items,
- }
- return _log_tool_output("get_weight_score_topn", json.dumps(result, ensure_ascii=False, indent=2))
- except Exception as e:
- return _log_tool_output("get_weight_score_topn", f"查询失败: {e}")
- @tool("批量查询指定名称的权重分。参数:level(元素/分类)、dimension(实质/形式/意图)、names(名称列表)。")
- def get_weight_score_by_name(level: str, dimension: str, names: list[str]) -> str:
- """批量查询指定名称的权重分。
- Args:
- level: 查询层级,元素 或 分类。
- dimension: 查询维度,实质 / 形式 / 意图。
- names: 要查询的名称列表(元素名或分类名),顺序与返回 results 一一对应。
- Returns:
- JSON 字符串,含每个名称的 matched_count 与 items。
- """
- execution_id = TopicBuildAgentContext.get_execution_id()
- params = {
- "execution_id": execution_id,
- "level": level,
- "dimension": dimension,
- "names": names,
- }
- _log_tool_input("get_weight_score_by_name", params)
- try:
- _validate_params(level=level, dimension=dimension)
- if not names:
- return _log_tool_output("get_weight_score_by_name", "错误: names 不能为空列表")
- if not isinstance(names, list):
- return _log_tool_output("get_weight_score_by_name", f"错误: names 必须为列表,当前类型: {type(names).__name__}")
- stripped: list[str] = []
- for i, n in enumerate(names):
- if n is None or (isinstance(n, str) and not n.strip()):
- return _log_tool_output(
- "get_weight_score_by_name",
- f"错误: names[{i}] 不能为空",
- )
- stripped.append(str(n).strip())
- data = _load_weight_data(level=level, dimension=dimension)
- if level == "元素":
- key = "name"
- else:
- key = "category"
- results = []
- for target_name in stripped:
- matched = [item for item in data if str(item.get(key, "")).strip() == target_name]
- results.append(
- {
- "name": target_name,
- "matched_count": len(matched),
- "items": matched,
- }
- )
- result = {
- "level": level,
- "dimension": dimension,
- "query_count": len(stripped),
- "results": results,
- }
- return _log_tool_output("get_weight_score_by_name", json.dumps(result, ensure_ascii=False, indent=2))
- except Exception as e:
- return _log_tool_output("get_weight_score_by_name", f"查询失败: {e}")
|