| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391 |
- """piaoquan_demand 示例的最小可运行入口。"""
- import asyncio
- import copy
- import importlib
- import json
- import os
- import sys
- from datetime import datetime
- from pathlib import Path
- from typing import Optional
- from dotenv import load_dotenv
- from sqlalchemy import desc, or_
- from examples.piaoquan_demand.config import LOG_LEVEL, ENABLED_TOOLS
- from examples.piaoquan_demand.db_manager import DatabaseManager
- from examples.piaoquan_demand.models import TopicPatternExecution
- from examples.piaoquan_demand.pattern_toos.pattern_service import run_mining
- from examples.piaoquan_demand.prepare import prepare
- from examples.piaoquan_demand.topic_build_agent_context import TopicBuildAgentContext
- from examples.piaoquan_demand.mysql import mysql_db
- # Clash Verge TUN 模式兼容:禁止 httpx/urllib 自动检测系统 HTTP 代理
- os.environ.setdefault("no_proxy", "*")
- # 该示例仅使用项目侧能力,禁用框架内置 skills
- os.environ.setdefault("AGENT_DISABLE_BUILTIN_SKILLS", "1")
- # 禁用内置工具自动注册,并开启严格工具白名单
- os.environ.setdefault("AGENT_DISABLE_BUILTIN_TOOL_REGISTRATION", "1")
- os.environ.setdefault("AGENT_STRICT_TOOL_SELECTION", "1")
- # 禁用所有侧分支(压缩/反思)
- os.environ.setdefault("AGENT_DISABLE_SIDE_BRANCHES", "1")
- # 添加项目根目录到 Python 路径
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
- load_dotenv()
- from agent.core.runner import AgentRunner
- from agent.llm import create_openrouter_llm_call
- from agent.llm.prompts import SimplePrompt
- from agent.trace import FileSystemTraceStore, Message, Trace
- from agent.utils import setup_logging
- from log_capture import build_log, log
- # 导入项目配置
- from examples.piaoquan_demand.config import DEBUG, LOG_FILE, LOG_LEVEL, RUN_CONFIG, TRACE_STORE_PATH
- CUSTOM_TOOL_MODULES = {
- # piaoquan_demand 示例:严格按工具名白名单加载对应模块
- "think_and_plan": "examples.piaoquan_demand.agent_tools",
- "get_category_tree": "examples.piaoquan_demand.topic_build_pattern_tools",
- "get_frequent_itemsets": "examples.piaoquan_demand.topic_build_pattern_tools",
- "get_itemset_detail": "examples.piaoquan_demand.topic_build_pattern_tools",
- "get_post_elements": "examples.piaoquan_demand.topic_build_pattern_tools",
- "search_elements": "examples.piaoquan_demand.topic_build_pattern_tools",
- "get_element_category_chain": "examples.piaoquan_demand.topic_build_pattern_tools",
- "get_category_detail": "examples.piaoquan_demand.topic_build_pattern_tools",
- "search_categories": "examples.piaoquan_demand.topic_build_pattern_tools",
- "get_category_elements": "examples.piaoquan_demand.topic_build_pattern_tools",
- "get_category_co_occurrences": "examples.piaoquan_demand.topic_build_pattern_tools",
- "get_element_co_occurrences": "examples.piaoquan_demand.topic_build_pattern_tools",
- "get_weight_score_topn": "examples.piaoquan_demand.weight_score_query_tools",
- "get_weight_score_by_name": "examples.piaoquan_demand.weight_score_query_tools",
- "create_demand_item": "examples.piaoquan_demand.demand_build_agent_tools",
- "create_demand_items": "examples.piaoquan_demand.demand_build_agent_tools",
- "write_execution_summary": "examples.piaoquan_demand.demand_build_agent_tools",
- }
- def get_execution_id_by_merge_level2(cluster_name: str):
- """根据二级品类和平台查询最新的 execution_id。"""
- session = DatabaseManager().get_session()
- try:
- start_of_today = datetime.combine(datetime.now().date(), datetime.min.time())
- query = session.query(TopicPatternExecution).filter(
- TopicPatternExecution.cluster_name == cluster_name,
- TopicPatternExecution.start_time >= start_of_today,
- )
- execution = query.order_by(desc(TopicPatternExecution.id)).first()
- if not execution:
- return None
- return execution.id
- finally:
- session.close()
- def resolve_model(prompt: SimplePrompt) -> str:
- model_from_prompt = prompt.config.get("model")
- if model_from_prompt:
- return model_from_prompt
- return f"anthropic/{RUN_CONFIG.model}" if "/" not in RUN_CONFIG.model else RUN_CONFIG.model
- def extract_assistant_text(message: Message) -> str:
- if message.role != "assistant":
- return ""
- content = message.content
- if isinstance(content, str):
- return content
- if isinstance(content, dict):
- text = content.get("text", "")
- # 即使本轮包含工具调用,也打印模型给出的文本,便于观察每一步输出
- if text:
- return text
- return ""
- def register_selected_tools(tool_names: list[str]) -> None:
- for tool_name in tool_names:
- module_path = CUSTOM_TOOL_MODULES.get(tool_name)
- if not module_path:
- raise ValueError(f"未配置工具模块映射: {tool_name}")
- importlib.import_module(module_path)
- def _join_element_names_to_name(element_names: object) -> str:
- """把 tool 入参 element_names 转成 demand_content.name(逗号分隔)。"""
- if element_names is None:
- return ""
- if isinstance(element_names, list):
- parts = [str(x).strip() for x in element_names if x is not None and str(x).strip()]
- return ",".join(parts)
- # 兼容异常数据:比如历史版本可能传了字符串
- return str(element_names).strip()
- def _safe_truncate(s: object, max_len: int) -> str:
- if s is None:
- return ""
- s_str = str(s)
- if max_len and len(s_str) > max_len:
- return s_str[:max_len]
- return s_str
- def _load_name_score_map(execution_id: int) -> dict:
- """读取 data/{execution_id} 下所有 JSON 的 name->score(同名取最高分)。"""
- data_dir = Path(__file__).parent / "data" / str(execution_id)
- if not data_dir.exists():
- return {}
- score_map = {}
- for json_path in data_dir.glob("*.json"):
- try:
- with open(json_path, "r", encoding="utf-8") as f:
- payload = json.load(f)
- except Exception:
- continue
- if not isinstance(payload, list):
- continue
- for item in payload:
- if not isinstance(item, dict):
- continue
- name = item.get("name")
- score = item.get("score")
- if isinstance(name, str) and isinstance(score, (int, float)):
- prev = score_map.get(name)
- score_f = float(score)
- if prev is None or score_f > prev:
- score_map[name] = score_f
- return score_map
- def _avg_score_for_joined_name(name: str, score_map: dict) -> float:
- """按逗号拆分 name,分别取分后求平均。"""
- parts = [part.strip() for part in str(name).split(",") if part and part.strip()]
- if not parts:
- return 0.0
- return sum(float(score_map.get(part, 0.0)) for part in parts) / len(parts)
- def _create_demand_task(execution_id: int) -> Optional[int]:
- """创建 demand_task 记录,返回任务ID。"""
- try:
- task_id = mysql_db.insert(
- "demand_task",
- {
- "execution_id": execution_id,
- "status": 0,
- "log": "",
- },
- )
- log(f"[task] 创建 demand_task 成功,task_id={task_id}, execution_id={execution_id}")
- return task_id
- except Exception as e:
- log(f"[task] 创建 demand_task 失败,execution_id={execution_id}, error={e}")
- return None
- def _finish_demand_task(task_id: Optional[int], status: int, task_log: str) -> None:
- """更新 demand_task 状态与日志。"""
- if not task_id:
- return
- try:
- mysql_db.update(
- "demand_task",
- {
- "status": int(status),
- "log": task_log or "",
- },
- "id = %s",
- (task_id,),
- )
- log(f"[task] 更新 demand_task 成功,task_id={task_id}, status={status}")
- except Exception as e:
- log(f"[task] 更新 demand_task 失败,task_id={task_id}, status={status}, error={e}")
- def write_demand_items_to_mysql(execution_id: int, merge_level2: str) -> int:
- """
- 把 result/{execution_id}/execution_id_{execution_id}_demand_items.json
- 写入 MySQL 表 demand_content
- """
- # create_demand_item(s) 使用 Path.cwd()/result 作为输出目录。
- # 为了兼容“从不同目录启动脚本”的情况,这里同时尝试 cwd 和脚本目录两种结果位置。
- demand_items_path = (
- Path.cwd()
- / "result"
- / str(execution_id)
- / f"execution_id_{execution_id}_demand_items.json"
- )
- if not demand_items_path.exists():
- alt_path = (
- Path(__file__).parent
- / "result"
- / str(execution_id)
- / f"execution_id_{execution_id}_demand_items.json"
- )
- if alt_path.exists():
- demand_items_path = alt_path
- else:
- log(f"[mysql] 未找到需求 JSON:{demand_items_path}(也未找到 {alt_path}),跳过写入")
- return 0
- try:
- with open(demand_items_path, "r", encoding="utf-8") as f:
- loaded = json.load(f)
- except Exception as e:
- log(f"[mysql] 读取需求 JSON 失败:{demand_items_path},error={e}")
- return 0
- items = loaded["items"] if isinstance(loaded, dict) and isinstance(loaded.get("items"), list) else loaded
- if not isinstance(items, list):
- log(f"[mysql] 需求 JSON 非数组,跳过写入:type={type(items)}")
- return 0
- score_map = _load_name_score_map(execution_id)
- rows: list[dict] = []
- for di in items:
- if not isinstance(di, dict):
- continue
- name = _join_element_names_to_name(di.get("element_names"))
- if not name:
- continue
- score = _avg_score_for_joined_name(name, score_map)
- reason = di.get("reason")
- desc_value = di.get("desc")
- ext_data = {"reason": reason, "desc": desc_value}
- rows.append(
- {
- "merge_leve2": _safe_truncate(merge_level2, 32),
- "name": _safe_truncate(name, 64),
- "score": float(score),
- "ext_data": json.dumps(ext_data, ensure_ascii=False),
- }
- )
- if not rows:
- log("[mysql] 生成行为空,跳过写入")
- return 0
- affected = mysql_db.insert_many("demand_content", rows)
- log(f"[mysql] 写入 demand_content 完成,rows={len(rows)}, affected={affected}")
- return len(rows)
- async def run_once(execution_id, merge_level2) -> str:
- task_id = _create_demand_task(execution_id=execution_id)
- task_status = 2
- task_log_text = ""
- TopicBuildAgentContext.set_execution_id(execution_id)
- prepare(execution_id)
- base_dir = Path(__file__).parent
- output_dir = base_dir / "output"
- output_dir.mkdir(exist_ok=True)
- setup_logging(level=LOG_LEVEL, file=LOG_FILE)
- register_selected_tools(ENABLED_TOOLS)
- prompt = SimplePrompt(base_dir / "demand.md")
- model = resolve_model(prompt)
- run_config = copy.deepcopy(RUN_CONFIG)
- run_config.temperature = float(prompt.config.get("temperature", run_config.temperature))
- run_config.max_iterations = int(prompt.config.get("max_iterations", run_config.max_iterations))
- run_config.tools = ENABLED_TOOLS.copy()
- # 禁用反思/总结经验相关流程(避免进入 reflection 侧分支)
- run_config.enable_research_flow = False
- run_config.goal_compression = "none"
- run_config.force_side_branch = None
- run_config.knowledge.enable_extraction = False
- run_config.knowledge.enable_completion_extraction = False
- run_config.knowledge.enable_injection = False
- run_config.trace_id = None
- initial_messages = prompt.build_messages(merge_level2=merge_level2)
- store = FileSystemTraceStore(base_path=TRACE_STORE_PATH)
- runner = AgentRunner(
- trace_store=store,
- llm_call=create_openrouter_llm_call(model=model),
- skills_dir=None,
- debug=DEBUG,
- )
- final_text = ""
- total_tokens = 0
- total_cost = 0.0
- has_completed_trace_cost = False
- log_file_path = output_dir / f"{execution_id}" / f"run_log_{execution_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
- log_file_path.parent.mkdir(parents=True, exist_ok=True)
- try:
- with build_log(execution_id) as log_buffer:
- async for item in runner.run(messages=initial_messages, config=run_config):
- if isinstance(item, Trace):
- if getattr(item, "status", None) == "completed":
- total_tokens = int(getattr(item, "total_tokens", 0) or 0)
- total_cost = float(getattr(item, "total_cost", 0.0) or 0.0)
- has_completed_trace_cost = True
- continue
- elif isinstance(item, Message):
- text = extract_assistant_text(item)
- if text:
- final_text = text
- log(f"[assistant] {text}")
- if not has_completed_trace_cost:
- total_tokens += int(getattr(item, "total_tokens", 0) or 0)
- total_cost += float(getattr(item, "cost", 0.0) or 0.0)
- if final_text:
- output_file = output_dir / f"{execution_id}" / "result.txt"
- output_file.parent.mkdir(parents=True, exist_ok=True)
- with open(output_file, "w", encoding="utf-8") as f:
- f.write(final_text)
- log(f"[cost] total_tokens={total_tokens}, total_cost=${total_cost:.6f}")
- # agent 执行完成后:把本地 result JSON 写入 MySQL 表 demand_content
- # element_names -> name(逗号分隔);reason/desc -> ext_data JSON;merge_leve2 -> demand_content.merge_leve2
- try:
- write_demand_items_to_mysql(execution_id=execution_id, merge_level2=merge_level2)
- except Exception as e:
- log(f"[mysql] 写入 demand_content 异常:{e}")
- task_log_text = log_buffer.getvalue()
- with open(log_file_path, "w", encoding="utf-8") as f:
- f.write(task_log_text)
- task_status = 1
- except Exception as e:
- if not task_log_text:
- task_log_text = f"[run] 执行异常: {e}"
- task_status = 2
- raise
- finally:
- _finish_demand_task(task_id=task_id, status=task_status, task_log=task_log_text)
- return final_text
- async def main() -> None:
- cluster_name = '历史名人'
- execution_id = get_execution_id_by_merge_level2(cluster_name=cluster_name)
- if execution_id is None:
- execution_id = run_mining(cluster_name=cluster_name, merge_leve2=cluster_name)
- print(execution_id)
- # await run_once(execution_id, cluster_name)
- if __name__ == "__main__":
- asyncio.run(main())
|