run.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. """piaoquan_demand 示例的最小可运行入口。"""
  2. import asyncio
  3. import copy
  4. import importlib
  5. import json
  6. import os
  7. import sys
  8. from datetime import datetime
  9. from pathlib import Path
  10. from typing import Optional
  11. from dotenv import load_dotenv
  12. from sqlalchemy import desc, or_
  13. from examples.piaoquan_demand.config import LOG_LEVEL, ENABLED_TOOLS
  14. from examples.piaoquan_demand.db_manager import DatabaseManager
  15. from examples.piaoquan_demand.models import TopicPatternExecution
  16. from examples.piaoquan_demand.pattern_toos.pattern_service import run_mining
  17. from examples.piaoquan_demand.prepare import prepare
  18. from examples.piaoquan_demand.topic_build_agent_context import TopicBuildAgentContext
  19. from examples.piaoquan_demand.mysql import mysql_db
  20. # Clash Verge TUN 模式兼容:禁止 httpx/urllib 自动检测系统 HTTP 代理
  21. os.environ.setdefault("no_proxy", "*")
  22. # 该示例仅使用项目侧能力,禁用框架内置 skills
  23. os.environ.setdefault("AGENT_DISABLE_BUILTIN_SKILLS", "1")
  24. # 禁用内置工具自动注册,并开启严格工具白名单
  25. os.environ.setdefault("AGENT_DISABLE_BUILTIN_TOOL_REGISTRATION", "1")
  26. os.environ.setdefault("AGENT_STRICT_TOOL_SELECTION", "1")
  27. # 禁用所有侧分支(压缩/反思)
  28. os.environ.setdefault("AGENT_DISABLE_SIDE_BRANCHES", "1")
  29. # 添加项目根目录到 Python 路径
  30. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  31. load_dotenv()
  32. from agent.core.runner import AgentRunner
  33. from agent.llm import create_openrouter_llm_call
  34. from agent.llm.prompts import SimplePrompt
  35. from agent.trace import FileSystemTraceStore, Message, Trace
  36. from agent.utils import setup_logging
  37. from log_capture import build_log, log
  38. # 导入项目配置
  39. from examples.piaoquan_demand.config import DEBUG, LOG_FILE, LOG_LEVEL, RUN_CONFIG, TRACE_STORE_PATH
  40. CUSTOM_TOOL_MODULES = {
  41. # piaoquan_demand 示例:严格按工具名白名单加载对应模块
  42. "think_and_plan": "examples.piaoquan_demand.agent_tools",
  43. "get_category_tree": "examples.piaoquan_demand.topic_build_pattern_tools",
  44. "get_frequent_itemsets": "examples.piaoquan_demand.topic_build_pattern_tools",
  45. "get_itemset_detail": "examples.piaoquan_demand.topic_build_pattern_tools",
  46. "get_post_elements": "examples.piaoquan_demand.topic_build_pattern_tools",
  47. "search_elements": "examples.piaoquan_demand.topic_build_pattern_tools",
  48. "get_element_category_chain": "examples.piaoquan_demand.topic_build_pattern_tools",
  49. "get_category_detail": "examples.piaoquan_demand.topic_build_pattern_tools",
  50. "search_categories": "examples.piaoquan_demand.topic_build_pattern_tools",
  51. "get_category_elements": "examples.piaoquan_demand.topic_build_pattern_tools",
  52. "get_category_co_occurrences": "examples.piaoquan_demand.topic_build_pattern_tools",
  53. "get_element_co_occurrences": "examples.piaoquan_demand.topic_build_pattern_tools",
  54. "get_weight_score_topn": "examples.piaoquan_demand.weight_score_query_tools",
  55. "get_weight_score_by_name": "examples.piaoquan_demand.weight_score_query_tools",
  56. "create_demand_item": "examples.piaoquan_demand.demand_build_agent_tools",
  57. "create_demand_items": "examples.piaoquan_demand.demand_build_agent_tools",
  58. "write_execution_summary": "examples.piaoquan_demand.demand_build_agent_tools",
  59. }
  60. def get_execution_id_by_merge_level2(cluster_name: str):
  61. """根据二级品类和平台查询最新的 execution_id。"""
  62. session = DatabaseManager().get_session()
  63. try:
  64. start_of_today = datetime.combine(datetime.now().date(), datetime.min.time())
  65. query = session.query(TopicPatternExecution).filter(
  66. TopicPatternExecution.cluster_name == cluster_name,
  67. TopicPatternExecution.start_time >= start_of_today,
  68. )
  69. execution = query.order_by(desc(TopicPatternExecution.id)).first()
  70. if not execution:
  71. return None
  72. return execution.id
  73. finally:
  74. session.close()
  75. def resolve_model(prompt: SimplePrompt) -> str:
  76. model_from_prompt = prompt.config.get("model")
  77. if model_from_prompt:
  78. return model_from_prompt
  79. return f"anthropic/{RUN_CONFIG.model}" if "/" not in RUN_CONFIG.model else RUN_CONFIG.model
  80. def extract_assistant_text(message: Message) -> str:
  81. if message.role != "assistant":
  82. return ""
  83. content = message.content
  84. if isinstance(content, str):
  85. return content
  86. if isinstance(content, dict):
  87. text = content.get("text", "")
  88. # 即使本轮包含工具调用,也打印模型给出的文本,便于观察每一步输出
  89. if text:
  90. return text
  91. return ""
  92. def register_selected_tools(tool_names: list[str]) -> None:
  93. for tool_name in tool_names:
  94. module_path = CUSTOM_TOOL_MODULES.get(tool_name)
  95. if not module_path:
  96. raise ValueError(f"未配置工具模块映射: {tool_name}")
  97. importlib.import_module(module_path)
  98. def _join_element_names_to_name(element_names: object) -> str:
  99. """把 tool 入参 element_names 转成 demand_content.name(逗号分隔)。"""
  100. if element_names is None:
  101. return ""
  102. if isinstance(element_names, list):
  103. parts = [str(x).strip() for x in element_names if x is not None and str(x).strip()]
  104. return ",".join(parts)
  105. # 兼容异常数据:比如历史版本可能传了字符串
  106. return str(element_names).strip()
  107. def _safe_truncate(s: object, max_len: int) -> str:
  108. if s is None:
  109. return ""
  110. s_str = str(s)
  111. if max_len and len(s_str) > max_len:
  112. return s_str[:max_len]
  113. return s_str
  114. def _load_name_score_map(execution_id: int) -> dict:
  115. """读取 data/{execution_id} 下所有 JSON 的 name->score(同名取最高分)。"""
  116. data_dir = Path(__file__).parent / "data" / str(execution_id)
  117. if not data_dir.exists():
  118. return {}
  119. score_map = {}
  120. for json_path in data_dir.glob("*.json"):
  121. try:
  122. with open(json_path, "r", encoding="utf-8") as f:
  123. payload = json.load(f)
  124. except Exception:
  125. continue
  126. if not isinstance(payload, list):
  127. continue
  128. for item in payload:
  129. if not isinstance(item, dict):
  130. continue
  131. name = item.get("name")
  132. score = item.get("score")
  133. if isinstance(name, str) and isinstance(score, (int, float)):
  134. prev = score_map.get(name)
  135. score_f = float(score)
  136. if prev is None or score_f > prev:
  137. score_map[name] = score_f
  138. return score_map
  139. def _avg_score_for_joined_name(name: str, score_map: dict) -> float:
  140. """按逗号拆分 name,分别取分后求平均。"""
  141. parts = [part.strip() for part in str(name).split(",") if part and part.strip()]
  142. if not parts:
  143. return 0.0
  144. return sum(float(score_map.get(part, 0.0)) for part in parts) / len(parts)
  145. def _create_demand_task(execution_id: int) -> Optional[int]:
  146. """创建 demand_task 记录,返回任务ID。"""
  147. try:
  148. task_id = mysql_db.insert(
  149. "demand_task",
  150. {
  151. "execution_id": execution_id,
  152. "status": 0,
  153. "log": "",
  154. },
  155. )
  156. log(f"[task] 创建 demand_task 成功,task_id={task_id}, execution_id={execution_id}")
  157. return task_id
  158. except Exception as e:
  159. log(f"[task] 创建 demand_task 失败,execution_id={execution_id}, error={e}")
  160. return None
  161. def _finish_demand_task(task_id: Optional[int], status: int, task_log: str) -> None:
  162. """更新 demand_task 状态与日志。"""
  163. if not task_id:
  164. return
  165. try:
  166. mysql_db.update(
  167. "demand_task",
  168. {
  169. "status": int(status),
  170. "log": task_log or "",
  171. },
  172. "id = %s",
  173. (task_id,),
  174. )
  175. log(f"[task] 更新 demand_task 成功,task_id={task_id}, status={status}")
  176. except Exception as e:
  177. log(f"[task] 更新 demand_task 失败,task_id={task_id}, status={status}, error={e}")
  178. def write_demand_items_to_mysql(execution_id: int, merge_level2: str) -> int:
  179. """
  180. 把 result/{execution_id}/execution_id_{execution_id}_demand_items.json
  181. 写入 MySQL 表 demand_content
  182. """
  183. # create_demand_item(s) 使用 Path.cwd()/result 作为输出目录。
  184. # 为了兼容“从不同目录启动脚本”的情况,这里同时尝试 cwd 和脚本目录两种结果位置。
  185. demand_items_path = (
  186. Path.cwd()
  187. / "result"
  188. / str(execution_id)
  189. / f"execution_id_{execution_id}_demand_items.json"
  190. )
  191. if not demand_items_path.exists():
  192. alt_path = (
  193. Path(__file__).parent
  194. / "result"
  195. / str(execution_id)
  196. / f"execution_id_{execution_id}_demand_items.json"
  197. )
  198. if alt_path.exists():
  199. demand_items_path = alt_path
  200. else:
  201. log(f"[mysql] 未找到需求 JSON:{demand_items_path}(也未找到 {alt_path}),跳过写入")
  202. return 0
  203. try:
  204. with open(demand_items_path, "r", encoding="utf-8") as f:
  205. loaded = json.load(f)
  206. except Exception as e:
  207. log(f"[mysql] 读取需求 JSON 失败:{demand_items_path},error={e}")
  208. return 0
  209. items = loaded["items"] if isinstance(loaded, dict) and isinstance(loaded.get("items"), list) else loaded
  210. if not isinstance(items, list):
  211. log(f"[mysql] 需求 JSON 非数组,跳过写入:type={type(items)}")
  212. return 0
  213. score_map = _load_name_score_map(execution_id)
  214. rows: list[dict] = []
  215. for di in items:
  216. if not isinstance(di, dict):
  217. continue
  218. name = _join_element_names_to_name(di.get("element_names"))
  219. if not name:
  220. continue
  221. score = _avg_score_for_joined_name(name, score_map)
  222. reason = di.get("reason")
  223. desc_value = di.get("desc")
  224. ext_data = {"reason": reason, "desc": desc_value}
  225. rows.append(
  226. {
  227. "merge_leve2": _safe_truncate(merge_level2, 32),
  228. "name": _safe_truncate(name, 64),
  229. "score": float(score),
  230. "ext_data": json.dumps(ext_data, ensure_ascii=False),
  231. }
  232. )
  233. if not rows:
  234. log("[mysql] 生成行为空,跳过写入")
  235. return 0
  236. affected = mysql_db.insert_many("demand_content", rows)
  237. log(f"[mysql] 写入 demand_content 完成,rows={len(rows)}, affected={affected}")
  238. return len(rows)
  239. async def run_once(execution_id, merge_level2) -> str:
  240. task_id = _create_demand_task(execution_id=execution_id)
  241. task_status = 2
  242. task_log_text = ""
  243. TopicBuildAgentContext.set_execution_id(execution_id)
  244. prepare(execution_id)
  245. base_dir = Path(__file__).parent
  246. output_dir = base_dir / "output"
  247. output_dir.mkdir(exist_ok=True)
  248. setup_logging(level=LOG_LEVEL, file=LOG_FILE)
  249. register_selected_tools(ENABLED_TOOLS)
  250. prompt = SimplePrompt(base_dir / "demand.md")
  251. model = resolve_model(prompt)
  252. run_config = copy.deepcopy(RUN_CONFIG)
  253. run_config.temperature = float(prompt.config.get("temperature", run_config.temperature))
  254. run_config.max_iterations = int(prompt.config.get("max_iterations", run_config.max_iterations))
  255. run_config.tools = ENABLED_TOOLS.copy()
  256. # 禁用反思/总结经验相关流程(避免进入 reflection 侧分支)
  257. run_config.enable_research_flow = False
  258. run_config.goal_compression = "none"
  259. run_config.force_side_branch = None
  260. run_config.knowledge.enable_extraction = False
  261. run_config.knowledge.enable_completion_extraction = False
  262. run_config.knowledge.enable_injection = False
  263. run_config.trace_id = None
  264. initial_messages = prompt.build_messages(merge_level2=merge_level2)
  265. store = FileSystemTraceStore(base_path=TRACE_STORE_PATH)
  266. runner = AgentRunner(
  267. trace_store=store,
  268. llm_call=create_openrouter_llm_call(model=model),
  269. skills_dir=None,
  270. debug=DEBUG,
  271. )
  272. final_text = ""
  273. total_tokens = 0
  274. total_cost = 0.0
  275. has_completed_trace_cost = False
  276. log_file_path = output_dir / f"{execution_id}" / f"run_log_{execution_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
  277. log_file_path.parent.mkdir(parents=True, exist_ok=True)
  278. try:
  279. with build_log(execution_id) as log_buffer:
  280. async for item in runner.run(messages=initial_messages, config=run_config):
  281. if isinstance(item, Trace):
  282. if getattr(item, "status", None) == "completed":
  283. total_tokens = int(getattr(item, "total_tokens", 0) or 0)
  284. total_cost = float(getattr(item, "total_cost", 0.0) or 0.0)
  285. has_completed_trace_cost = True
  286. continue
  287. elif isinstance(item, Message):
  288. text = extract_assistant_text(item)
  289. if text:
  290. final_text = text
  291. log(f"[assistant] {text}")
  292. if not has_completed_trace_cost:
  293. total_tokens += int(getattr(item, "total_tokens", 0) or 0)
  294. total_cost += float(getattr(item, "cost", 0.0) or 0.0)
  295. if final_text:
  296. output_file = output_dir / f"{execution_id}" / "result.txt"
  297. output_file.parent.mkdir(parents=True, exist_ok=True)
  298. with open(output_file, "w", encoding="utf-8") as f:
  299. f.write(final_text)
  300. log(f"[cost] total_tokens={total_tokens}, total_cost=${total_cost:.6f}")
  301. # agent 执行完成后:把本地 result JSON 写入 MySQL 表 demand_content
  302. # element_names -> name(逗号分隔);reason/desc -> ext_data JSON;merge_leve2 -> demand_content.merge_leve2
  303. try:
  304. write_demand_items_to_mysql(execution_id=execution_id, merge_level2=merge_level2)
  305. except Exception as e:
  306. log(f"[mysql] 写入 demand_content 异常:{e}")
  307. task_log_text = log_buffer.getvalue()
  308. with open(log_file_path, "w", encoding="utf-8") as f:
  309. f.write(task_log_text)
  310. task_status = 1
  311. except Exception as e:
  312. if not task_log_text:
  313. task_log_text = f"[run] 执行异常: {e}"
  314. task_status = 2
  315. raise
  316. finally:
  317. _finish_demand_task(task_id=task_id, status=task_status, task_log=task_log_text)
  318. return final_text
  319. async def main() -> None:
  320. cluster_name = '历史名人'
  321. execution_id = get_execution_id_by_merge_level2(cluster_name=cluster_name)
  322. if execution_id is None:
  323. execution_id = run_mining(cluster_name=cluster_name, merge_leve2=cluster_name)
  324. print(execution_id)
  325. # await run_once(execution_id, cluster_name)
  326. if __name__ == "__main__":
  327. asyncio.run(main())