run.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. """demand 示例的最小可运行入口。"""
  2. import copy
  3. import importlib
  4. import json
  5. import os
  6. import sys
  7. from datetime import datetime
  8. from pathlib import Path
  9. from typing import Optional
  10. from dotenv import load_dotenv
  11. from sqlalchemy import desc, or_
  12. from examples.demand.changwen_prepare import changwen_prepare
  13. from examples.demand.config import LOG_LEVEL, ENABLED_TOOLS
  14. from examples.demand.db_manager import DatabaseManager
  15. from examples.demand.models import TopicPatternExecution
  16. from examples.demand.piaoquan_prepare import prepare, piaoquan_prepare
  17. from examples.demand.demand_agent_context import TopicBuildAgentContext
  18. from examples.demand.mysql import mysql_db
  19. # Clash Verge TUN 模式兼容:禁止 httpx/urllib 自动检测系统 HTTP 代理
  20. os.environ.setdefault("no_proxy", "*")
  21. # 该示例仅使用项目侧能力,禁用框架内置 skills
  22. os.environ.setdefault("AGENT_DISABLE_BUILTIN_SKILLS", "1")
  23. # 禁用内置工具自动注册,并开启严格工具白名单
  24. os.environ.setdefault("AGENT_DISABLE_BUILTIN_TOOL_REGISTRATION", "1")
  25. os.environ.setdefault("AGENT_STRICT_TOOL_SELECTION", "1")
  26. # 禁用所有侧分支(压缩/反思)
  27. os.environ.setdefault("AGENT_DISABLE_SIDE_BRANCHES", "1")
  28. # 添加项目根目录到 Python 路径
  29. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  30. load_dotenv()
  31. from agent.core.runner import AgentRunner
  32. from agent.llm import create_openrouter_llm_call
  33. from agent.llm.prompts import SimplePrompt
  34. from agent.trace import FileSystemTraceStore, Message, Trace
  35. from agent.utils import setup_logging
  36. from examples.demand.log_capture import build_log, log
  37. # 导入项目配置
  38. from examples.demand.config import DEBUG, LOG_FILE, LOG_LEVEL, RUN_CONFIG, TRACE_STORE_PATH
  39. CUSTOM_TOOL_MODULES = {
  40. # demand 示例:严格按工具名白名单加载对应模块
  41. "think_and_plan": "examples.demand.agent_tools",
  42. "get_category_tree": "examples.demand.demand_pattern_tools",
  43. "get_frequent_itemsets": "examples.demand.demand_pattern_tools",
  44. "get_itemset_detail": "examples.demand.demand_pattern_tools",
  45. "get_post_elements": "examples.demand.demand_pattern_tools",
  46. "search_elements": "examples.demand.demand_pattern_tools",
  47. "get_element_category_chain": "examples.demand.demand_pattern_tools",
  48. "get_category_detail": "examples.demand.demand_pattern_tools",
  49. "search_categories": "examples.demand.demand_pattern_tools",
  50. "get_category_elements": "examples.demand.demand_pattern_tools",
  51. "get_category_co_occurrences": "examples.demand.demand_pattern_tools",
  52. "get_element_co_occurrences": "examples.demand.demand_pattern_tools",
  53. "get_weight_score_topn": "examples.demand.weight_score_query_tools",
  54. "get_weight_score_by_name": "examples.demand.weight_score_query_tools",
  55. "create_demand_item": "examples.demand.demand_build_agent_tools",
  56. "create_demand_items": "examples.demand.demand_build_agent_tools",
  57. "write_execution_summary": "examples.demand.demand_build_agent_tools",
  58. }
  59. def get_execution_id_by_merge_level2(cluster_name: str):
  60. """根据二级品类和平台查询最新的 execution_id。"""
  61. session = DatabaseManager().get_session()
  62. try:
  63. start_of_today = datetime.combine(datetime.now().date(), datetime.min.time())
  64. query = session.query(TopicPatternExecution).filter(
  65. TopicPatternExecution.cluster_name == cluster_name,
  66. TopicPatternExecution.start_time >= start_of_today,
  67. )
  68. execution = query.order_by(desc(TopicPatternExecution.id)).first()
  69. if not execution:
  70. return None
  71. return execution.id
  72. finally:
  73. session.close()
  74. def resolve_model(prompt: SimplePrompt) -> str:
  75. model_from_prompt = prompt.config.get("model")
  76. if model_from_prompt:
  77. return model_from_prompt
  78. return f"anthropic/{RUN_CONFIG.model}" if "/" not in RUN_CONFIG.model else RUN_CONFIG.model
  79. def extract_assistant_text(message: Message) -> str:
  80. if message.role != "assistant":
  81. return ""
  82. content = message.content
  83. if isinstance(content, str):
  84. return content
  85. if isinstance(content, dict):
  86. text = content.get("text", "")
  87. # 即使本轮包含工具调用,也打印模型给出的文本,便于观察每一步输出
  88. if text:
  89. return text
  90. return ""
  91. def register_selected_tools(tool_names: list[str]) -> None:
  92. for tool_name in tool_names:
  93. module_path = CUSTOM_TOOL_MODULES.get(tool_name)
  94. if not module_path:
  95. raise ValueError(f"未配置工具模块映射: {tool_name}")
  96. importlib.import_module(module_path)
  97. def _join_element_names_to_name(element_names: object) -> str:
  98. """把 tool 入参 element_names 转成 demand_content.name(逗号分隔)。"""
  99. if element_names is None:
  100. return ""
  101. if isinstance(element_names, list):
  102. parts = [str(x).strip() for x in element_names if x is not None and str(x).strip()]
  103. return ",".join(parts)
  104. # 兼容异常数据:比如历史版本可能传了字符串
  105. return str(element_names).strip()
  106. def _safe_truncate(s: object, max_len: int) -> str:
  107. if s is None:
  108. return ""
  109. s_str = str(s)
  110. if max_len and len(s_str) > max_len:
  111. return s_str[:max_len]
  112. return s_str
  113. def _load_name_score_map(execution_id: int) -> dict:
  114. """读取 data/{execution_id} 下所有 JSON 的 name->score(同名取最高分)。"""
  115. data_dir = Path(__file__).parent / "data" / str(execution_id)
  116. if not data_dir.exists():
  117. return {}
  118. score_map = {}
  119. for json_path in data_dir.glob("*.json"):
  120. try:
  121. with open(json_path, "r", encoding="utf-8") as f:
  122. payload = json.load(f)
  123. except Exception:
  124. continue
  125. if not isinstance(payload, list):
  126. continue
  127. for item in payload:
  128. if not isinstance(item, dict):
  129. continue
  130. name = item.get("name")
  131. score = item.get("score")
  132. if isinstance(name, str) and isinstance(score, (int, float)):
  133. prev = score_map.get(name)
  134. score_f = float(score)
  135. if prev is None or score_f > prev:
  136. score_map[name] = score_f
  137. return score_map
  138. def _avg_score_for_joined_name(name: str, score_map: dict) -> float:
  139. """按逗号拆分 name,分别取分后求平均。"""
  140. parts = [part.strip() for part in str(name).split(",") if part and part.strip()]
  141. if not parts:
  142. return 0.0
  143. return sum(float(score_map.get(part, 0.0)) for part in parts) / len(parts)
  144. def _create_demand_task(
  145. execution_id: int,
  146. name: Optional[str] = None,
  147. platform: Optional[str] = None,
  148. ) -> Optional[int]:
  149. """创建 demand_task 记录,返回任务ID。"""
  150. try:
  151. # 数据库字段 demand_task.name: varchar(32)
  152. if name is not None:
  153. name = str(name)[:32]
  154. # 数据库字段 demand_task.platform: varchar(32)
  155. if platform is not None:
  156. platform = str(platform)[:32]
  157. task_id = mysql_db.insert(
  158. "demand_task",
  159. {
  160. "execution_id": execution_id,
  161. "name": name,
  162. "platform": platform,
  163. "status": 0,
  164. "log": "",
  165. },
  166. )
  167. log(f"[task] 创建 demand_task 成功,task_id={task_id}, execution_id={execution_id}")
  168. return task_id
  169. except Exception as e:
  170. log(f"[task] 创建 demand_task 失败,execution_id={execution_id}, error={e}")
  171. return None
  172. def _finish_demand_task(task_id: Optional[int], status: int, task_log: str) -> None:
  173. """更新 demand_task 状态与日志。"""
  174. if not task_id:
  175. return
  176. try:
  177. mysql_db.update(
  178. "demand_task",
  179. {
  180. "status": int(status),
  181. "log": task_log or "",
  182. },
  183. "id = %s",
  184. (task_id,),
  185. )
  186. log(f"[task] 更新 demand_task 成功,task_id={task_id}, status={status}")
  187. except Exception as e:
  188. log(f"[task] 更新 demand_task 失败,task_id={task_id}, status={status}, error={e}")
  189. def write_demand_items_to_mysql(execution_id: int, merge_level2: str) -> int:
  190. """
  191. 把 result/{execution_id}/execution_id_{execution_id}_demand_items.json
  192. 写入 MySQL 表 demand_content
  193. """
  194. # create_demand_item(s) 使用 Path.cwd()/result 作为输出目录。
  195. # 为了兼容“从不同目录启动脚本”的情况,这里同时尝试 cwd 和脚本目录两种结果位置。
  196. demand_items_path = (
  197. Path.cwd()
  198. / "result"
  199. / str(execution_id)
  200. / f"execution_id_{execution_id}_demand_items.json"
  201. )
  202. if not demand_items_path.exists():
  203. alt_path = (
  204. Path(__file__).parent
  205. / "result"
  206. / str(execution_id)
  207. / f"execution_id_{execution_id}_demand_items.json"
  208. )
  209. if alt_path.exists():
  210. demand_items_path = alt_path
  211. else:
  212. log(f"[mysql] 未找到需求 JSON:{demand_items_path}(也未找到 {alt_path}),跳过写入")
  213. return 0
  214. try:
  215. with open(demand_items_path, "r", encoding="utf-8") as f:
  216. loaded = json.load(f)
  217. except Exception as e:
  218. log(f"[mysql] 读取需求 JSON 失败:{demand_items_path},error={e}")
  219. return 0
  220. items = loaded["items"] if isinstance(loaded, dict) and isinstance(loaded.get("items"), list) else loaded
  221. if not isinstance(items, list):
  222. log(f"[mysql] 需求 JSON 非数组,跳过写入:type={type(items)}")
  223. return 0
  224. score_map = _load_name_score_map(execution_id)
  225. rows: list[dict] = []
  226. for di in items:
  227. if not isinstance(di, dict):
  228. continue
  229. name = _join_element_names_to_name(di.get("element_names"))
  230. if not name:
  231. continue
  232. score = _avg_score_for_joined_name(name, score_map)
  233. reason = di.get("reason")
  234. desc_value = di.get("desc")
  235. ext_data = {"reason": reason, "desc": desc_value}
  236. rows.append(
  237. {
  238. "merge_leve2": _safe_truncate(merge_level2, 32),
  239. "name": _safe_truncate(name, 64),
  240. "score": float(score),
  241. "ext_data": json.dumps(ext_data, ensure_ascii=False),
  242. }
  243. )
  244. if not rows:
  245. log("[mysql] 生成行为空,跳过写入")
  246. return 0
  247. affected = mysql_db.insert_many("demand_content", rows)
  248. log(f"[mysql] 写入 demand_content 完成,rows={len(rows)}, affected={affected}")
  249. return len(rows)
  250. async def run_once(execution_id, merge_level2, task_id: Optional[int] = None) -> str:
  251. task_log_text = ""
  252. task_status = 0
  253. TopicBuildAgentContext.set_execution_id(execution_id)
  254. prepare(execution_id)
  255. base_dir = Path(__file__).parent
  256. output_dir = base_dir / "output"
  257. output_dir.mkdir(exist_ok=True)
  258. setup_logging(level=LOG_LEVEL, file=LOG_FILE)
  259. register_selected_tools(ENABLED_TOOLS)
  260. prompt = SimplePrompt(base_dir / "demand.md")
  261. model = resolve_model(prompt)
  262. run_config = copy.deepcopy(RUN_CONFIG)
  263. run_config.temperature = float(prompt.config.get("temperature", run_config.temperature))
  264. run_config.max_iterations = int(prompt.config.get("max_iterations", run_config.max_iterations))
  265. run_config.tools = ENABLED_TOOLS.copy()
  266. # 禁用反思/总结经验相关流程(避免进入 reflection 侧分支)
  267. run_config.enable_research_flow = False
  268. run_config.goal_compression = "none"
  269. run_config.force_side_branch = None
  270. run_config.knowledge.enable_extraction = False
  271. run_config.knowledge.enable_completion_extraction = False
  272. run_config.knowledge.enable_injection = False
  273. run_config.trace_id = None
  274. initial_messages = prompt.build_messages(merge_level2=merge_level2)
  275. store = FileSystemTraceStore(base_path=TRACE_STORE_PATH)
  276. runner = AgentRunner(
  277. trace_store=store,
  278. llm_call=create_openrouter_llm_call(model=model),
  279. skills_dir=None,
  280. debug=DEBUG,
  281. )
  282. final_text = ""
  283. total_tokens = 0
  284. total_cost = 0.0
  285. has_completed_trace_cost = False
  286. log_file_path = output_dir / f"{execution_id}" / f"run_log_{execution_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
  287. log_file_path.parent.mkdir(parents=True, exist_ok=True)
  288. try:
  289. with build_log(execution_id) as log_buffer:
  290. async for item in runner.run(messages=initial_messages, config=run_config):
  291. if isinstance(item, Trace):
  292. if getattr(item, "status", None) == "completed":
  293. total_tokens = int(getattr(item, "total_tokens", 0) or 0)
  294. total_cost = float(getattr(item, "total_cost", 0.0) or 0.0)
  295. has_completed_trace_cost = True
  296. continue
  297. elif isinstance(item, Message):
  298. text = extract_assistant_text(item)
  299. if text:
  300. final_text = text
  301. log(f"[assistant] {text}")
  302. if not has_completed_trace_cost:
  303. total_tokens += int(getattr(item, "total_tokens", 0) or 0)
  304. total_cost += float(getattr(item, "cost", 0.0) or 0.0)
  305. if final_text:
  306. output_file = output_dir / f"{execution_id}" / "result.txt"
  307. output_file.parent.mkdir(parents=True, exist_ok=True)
  308. with open(output_file, "w", encoding="utf-8") as f:
  309. f.write(final_text)
  310. log(f"[cost] total_tokens={total_tokens}, total_cost=${total_cost:.6f}")
  311. # agent 执行完成后:把本地 result JSON 写入 MySQL 表 demand_content
  312. # element_names -> name(逗号分隔);reason/desc -> ext_data JSON;merge_leve2 -> demand_content.merge_leve2
  313. try:
  314. write_demand_items_to_mysql(execution_id=execution_id, merge_level2=merge_level2)
  315. except Exception as e:
  316. log(f"[mysql] 写入 demand_content 异常:{e}")
  317. task_log_text = log_buffer.getvalue()
  318. task_status = 1
  319. except Exception as e:
  320. if not task_log_text:
  321. # 如果异常发生在 build_log 内部,尽量回收已产生的日志
  322. try:
  323. existing = locals().get("log_buffer")
  324. if existing is not None:
  325. task_log_text = existing.getvalue() # type: ignore[attr-defined]
  326. except Exception:
  327. pass
  328. if not task_log_text:
  329. task_log_text = f"[run] 执行异常: {e}"
  330. task_status = 2
  331. raise
  332. finally:
  333. if task_log_text:
  334. try:
  335. with open(log_file_path, "w", encoding="utf-8") as f:
  336. f.write(task_log_text)
  337. except Exception:
  338. # 兜底:即使写文件失败,也要确保 MySQL 状态被更新
  339. pass
  340. _finish_demand_task(task_id=task_id, status=task_status, task_log=task_log_text)
  341. return final_text
  342. async def main(
  343. cluster_name: str,
  344. platform_type: str,
  345. execution_id: Optional[int] = None,
  346. task_id: Optional[int] = None,
  347. ) -> dict:
  348. if execution_id is None:
  349. if platform_type == "piaoquan":
  350. execution_id = piaoquan_prepare(cluster_name)
  351. elif platform_type == "changwen":
  352. execution_id = changwen_prepare(cluster_name)
  353. else:
  354. execution_id = None
  355. if not execution_id:
  356. return {"execution_id": None, "final_text": ""}
  357. final_text = await run_once(execution_id, cluster_name, task_id=task_id)
  358. return {"execution_id": execution_id, "final_text": final_text}
  359. if __name__ == "__main__":
  360. piaoquan_prepare('历史名人')
  361. # asyncio.run(main('小阳看天下', 'changwen'))