"""demand 示例的最小可运行入口。""" import asyncio import copy import importlib import json import os import sys from datetime import datetime from pathlib import Path from typing import Optional from zoneinfo import ZoneInfo from dotenv import load_dotenv from sqlalchemy import desc, or_ from examples.demand.changwen_prepare import changwen_prepare from examples.demand.config import LOG_LEVEL, ENABLED_TOOLS from examples.demand.db_manager import DatabaseManager from examples.demand.models import TopicPatternExecution from examples.demand.piaoquan_prepare import prepare, piaoquan_prepare from examples.demand.demand_agent_context import TopicBuildAgentContext from examples.demand.mysql import mysql_db # Clash Verge TUN 模式兼容:禁止 httpx/urllib 自动检测系统 HTTP 代理 os.environ.setdefault("no_proxy", "*") # 该示例仅使用项目侧能力,禁用框架内置 skills os.environ.setdefault("AGENT_DISABLE_BUILTIN_SKILLS", "1") # 禁用内置工具自动注册,并开启严格工具白名单 os.environ.setdefault("AGENT_DISABLE_BUILTIN_TOOL_REGISTRATION", "1") os.environ.setdefault("AGENT_STRICT_TOOL_SELECTION", "1") # 禁用所有侧分支(压缩/反思) os.environ.setdefault("AGENT_DISABLE_SIDE_BRANCHES", "1") # 添加项目根目录到 Python 路径 sys.path.insert(0, str(Path(__file__).parent.parent.parent)) load_dotenv() from agent.core.runner import AgentRunner from agent.llm import create_openrouter_llm_call from agent.llm.prompts import SimplePrompt from agent.trace import FileSystemTraceStore, Message, Trace from agent.utils import setup_logging from examples.demand.log_capture import build_log, log # 导入项目配置 from examples.demand.config import DEBUG, LOG_FILE, LOG_LEVEL, RUN_CONFIG, TRACE_STORE_PATH CUSTOM_TOOL_MODULES = { # demand 示例:严格按工具名白名单加载对应模块 "think_and_plan": "examples.demand.agent_tools", "get_category_tree": "examples.demand.demand_pattern_tools", "get_frequent_itemsets": "examples.demand.demand_pattern_tools", "get_itemset_detail": "examples.demand.demand_pattern_tools", "get_post_elements": "examples.demand.demand_pattern_tools", "search_elements": "examples.demand.demand_pattern_tools", "get_element_category_chain": "examples.demand.demand_pattern_tools", "get_category_detail": "examples.demand.demand_pattern_tools", "search_categories": "examples.demand.demand_pattern_tools", "get_category_elements": "examples.demand.demand_pattern_tools", "get_category_co_occurrences": "examples.demand.demand_pattern_tools", "get_element_co_occurrences": "examples.demand.demand_pattern_tools", "get_weight_score_topn": "examples.demand.weight_score_query_tools", "get_weight_score_by_name": "examples.demand.weight_score_query_tools", "create_demand_item": "examples.demand.demand_build_agent_tools", "create_demand_items": "examples.demand.demand_build_agent_tools", "write_execution_summary": "examples.demand.demand_build_agent_tools", } def get_execution_id_by_merge_level2(cluster_name: str): """根据二级品类和平台查询最新的 execution_id。""" session = DatabaseManager().get_session() try: start_of_today = datetime.combine(datetime.now().date(), datetime.min.time()) query = session.query(TopicPatternExecution).filter( TopicPatternExecution.cluster_name == cluster_name, TopicPatternExecution.start_time >= start_of_today, ) execution = query.order_by(desc(TopicPatternExecution.id)).first() if not execution: return None return execution.id finally: session.close() def resolve_model(prompt: SimplePrompt) -> str: model_from_prompt = prompt.config.get("model") if model_from_prompt: return model_from_prompt return f"anthropic/{RUN_CONFIG.model}" if "/" not in RUN_CONFIG.model else RUN_CONFIG.model def extract_assistant_text(message: Message) -> str: if message.role != "assistant": return "" content = message.content if isinstance(content, str): return content if isinstance(content, dict): text = content.get("text", "") # 即使本轮包含工具调用,也打印模型给出的文本,便于观察每一步输出 if text: return text return "" def register_selected_tools(tool_names: list[str]) -> None: for tool_name in tool_names: module_path = CUSTOM_TOOL_MODULES.get(tool_name) if not module_path: raise ValueError(f"未配置工具模块映射: {tool_name}") importlib.import_module(module_path) def _join_element_names_to_name(element_names: object) -> str: """把 tool 入参 element_names 转成 demand_content.name(逗号分隔)。""" if element_names is None: return "" if isinstance(element_names, list): parts = [str(x).strip() for x in element_names if x is not None and str(x).strip()] return ",".join(parts) # 兼容异常数据:比如历史版本可能传了字符串 return str(element_names).strip() def _safe_truncate(s: object, max_len: int) -> str: if s is None: return "" s_str = str(s) if max_len and len(s_str) > max_len: return s_str[:max_len] return s_str def _load_name_score_map(execution_id: int) -> dict: """读取 data/{execution_id} 下所有 JSON 的「名字->score」(同名取最高分)。 兼容两类数据结构: - `*_元素.json`:字段 `name` 表示名字 - `*_分类.json`:字段 `category` 表示名字 """ data_dir = Path(__file__).parent / "data" / str(execution_id) if not data_dir.exists(): return {} score_map = {} for json_path in data_dir.glob("*.json"): try: with open(json_path, "r", encoding="utf-8") as f: payload = json.load(f) except Exception: continue if not isinstance(payload, list): continue for item in payload: if not isinstance(item, dict): continue # 元素数据以 name 为主;分类数据以 category 为主。 name = item.get("name") if not isinstance(name, str) or not name: name = item.get("category") score = item.get("score") if isinstance(name, str) and isinstance(score, (int, float)): prev = score_map.get(name) score_f = float(score) if prev is None or score_f > prev: score_map[name] = score_f return score_map def _avg_score_for_joined_name(name: str, score_map: dict) -> float: """按逗号拆分 name,分别取分后求平均。""" parts = [part.strip() for part in str(name).split(",") if part and part.strip()] if not parts: return 0.0 return sum(float(score_map.get(part, 0.0)) for part in parts) / len(parts) def _create_demand_task( execution_id: int, name: Optional[str] = None, platform: Optional[str] = None, ) -> Optional[int]: """创建 demand_task 记录,返回任务ID。""" try: # 数据库字段 demand_task.name: varchar(32) if name is not None: name = str(name)[:32] # 数据库字段 demand_task.platform: varchar(32) if platform is not None: platform = str(platform)[:32] task_id = mysql_db.insert( "demand_task", { "execution_id": execution_id, "name": name, "platform": platform, "status": 0, "log": "", }, ) log(f"[task] 创建 demand_task 成功,task_id={task_id}, execution_id={execution_id}") return task_id except Exception as e: log(f"[task] 创建 demand_task 失败,execution_id={execution_id}, error={e}") return None def _finish_demand_task(task_id: Optional[int], status: int, task_log: str) -> None: """更新 demand_task 状态与日志。""" if not task_id: return try: mysql_db.update( "demand_task", { "status": int(status), "log": task_log or "", }, "id = %s", (task_id,), ) log(f"[task] 更新 demand_task 成功,task_id={task_id}, status={status}") except Exception as e: log(f"[task] 更新 demand_task 失败,task_id={task_id}, status={status}, error={e}") def write_demand_items_to_mysql(execution_id: int, merge_level2: str) -> int: """ 把 result/{execution_id}/execution_id_{execution_id}_demand_items.json 写入 MySQL 表 demand_content """ # create_demand_item(s) 使用 Path.cwd()/result 作为输出目录。 # 为了兼容“从不同目录启动脚本”的情况,这里同时尝试 cwd 和脚本目录两种结果位置。 demand_items_path = ( Path.cwd() / "result" / str(execution_id) / f"execution_id_{execution_id}_demand_items.json" ) if not demand_items_path.exists(): alt_path = ( Path(__file__).parent / "result" / str(execution_id) / f"execution_id_{execution_id}_demand_items.json" ) if alt_path.exists(): demand_items_path = alt_path else: log(f"[mysql] 未找到需求 JSON:{demand_items_path}(也未找到 {alt_path}),跳过写入") return 0 try: with open(demand_items_path, "r", encoding="utf-8") as f: loaded = json.load(f) except Exception as e: log(f"[mysql] 读取需求 JSON 失败:{demand_items_path},error={e}") return 0 items = loaded["items"] if isinstance(loaded, dict) and isinstance(loaded.get("items"), list) else loaded if not isinstance(items, list): log(f"[mysql] 需求 JSON 非数组,跳过写入:type={type(items)}") return 0 dt_value = datetime.now(ZoneInfo("Asia/Shanghai")).strftime("%Y%m%d") score_map = _load_name_score_map(execution_id) rows: list[dict] = [] for di in items: if not isinstance(di, dict): continue name = _join_element_names_to_name(di.get("element_names")) if not name: continue score = _avg_score_for_joined_name(name, score_map) reason = di.get("reason") desc_value = di.get("desc") suggestion = desc_value # 兼容旧字段:同时保留 ext_data(reason/desc)JSON,便于旧版消费逻辑迁移期继续使用。 ext_data = {"reason": reason, "desc": desc_value} rows.append( { "merge_leve2": _safe_truncate(merge_level2, 32), "name": _safe_truncate(name, 64), "reason": reason, "suggestion": suggestion, "score": float(score), "ext_data": json.dumps(ext_data, ensure_ascii=False), "dt": dt_value, } ) if not rows: log("[mysql] 生成行为空,跳过写入") return 0 affected = mysql_db.insert_many("demand_content", rows) log(f"[mysql] 写入 demand_content 完成,rows={len(rows)}, affected={affected}") return len(rows) async def run_once(execution_id, merge_level2, task_id: Optional[int] = None) -> str: task_log_text = "" task_status = 0 TopicBuildAgentContext.set_execution_id(execution_id) base_dir = Path(__file__).parent output_dir = base_dir / "output" output_dir.mkdir(exist_ok=True) setup_logging(level=LOG_LEVEL, file=LOG_FILE) register_selected_tools(ENABLED_TOOLS) prompt = SimplePrompt(base_dir / "demand.md") model = resolve_model(prompt) run_config = copy.deepcopy(RUN_CONFIG) run_config.temperature = float(prompt.config.get("temperature", run_config.temperature)) run_config.max_iterations = int(prompt.config.get("max_iterations", run_config.max_iterations)) run_config.tools = ENABLED_TOOLS.copy() # 禁用反思/总结经验相关流程(避免进入 reflection 侧分支) run_config.enable_research_flow = False run_config.goal_compression = "none" run_config.force_side_branch = None run_config.knowledge.enable_extraction = False run_config.knowledge.enable_completion_extraction = False run_config.knowledge.enable_injection = False run_config.trace_id = None initial_messages = prompt.build_messages(merge_level2=merge_level2) store = FileSystemTraceStore(base_path=TRACE_STORE_PATH) runner = AgentRunner( trace_store=store, llm_call=create_openrouter_llm_call(model=model), skills_dir=None, debug=DEBUG, ) final_text = "" total_tokens = 0 total_cost = 0.0 has_completed_trace_cost = False log_file_path = output_dir / f"{execution_id}" / f"run_log_{execution_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt" log_file_path.parent.mkdir(parents=True, exist_ok=True) try: with build_log(execution_id) as log_buffer: async for item in runner.run(messages=initial_messages, config=run_config): if isinstance(item, Trace): if getattr(item, "status", None) == "completed": total_tokens = int(getattr(item, "total_tokens", 0) or 0) total_cost = float(getattr(item, "total_cost", 0.0) or 0.0) has_completed_trace_cost = True continue elif isinstance(item, Message): text = extract_assistant_text(item) if text: final_text = text log(f"[assistant] {text}") if not has_completed_trace_cost: total_tokens += int(getattr(item, "total_tokens", 0) or 0) total_cost += float(getattr(item, "cost", 0.0) or 0.0) if final_text: output_file = output_dir / f"{execution_id}" / "result.txt" output_file.parent.mkdir(parents=True, exist_ok=True) with open(output_file, "w", encoding="utf-8") as f: f.write(final_text) log(f"[cost] total_tokens={total_tokens}, total_cost=${total_cost:.6f}") # agent 执行完成后:把本地 result JSON 写入 MySQL 表 demand_content # element_names -> name(逗号分隔);reason -> demand_content.reason;desc -> demand_content.suggestion;dt -> demand_content.dt try: write_demand_items_to_mysql(execution_id=execution_id, merge_level2=merge_level2) except Exception as e: log(f"[mysql] 写入 demand_content 异常:{e}") task_log_text = log_buffer.getvalue() task_status = 1 except Exception as e: if not task_log_text: # 如果异常发生在 build_log 内部,尽量回收已产生的日志 try: existing = locals().get("log_buffer") if existing is not None: task_log_text = existing.getvalue() # type: ignore[attr-defined] except Exception: pass if not task_log_text: task_log_text = f"[run] 执行异常: {e}" task_status = 2 raise finally: if task_log_text: try: with open(log_file_path, "w", encoding="utf-8") as f: f.write(task_log_text) except Exception: # 兜底:即使写文件失败,也要确保 MySQL 状态被更新 pass _finish_demand_task(task_id=task_id, status=task_status, task_log=task_log_text) return final_text async def main( cluster_name: str, platform_type: str, execution_id: Optional[int] = None, task_id: Optional[int] = None, ) -> dict: if execution_id is None: if platform_type == "piaoquan": execution_id = piaoquan_prepare(cluster_name) elif platform_type == "changwen": execution_id = changwen_prepare(cluster_name) else: execution_id = None if not execution_id: return {"execution_id": None, "final_text": ""} final_text = await run_once(execution_id, cluster_name, task_id=task_id) return {"execution_id": execution_id, "final_text": final_text} if __name__ == "__main__": # asyncio.run(run_once(8, '贪污腐败')) write_demand_items_to_mysql(execution_id=8, merge_level2='贪污腐败')