#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 权重分查询工具 从 examples/piaoquan_demand/data/{execution_id} 目录读取权重分 JSON, 支持按元素/分类查询 TopN,以及按名称列表批量查询权重分。 """ import json from pathlib import Path from agent import tool from examples.piaoquan_demand.topic_build_agent_context import TopicBuildAgentContext from examples.piaoquan_demand.topic_build_pattern_tools import _log_tool_input, _log_tool_output _VALID_LEVELS = {"元素", "分类"} _VALID_DIMENSIONS = {"实质", "形式", "意图"} def _get_weight_file_path(level: str, dimension: str) -> Path: """根据参数构造权重数据文件路径。""" execution_id = TopicBuildAgentContext.get_execution_id() if execution_id is None: raise ValueError("未设置 execution_id,请先在 TopicBuildAgentContext 中设置") filename = f"{dimension}_{level}.json" base_dir = Path(__file__).parent / "data" / str(execution_id) return base_dir / filename def _load_weight_data(level: str, dimension: str) -> list[dict]: """读取并返回权重数据列表。""" file_path = _get_weight_file_path(level=level, dimension=dimension) if not file_path.exists(): raise FileNotFoundError(f"权重数据文件不存在: {file_path}") with file_path.open("r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list): raise ValueError(f"权重数据格式错误,期望 list,实际为: {type(data).__name__}") return data def _validate_params(level: str, dimension: str): """校验通用参数。""" if level not in _VALID_LEVELS: raise ValueError(f"level 参数非法: {level},可选值: {sorted(_VALID_LEVELS)}") if dimension not in _VALID_DIMENSIONS: raise ValueError(f"dimension 参数非法: {dimension},可选值: {sorted(_VALID_DIMENSIONS)}") @tool("查询元素或分类权重分排名区间。参数:level(元素/分类)、dimension(实质/形式/意图)、start(起始排名,含,从1开始)、end(结束排名,含,从1开始)。") def get_weight_score_topn(level: str, dimension: str, start: int = 1, end: int = 10) -> str: """查询元素或分类权重分排名区间。 Args: level: 查询层级,元素 或 分类。 dimension: 查询维度,实质 / 形式 / 意图。 start: 起始排名(包含),从 1 开始。 end: 结束排名(包含),从 1 开始。 Returns: JSON 字符串,包含查询参数、总量和区间数据。 """ execution_id = TopicBuildAgentContext.get_execution_id() params = { "execution_id": execution_id, "level": level, "dimension": dimension, "start": start, "end": end, } _log_tool_input("get_weight_score_topn", params) try: _validate_params(level=level, dimension=dimension) if start < 1 or end < 1: return _log_tool_output( "get_weight_score_topn", f"错误: start/end 必须为大于等于 1 的整数,当前值: start={start}, end={end}", ) if start > end: return _log_tool_output( "get_weight_score_topn", f"错误: start 不能大于 end,当前值: start={start}, end={end}", ) data = _load_weight_data(level=level, dimension=dimension) sorted_data = sorted(data, key=lambda x: float(x.get("score", 0)), reverse=True) # 用户输入为 1-based 且 end 为包含边界,需转换为 Python 切片 ranged_items = sorted_data[start - 1 : end] result = { "level": level, "dimension": dimension, "start": start, "end": end, "total_count": len(data), "matched_count": len(ranged_items), "items": ranged_items, } return _log_tool_output("get_weight_score_topn", json.dumps(result, ensure_ascii=False, indent=2)) except Exception as e: return _log_tool_output("get_weight_score_topn", f"查询失败: {e}") @tool("批量查询指定名称的权重分。参数:level(元素/分类)、dimension(实质/形式/意图)、names(名称列表)。") def get_weight_score_by_name(level: str, dimension: str, names: list[str]) -> str: """批量查询指定名称的权重分。 Args: level: 查询层级,元素 或 分类。 dimension: 查询维度,实质 / 形式 / 意图。 names: 要查询的名称列表(元素名或分类名),顺序与返回 results 一一对应。 Returns: JSON 字符串,含每个名称的 matched_count 与 items。 """ execution_id = TopicBuildAgentContext.get_execution_id() params = { "execution_id": execution_id, "level": level, "dimension": dimension, "names": names, } _log_tool_input("get_weight_score_by_name", params) try: _validate_params(level=level, dimension=dimension) if not names: return _log_tool_output("get_weight_score_by_name", "错误: names 不能为空列表") if not isinstance(names, list): return _log_tool_output("get_weight_score_by_name", f"错误: names 必须为列表,当前类型: {type(names).__name__}") stripped: list[str] = [] for i, n in enumerate(names): if n is None or (isinstance(n, str) and not n.strip()): return _log_tool_output( "get_weight_score_by_name", f"错误: names[{i}] 不能为空", ) stripped.append(str(n).strip()) data = _load_weight_data(level=level, dimension=dimension) if level == "元素": key = "name" else: key = "category" results = [] for target_name in stripped: matched = [item for item in data if str(item.get(key, "")).strip() == target_name] results.append( { "name": target_name, "matched_count": len(matched), "items": matched, } ) result = { "level": level, "dimension": dimension, "query_count": len(stripped), "results": results, } return _log_tool_output("get_weight_score_by_name", json.dumps(result, ensure_ascii=False, indent=2)) except Exception as e: return _log_tool_output("get_weight_score_by_name", f"查询失败: {e}")