"""piaoquan_demand 示例的最小可运行入口。""" import asyncio import copy import importlib import json import os import sys from datetime import datetime from pathlib import Path from typing import Optional from dotenv import load_dotenv from sqlalchemy import desc, or_ from examples.piaoquan_demand.config import LOG_LEVEL, ENABLED_TOOLS from examples.piaoquan_demand.db_manager import DatabaseManager from examples.piaoquan_demand.models import TopicPatternExecution from examples.piaoquan_demand.pattern_toos.pattern_service import run_mining from examples.piaoquan_demand.prepare import prepare from examples.piaoquan_demand.topic_build_agent_context import TopicBuildAgentContext from examples.piaoquan_demand.mysql import mysql_db # Clash Verge TUN 模式兼容:禁止 httpx/urllib 自动检测系统 HTTP 代理 os.environ.setdefault("no_proxy", "*") # 该示例仅使用项目侧能力,禁用框架内置 skills os.environ.setdefault("AGENT_DISABLE_BUILTIN_SKILLS", "1") # 禁用内置工具自动注册,并开启严格工具白名单 os.environ.setdefault("AGENT_DISABLE_BUILTIN_TOOL_REGISTRATION", "1") os.environ.setdefault("AGENT_STRICT_TOOL_SELECTION", "1") # 禁用所有侧分支(压缩/反思) os.environ.setdefault("AGENT_DISABLE_SIDE_BRANCHES", "1") # 添加项目根目录到 Python 路径 sys.path.insert(0, str(Path(__file__).parent.parent.parent)) load_dotenv() from agent.core.runner import AgentRunner from agent.llm import create_openrouter_llm_call from agent.llm.prompts import SimplePrompt from agent.trace import FileSystemTraceStore, Message, Trace from agent.utils import setup_logging from log_capture import build_log, log # 导入项目配置 from examples.piaoquan_demand.config import DEBUG, LOG_FILE, LOG_LEVEL, RUN_CONFIG, TRACE_STORE_PATH CUSTOM_TOOL_MODULES = { # piaoquan_demand 示例:严格按工具名白名单加载对应模块 "think_and_plan": "examples.piaoquan_demand.agent_tools", "get_category_tree": "examples.piaoquan_demand.topic_build_pattern_tools", "get_frequent_itemsets": "examples.piaoquan_demand.topic_build_pattern_tools", "get_itemset_detail": "examples.piaoquan_demand.topic_build_pattern_tools", "get_post_elements": "examples.piaoquan_demand.topic_build_pattern_tools", "search_elements": "examples.piaoquan_demand.topic_build_pattern_tools", "get_element_category_chain": "examples.piaoquan_demand.topic_build_pattern_tools", "get_category_detail": "examples.piaoquan_demand.topic_build_pattern_tools", "search_categories": "examples.piaoquan_demand.topic_build_pattern_tools", "get_category_elements": "examples.piaoquan_demand.topic_build_pattern_tools", "get_category_co_occurrences": "examples.piaoquan_demand.topic_build_pattern_tools", "get_element_co_occurrences": "examples.piaoquan_demand.topic_build_pattern_tools", "get_weight_score_topn": "examples.piaoquan_demand.weight_score_query_tools", "get_weight_score_by_name": "examples.piaoquan_demand.weight_score_query_tools", "create_demand_item": "examples.piaoquan_demand.demand_build_agent_tools", "create_demand_items": "examples.piaoquan_demand.demand_build_agent_tools", "write_execution_summary": "examples.piaoquan_demand.demand_build_agent_tools", } def get_execution_id_by_merge_level2(cluster_name: str): """根据二级品类和平台查询最新的 execution_id。""" session = DatabaseManager().get_session() try: start_of_today = datetime.combine(datetime.now().date(), datetime.min.time()) query = session.query(TopicPatternExecution).filter( TopicPatternExecution.cluster_name == cluster_name, TopicPatternExecution.start_time >= start_of_today, ) execution = query.order_by(desc(TopicPatternExecution.id)).first() if not execution: return None return execution.id finally: session.close() def resolve_model(prompt: SimplePrompt) -> str: model_from_prompt = prompt.config.get("model") if model_from_prompt: return model_from_prompt return f"anthropic/{RUN_CONFIG.model}" if "/" not in RUN_CONFIG.model else RUN_CONFIG.model def extract_assistant_text(message: Message) -> str: if message.role != "assistant": return "" content = message.content if isinstance(content, str): return content if isinstance(content, dict): text = content.get("text", "") # 即使本轮包含工具调用,也打印模型给出的文本,便于观察每一步输出 if text: return text return "" def register_selected_tools(tool_names: list[str]) -> None: for tool_name in tool_names: module_path = CUSTOM_TOOL_MODULES.get(tool_name) if not module_path: raise ValueError(f"未配置工具模块映射: {tool_name}") importlib.import_module(module_path) def _join_element_names_to_name(element_names: object) -> str: """把 tool 入参 element_names 转成 demand_content.name(逗号分隔)。""" if element_names is None: return "" if isinstance(element_names, list): parts = [str(x).strip() for x in element_names if x is not None and str(x).strip()] return ",".join(parts) # 兼容异常数据:比如历史版本可能传了字符串 return str(element_names).strip() def _safe_truncate(s: object, max_len: int) -> str: if s is None: return "" s_str = str(s) if max_len and len(s_str) > max_len: return s_str[:max_len] return s_str def _load_name_score_map(execution_id: int) -> dict: """读取 data/{execution_id} 下所有 JSON 的 name->score(同名取最高分)。""" data_dir = Path(__file__).parent / "data" / str(execution_id) if not data_dir.exists(): return {} score_map = {} for json_path in data_dir.glob("*.json"): try: with open(json_path, "r", encoding="utf-8") as f: payload = json.load(f) except Exception: continue if not isinstance(payload, list): continue for item in payload: if not isinstance(item, dict): continue name = item.get("name") score = item.get("score") if isinstance(name, str) and isinstance(score, (int, float)): prev = score_map.get(name) score_f = float(score) if prev is None or score_f > prev: score_map[name] = score_f return score_map def _avg_score_for_joined_name(name: str, score_map: dict) -> float: """按逗号拆分 name,分别取分后求平均。""" parts = [part.strip() for part in str(name).split(",") if part and part.strip()] if not parts: return 0.0 return sum(float(score_map.get(part, 0.0)) for part in parts) / len(parts) def _create_demand_task(execution_id: int) -> Optional[int]: """创建 demand_task 记录,返回任务ID。""" try: task_id = mysql_db.insert( "demand_task", { "execution_id": execution_id, "status": 0, "log": "", }, ) log(f"[task] 创建 demand_task 成功,task_id={task_id}, execution_id={execution_id}") return task_id except Exception as e: log(f"[task] 创建 demand_task 失败,execution_id={execution_id}, error={e}") return None def _finish_demand_task(task_id: Optional[int], status: int, task_log: str) -> None: """更新 demand_task 状态与日志。""" if not task_id: return try: mysql_db.update( "demand_task", { "status": int(status), "log": task_log or "", }, "id = %s", (task_id,), ) log(f"[task] 更新 demand_task 成功,task_id={task_id}, status={status}") except Exception as e: log(f"[task] 更新 demand_task 失败,task_id={task_id}, status={status}, error={e}") def write_demand_items_to_mysql(execution_id: int, merge_level2: str) -> int: """ 把 result/{execution_id}/execution_id_{execution_id}_demand_items.json 写入 MySQL 表 demand_content """ # create_demand_item(s) 使用 Path.cwd()/result 作为输出目录。 # 为了兼容“从不同目录启动脚本”的情况,这里同时尝试 cwd 和脚本目录两种结果位置。 demand_items_path = ( Path.cwd() / "result" / str(execution_id) / f"execution_id_{execution_id}_demand_items.json" ) if not demand_items_path.exists(): alt_path = ( Path(__file__).parent / "result" / str(execution_id) / f"execution_id_{execution_id}_demand_items.json" ) if alt_path.exists(): demand_items_path = alt_path else: log(f"[mysql] 未找到需求 JSON:{demand_items_path}(也未找到 {alt_path}),跳过写入") return 0 try: with open(demand_items_path, "r", encoding="utf-8") as f: loaded = json.load(f) except Exception as e: log(f"[mysql] 读取需求 JSON 失败:{demand_items_path},error={e}") return 0 items = loaded["items"] if isinstance(loaded, dict) and isinstance(loaded.get("items"), list) else loaded if not isinstance(items, list): log(f"[mysql] 需求 JSON 非数组,跳过写入:type={type(items)}") return 0 score_map = _load_name_score_map(execution_id) rows: list[dict] = [] for di in items: if not isinstance(di, dict): continue name = _join_element_names_to_name(di.get("element_names")) if not name: continue score = _avg_score_for_joined_name(name, score_map) reason = di.get("reason") desc_value = di.get("desc") ext_data = {"reason": reason, "desc": desc_value} rows.append( { "merge_leve2": _safe_truncate(merge_level2, 32), "name": _safe_truncate(name, 64), "score": float(score), "ext_data": json.dumps(ext_data, ensure_ascii=False), } ) if not rows: log("[mysql] 生成行为空,跳过写入") return 0 affected = mysql_db.insert_many("demand_content", rows) log(f"[mysql] 写入 demand_content 完成,rows={len(rows)}, affected={affected}") return len(rows) async def run_once(execution_id, merge_level2) -> str: task_id = _create_demand_task(execution_id=execution_id) task_status = 2 task_log_text = "" TopicBuildAgentContext.set_execution_id(execution_id) prepare(execution_id) base_dir = Path(__file__).parent output_dir = base_dir / "output" output_dir.mkdir(exist_ok=True) setup_logging(level=LOG_LEVEL, file=LOG_FILE) register_selected_tools(ENABLED_TOOLS) prompt = SimplePrompt(base_dir / "demand.md") model = resolve_model(prompt) run_config = copy.deepcopy(RUN_CONFIG) run_config.temperature = float(prompt.config.get("temperature", run_config.temperature)) run_config.max_iterations = int(prompt.config.get("max_iterations", run_config.max_iterations)) run_config.tools = ENABLED_TOOLS.copy() # 禁用反思/总结经验相关流程(避免进入 reflection 侧分支) run_config.enable_research_flow = False run_config.goal_compression = "none" run_config.force_side_branch = None run_config.knowledge.enable_extraction = False run_config.knowledge.enable_completion_extraction = False run_config.knowledge.enable_injection = False run_config.trace_id = None initial_messages = prompt.build_messages(merge_level2=merge_level2) store = FileSystemTraceStore(base_path=TRACE_STORE_PATH) runner = AgentRunner( trace_store=store, llm_call=create_openrouter_llm_call(model=model), skills_dir=None, debug=DEBUG, ) final_text = "" total_tokens = 0 total_cost = 0.0 has_completed_trace_cost = False log_file_path = output_dir / f"{execution_id}" / f"run_log_{execution_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt" log_file_path.parent.mkdir(parents=True, exist_ok=True) try: with build_log(execution_id) as log_buffer: async for item in runner.run(messages=initial_messages, config=run_config): if isinstance(item, Trace): if getattr(item, "status", None) == "completed": total_tokens = int(getattr(item, "total_tokens", 0) or 0) total_cost = float(getattr(item, "total_cost", 0.0) or 0.0) has_completed_trace_cost = True continue elif isinstance(item, Message): text = extract_assistant_text(item) if text: final_text = text log(f"[assistant] {text}") if not has_completed_trace_cost: total_tokens += int(getattr(item, "total_tokens", 0) or 0) total_cost += float(getattr(item, "cost", 0.0) or 0.0) if final_text: output_file = output_dir / f"{execution_id}" / "result.txt" output_file.parent.mkdir(parents=True, exist_ok=True) with open(output_file, "w", encoding="utf-8") as f: f.write(final_text) log(f"[cost] total_tokens={total_tokens}, total_cost=${total_cost:.6f}") # agent 执行完成后:把本地 result JSON 写入 MySQL 表 demand_content # element_names -> name(逗号分隔);reason/desc -> ext_data JSON;merge_leve2 -> demand_content.merge_leve2 try: write_demand_items_to_mysql(execution_id=execution_id, merge_level2=merge_level2) except Exception as e: log(f"[mysql] 写入 demand_content 异常:{e}") task_log_text = log_buffer.getvalue() with open(log_file_path, "w", encoding="utf-8") as f: f.write(task_log_text) task_status = 1 except Exception as e: if not task_log_text: task_log_text = f"[run] 执行异常: {e}" task_status = 2 raise finally: _finish_demand_task(task_id=task_id, status=task_status, task_log=task_log_text) return final_text async def main() -> None: cluster_name = '历史名人' execution_id = get_execution_id_by_merge_level2(cluster_name=cluster_name) if execution_id is None: execution_id = run_mining(cluster_name=cluster_name, merge_leve2=cluster_name) print(execution_id) # await run_once(execution_id, cluster_name) if __name__ == "__main__": asyncio.run(main())