match_nodes.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758
  1. """
  2. 制作案例 → 制作需求 → 内容树节点匹配
  3. 流程:
  4. 1a. 逐帖子分析(带图片),提炼每个帖子的制作需求(并发)
  5. 1b. 合并去重相似需求,保留 case_id 溯源
  6. 2. 按合并后的需求语义搜索内容树 + 取父子节点
  7. 3. 每个需求独立过 LLM 挂载决策(并发)
  8. 4. 输出 case → demand → node 关系表 + 可视化
  9. """
  10. import asyncio
  11. import json
  12. import os
  13. import re
  14. import sys
  15. from pathlib import Path
  16. import httpx
  17. # 添加项目根目录
  18. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  19. from dotenv import load_dotenv
  20. load_dotenv()
  21. from agent.llm.qwen import qwen_llm_call
  22. # ===== 配置 =====
  23. BASE_DIR = Path(__file__).parent
  24. CONTENT_TREE_BASE = "http://8.147.104.190:8001"
  25. CATEGORY_TREE_PATH = BASE_DIR / "prompts" / "category_tree.json"
  26. DEFAULT_MODEL = "qwen-plus"
  27. VISION_MODEL = "qwen-vl-max"
  28. SOURCE_TYPES = ["实质", "形式", "意图"]
  29. # 加载本地分类树
  30. _CATEGORY_TREE_CACHE = None
  31. def load_category_tree() -> dict:
  32. """加载本地分类树(缓存)"""
  33. global _CATEGORY_TREE_CACHE
  34. if _CATEGORY_TREE_CACHE is None:
  35. with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f:
  36. _CATEGORY_TREE_CACHE = json.load(f)
  37. return _CATEGORY_TREE_CACHE
  38. def collect_all_nodes(node: dict, nodes: list, parent_path: str = ""):
  39. """递归收集所有节点,展平树结构"""
  40. if "id" in node:
  41. node_copy = {
  42. "entity_id": node["id"],
  43. "name": node["name"],
  44. "path": node.get("path", parent_path),
  45. "source_type": node.get("source_type"),
  46. "description": node.get("description") or "",
  47. "level": node.get("level"),
  48. "parent_id": node.get("parent_id"),
  49. "element_count": node.get("element_count", 0),
  50. }
  51. nodes.append(node_copy)
  52. if "children" in node:
  53. current_path = node.get("path", parent_path)
  54. for child in node["children"]:
  55. collect_all_nodes(child, nodes, current_path)
  56. def search_local_tree(query: str, source_type: str, top_k: int = 5) -> list:
  57. """在本地分类树中搜索节点"""
  58. tree = load_category_tree()
  59. all_nodes = []
  60. collect_all_nodes(tree, all_nodes)
  61. # 过滤:只保留指定维度
  62. filtered = [n for n in all_nodes if n.get("source_type") == source_type]
  63. # 简单文本匹配打分
  64. query_lower = query.lower()
  65. scored = []
  66. for node in filtered:
  67. name = node["name"].lower()
  68. desc = node["description"].lower()
  69. score = 0.0
  70. # 名称完全匹配
  71. if query_lower == name:
  72. score = 1.0
  73. # 名称包含
  74. elif query_lower in name:
  75. score = 0.8
  76. # 描述包含
  77. elif query_lower in desc:
  78. score = 0.5
  79. # 名称被包含(反向)
  80. elif name in query_lower:
  81. score = 0.6
  82. if score > 0:
  83. node["score"] = score
  84. node["entity_type"] = "category" # 本地树都是 category
  85. scored.append(node)
  86. # 按分数排序,取 top_k
  87. scored.sort(key=lambda x: x["score"], reverse=True)
  88. return scored[:top_k]
  89. # ===== Prompt 加载 =====
  90. def load_prompt(filename: str) -> dict:
  91. """加载 .prompt 文件,解析 frontmatter 和 $role$ 分段"""
  92. path = BASE_DIR / "prompts" / filename
  93. text = path.read_text(encoding="utf-8")
  94. config = {}
  95. if text.startswith("---"):
  96. _, fm, text = text.split("---", 2)
  97. for line in fm.strip().splitlines():
  98. if ":" in line:
  99. k, v = line.split(":", 1)
  100. k, v = k.strip(), v.strip()
  101. if v.replace(".", "", 1).isdigit():
  102. v = float(v) if "." in v else int(v)
  103. config[k] = v
  104. messages = []
  105. parts = re.split(r'^\$(\w+)\$\s*$', text.strip(), flags=re.MULTILINE)
  106. for i in range(1, len(parts), 2):
  107. role = parts[i].strip()
  108. content = parts[i + 1].strip() if i + 1 < len(parts) else ""
  109. messages.append({"role": role, "content": content})
  110. return {"config": config, "messages": messages}
  111. def render_messages(prompt_data: dict, variables: dict) -> list[dict]:
  112. """用变量替换 prompt 模板中的 {var} 占位符"""
  113. rendered = []
  114. for msg in prompt_data["messages"]:
  115. content = msg["content"]
  116. for k, v in variables.items():
  117. content = content.replace(f"{{{k}}}", str(v))
  118. rendered.append({"role": msg["role"], "content": content})
  119. return rendered
  120. def parse_json_response(content: str) -> list | dict:
  121. """清理 LLM 输出中的 markdown 包裹并解析 JSON"""
  122. content = content.strip()
  123. if content.startswith("```"):
  124. content = content.split("\n", 1)[1]
  125. content = content.rsplit("```", 1)[0]
  126. try:
  127. return json.loads(content)
  128. except json.JSONDecodeError as e:
  129. # 尝试自动修复:替换字符串中的未转义引号
  130. print(f"\n[JSON 解析错误] 尝试自动修复...")
  131. # 简单修复:把字符串值中的单独双引号替换为单引号
  132. import re
  133. # 匹配 "key": "value with "quote" inside"
  134. # 这个正则会找到字符串值中间的引号并替换
  135. fixed = re.sub(r'("(?:evidence|description|demand_name)":\s*")([^"]*)"([^"]*)"([^"]*")',
  136. r'\1\2\'\3\4', content)
  137. try:
  138. return json.loads(fixed)
  139. except json.JSONDecodeError:
  140. # 修复失败,打印错误位置
  141. start = max(0, e.pos - 100)
  142. end = min(len(content), e.pos + 100)
  143. print(f"\n[JSON 解析错误] 位置 {e.pos}:")
  144. print(f"...{content[start:end]}...")
  145. print(f"\n完整内容已保存到 json_error.txt")
  146. with open("json_error.txt", "w", encoding="utf-8") as f:
  147. f.write(content)
  148. raise
  149. # ===== 内容树 API =====
  150. async def search_content_tree(client: httpx.AsyncClient, query: str, source_type: str, top_k: int = 5) -> list:
  151. params = {
  152. "q": query, "source_type": source_type, "entity_type": "all",
  153. "top_k": top_k, "use_description": "true",
  154. "include_ancestors": "false", "descendant_depth": 0,
  155. }
  156. resp = await client.get(f"{CONTENT_TREE_BASE}/api/agent/search", params=params)
  157. resp.raise_for_status()
  158. return resp.json().get("results", [])
  159. async def get_category_tree(client: httpx.AsyncClient, entity_id: int, source_type: str) -> dict:
  160. params = {
  161. "source_type": source_type, "include_ancestors": "true", "descendant_depth": 1,
  162. }
  163. resp = await client.get(f"{CONTENT_TREE_BASE}/api/agent/search/category/{entity_id}", params=params)
  164. resp.raise_for_status()
  165. return resp.json()
  166. # ===== Step 1a:逐帖子提需求(支持图片) =====
  167. def build_case_content(case: dict) -> list[dict]:
  168. """
  169. 构造单个帖子的多模态消息内容。
  170. 返回 OpenAI vision 格式的 content 数组。
  171. """
  172. parts = []
  173. # 文本部分(兼容多种字段名)
  174. title = case.get("title") or case.get("video_title") or case.get("post_title", "未知")
  175. text_lines = [f"## 案例:{title}"]
  176. text_lines.append(f"来源:{case.get('source', '')}")
  177. if case.get("user_input"):
  178. text_lines.append(f"用户输入:{json.dumps(case['user_input'], ensure_ascii=False)}")
  179. if case.get("output_description"):
  180. text_lines.append(f"输出效果:{case['output_description']}")
  181. if case.get("key_findings"):
  182. text_lines.append(f"关键发现:{case['key_findings']}")
  183. parts.append({"type": "text", "text": "\n".join(text_lines)})
  184. # 图片部分(兼容多种字段名)
  185. images = case.get("images") or case.get("effect_images", [])
  186. for img_url in images[:4]: # 最多4张图,控制 token
  187. parts.append({
  188. "type": "image_url",
  189. "image_url": {"url": img_url},
  190. })
  191. return parts
  192. async def extract_demands_for_case(case: dict) -> dict:
  193. """对单个帖子提取需求"""
  194. prompt_data = load_prompt("step1_extract_demands.prompt")
  195. model = prompt_data["config"].get("model", VISION_MODEL)
  196. temperature = prompt_data["config"].get("temperature", 0.3)
  197. case_content = build_case_content(case)
  198. has_images = any(p["type"] == "image_url" for p in case_content)
  199. # 如果没有图片,降级用文本模型省成本
  200. if not has_images:
  201. model = DEFAULT_MODEL
  202. # 纯文本模式:content 用字符串即可
  203. text_only = "\n".join(p["text"] for p in case_content if p["type"] == "text")
  204. messages = render_messages(prompt_data, {"case_content": text_only})
  205. else:
  206. # 多模态模式:user message 的 content 用数组
  207. messages = []
  208. for msg in prompt_data["messages"]:
  209. if msg["role"] == "user":
  210. # 把 prompt 模板文本放在图片前面
  211. template_text = msg["content"].replace("{case_content}", "")
  212. content = case_content.copy()
  213. if template_text.strip():
  214. content.insert(0, {"type": "text", "text": template_text})
  215. messages.append({"role": "user", "content": content})
  216. else:
  217. messages.append(msg)
  218. try:
  219. result = await qwen_llm_call(messages, model=model, temperature=temperature)
  220. demands = parse_json_response(result["content"])
  221. case_id = case.get("case_id", 0)
  222. # 给每个需求打上 case_id
  223. for d in demands:
  224. d["source_case_id"] = case_id
  225. return {"case_id": case_id, "title": case.get("title", ""), "demands": demands}
  226. except Exception as e:
  227. case_id = case.get("case_id", 0)
  228. print(f" [!] case {case_id} 提取失败: {e}")
  229. return {"case_id": case_id, "title": case.get("title", ""), "demands": [], "error": str(e)}
  230. # ===== Step 1b:合并去重 =====
  231. async def merge_demands(all_case_demands: list[dict]) -> list[dict]:
  232. """用 LLM 合并相似需求,超过30个时分批处理"""
  233. # 构造输入:展平所有 case 的 demands,标注来源
  234. flat = []
  235. for cd in all_case_demands:
  236. for d in cd["demands"]:
  237. flat.append({
  238. "case_id": cd["case_id"],
  239. "case_title": cd["title"],
  240. "demand_name": d["demand_name"],
  241. "description": d["description"],
  242. "search_keywords": d["search_keywords"],
  243. "evidence": d.get("evidence", ""),
  244. })
  245. if not flat:
  246. return []
  247. prompt_data = load_prompt("step1b_merge_demands.prompt")
  248. model = prompt_data["config"].get("model", DEFAULT_MODEL)
  249. temperature = prompt_data["config"].get("temperature", 0.3)
  250. # 分批处理:每批最多30个需求
  251. BATCH_SIZE = 30
  252. if len(flat) <= BATCH_SIZE:
  253. # 不需要分批
  254. all_demands_json = json.dumps(flat, ensure_ascii=False, indent=2)
  255. messages = render_messages(prompt_data, {"all_demands_json": all_demands_json})
  256. result = await qwen_llm_call(messages, model=model, temperature=temperature)
  257. merged = parse_json_response(result["content"])
  258. # 统一转换 source_case_ids 为字符串
  259. for m in merged:
  260. m["source_case_ids"] = [str(cid) for cid in m.get("source_case_ids", [])]
  261. return merged
  262. else:
  263. # 分批合并
  264. print(f" 需求数量 {len(flat)} 超过 {BATCH_SIZE},分批合并...")
  265. batches = [flat[i:i + BATCH_SIZE] for i in range(0, len(flat), BATCH_SIZE)]
  266. batch_results = []
  267. for i, batch in enumerate(batches, 1):
  268. print(f" 批次 {i}/{len(batches)}: {len(batch)} 个需求")
  269. batch_json = json.dumps(batch, ensure_ascii=False, indent=2)
  270. messages = render_messages(prompt_data, {"all_demands_json": batch_json})
  271. result = await qwen_llm_call(messages, model=model, temperature=temperature)
  272. batch_merged = parse_json_response(result["content"])
  273. batch_results.extend(batch_merged)
  274. # 如果分批后的结果仍然很多,再做一次最终合并
  275. if len(batch_results) > BATCH_SIZE:
  276. print(f" 批次合并后仍有 {len(batch_results)} 个需求,进行最终合并...")
  277. final_json = json.dumps(batch_results, ensure_ascii=False, indent=2)
  278. messages = render_messages(prompt_data, {"all_demands_json": final_json})
  279. result = await qwen_llm_call(messages, model=model, temperature=temperature)
  280. final_merged = parse_json_response(result["content"])
  281. # 统一转换 source_case_ids 为字符串
  282. for m in final_merged:
  283. m["source_case_ids"] = [str(cid) for cid in m.get("source_case_ids", [])]
  284. return final_merged
  285. else:
  286. # 统一转换 source_case_ids 为字符串
  287. for m in batch_results:
  288. m["source_case_ids"] = [str(cid) for cid in m.get("source_case_ids", [])]
  289. return batch_results
  290. # ===== Step 2:语义搜索 + 取父子节点 =====
  291. async def step2_search_and_expand(demands: list[dict]) -> dict:
  292. """语义搜索本地分类树 + 取父子节点"""
  293. all_results = {}
  294. tree = load_category_tree()
  295. for demand in demands:
  296. name = demand["demand_name"]
  297. keywords = demand["search_keywords"]
  298. all_results[name] = {}
  299. for source_type in SOURCE_TYPES:
  300. nodes = []
  301. seen = set()
  302. for kw in keywords:
  303. # 使用本地搜索替代 API
  304. results = search_local_tree(kw, source_type, top_k=3)
  305. for r in results:
  306. eid = r.get("entity_id")
  307. if eid in seen:
  308. continue
  309. seen.add(eid)
  310. node = {
  311. "entity_id": eid,
  312. "name": r.get("name", ""),
  313. "entity_type": "category",
  314. "score": r.get("score", 0),
  315. "description": r.get("description", ""),
  316. "path": r.get("path", ""),
  317. }
  318. # 本地树暂不支持动态取父子,先留空
  319. node["ancestors"] = []
  320. node["children"] = []
  321. nodes.append(node)
  322. if nodes:
  323. all_results[name][source_type] = nodes
  324. return all_results
  325. # ===== Step 3:逐需求挂载决策(并发) =====
  326. def build_single_demand_context(demand: dict, nodes_by_dim: dict) -> str:
  327. parts = []
  328. parts.append(f"## 需求:{demand['demand_name']}")
  329. parts.append(f"描述:{demand['description']}")
  330. parts.append(f"来源帖子:case_id={demand.get('source_case_ids', [])}")
  331. if demand.get("evidence"):
  332. parts.append(f"数据依据:{demand['evidence']}")
  333. if not nodes_by_dim:
  334. parts.append("(未搜索到相关节点)")
  335. return "\n".join(parts)
  336. for source_type, nodes in nodes_by_dim.items():
  337. parts.append(f"\n### {source_type}维度匹配节点:")
  338. for n in nodes:
  339. line = f"- [{n['entity_type']}] entity_id={n['entity_id']} \"{n['name']}\" score={n['score']:.2f}"
  340. if n.get("description"):
  341. line += f" 描述: {n['description']}"
  342. parts.append(line)
  343. if n.get("ancestors"):
  344. path = " > ".join(a["name"] for a in n["ancestors"])
  345. parts.append(f" 父链: {path}")
  346. if n.get("children"):
  347. kids = ", ".join(c["name"] for c in n["children"][:10])
  348. parts.append(f" 子节点: {kids}")
  349. return "\n".join(parts)
  350. async def mount_single_demand(demand: dict, nodes_by_dim: dict) -> dict:
  351. prompt_data = load_prompt("step3_mount_decision.prompt")
  352. model = prompt_data["config"].get("model", DEFAULT_MODEL)
  353. temperature = prompt_data["config"].get("temperature", 0.3)
  354. node_context = build_single_demand_context(demand, nodes_by_dim)
  355. messages = render_messages(prompt_data, {"node_context": node_context})
  356. result = await qwen_llm_call(messages, model=model, temperature=temperature, max_tokens=4096)
  357. # 解析结构化 JSON 输出
  358. try:
  359. decision = parse_json_response(result["content"])
  360. except Exception as e:
  361. print(f" [!] 挂载决策解析失败 ({demand['demand_name']}): {e}")
  362. decision = {"demand_name": demand["demand_name"], "mounted_nodes": [], "notes": f"解析失败: {e}"}
  363. return {
  364. "demand_name": demand["demand_name"],
  365. "source_case_ids": demand.get("source_case_ids", []),
  366. "decision": decision,
  367. }
  368. async def step3_mount_decisions(demands: list[dict], search_results: dict) -> list[dict]:
  369. tasks = []
  370. for demand in demands:
  371. nodes_by_dim = search_results.get(demand["demand_name"], {})
  372. tasks.append(mount_single_demand(demand, nodes_by_dim))
  373. return await asyncio.gather(*tasks)
  374. # ===== 可视化:生成关系表 + HTML =====
  375. def build_relation_table(all_case_demands: list[dict], merged_demands: list[dict], decisions: list[dict]) -> list[dict]:
  376. """构建 case → demand → node 关系表"""
  377. # 建立 demand_name → decision 的映射
  378. decision_map = {d["demand_name"]: d for d in decisions}
  379. rows = []
  380. for md in merged_demands:
  381. dn = md["demand_name"]
  382. case_ids = md.get("source_case_ids", [])
  383. dec = decision_map.get(dn, {})
  384. rows.append({
  385. "demand_name": dn,
  386. "description": md["description"],
  387. "source_case_ids": case_ids,
  388. "mount_decision": dec.get("decision", ""),
  389. })
  390. return rows
  391. def generate_html_visualization(cases_data: dict, all_case_demands: list[dict],
  392. merged_demands: list[dict], decisions: list[dict],
  393. search_results: dict) -> str:
  394. """生成简洁的 case→demand 树状图,节点作为标签显示在 demand 框里"""
  395. # 构建完整的 case_map,兼容多种字段名
  396. case_map = {}
  397. for c in cases_data.get("cases", []):
  398. case_map[c["case_id"]] = {
  399. "title": c.get("title") or c.get("video_title") or c.get("post_title", ""),
  400. "images": c.get("images") or c.get("effect_images", []),
  401. "link": c.get("source_link") or c.get("video_url") or c.get("post_url", "")
  402. }
  403. # 构建 decision_map:demand_name → mounted_nodes
  404. decision_map = {}
  405. for d in decisions:
  406. dec = d.get("decision", {})
  407. if isinstance(dec, dict):
  408. decision_map[d["demand_name"]] = dec.get("mounted_nodes", [])
  409. else:
  410. decision_map[d["demand_name"]] = []
  411. nodes_js = []
  412. edges_js = []
  413. node_id = 0
  414. # 为每个 demand 创建节点,并为其来源的每个 case 创建独立的 case 节点
  415. for md in merged_demands:
  416. demand_id = node_id
  417. node_id += 1
  418. # 构建 demand 节点的显示内容
  419. dn = md["demand_name"]
  420. desc = md["description"][:100] + "..." if len(md["description"]) > 100 else md["description"]
  421. # 从挂载决策中获取最终选择的节点
  422. mounted = decision_map.get(dn, [])
  423. node_tags = [f"{n['name']}({n.get('source_type', '')})" for n in mounted]
  424. tags_html = " | ".join(node_tags) if node_tags else "无挂载节点"
  425. # demand 节点的 HTML 标签(包含挂载节点)
  426. demand_label = f"{dn}\n\n[挂载] {tags_html}"
  427. demand_title = f"{dn}\n\n{desc}\n\n挂载节点: {tags_html}"
  428. nodes_js.append({
  429. "id": demand_id,
  430. "label": demand_label,
  431. "title": demand_title,
  432. "group": "demand",
  433. "level": 1,
  434. "shape": "box",
  435. "font": {"size": 12}
  436. })
  437. # 为每个来源 case 创建独立的 case 节点
  438. for cid in md.get("source_case_ids", []):
  439. case_id = node_id
  440. node_id += 1
  441. # 类型转换:case_map 的 key 可能是字符串或整数
  442. cid_str = str(cid)
  443. case_info = case_map.get(cid_str) or case_map.get(cid, {"title": f"Case {cid}", "images": [], "link": ""})
  444. # label 只显示标题
  445. title_short = case_info['title'][:50] + "..." if len(case_info['title']) > 50 else case_info['title']
  446. case_label = f"Case {cid}\n{title_short}"
  447. # 构建带图片的 HTML tooltip(不含链接)
  448. img_html = ""
  449. if case_info['images']:
  450. img_url = case_info['images'][0]
  451. img_html = f'<img src="{img_url}" style="max-width:300px; max-height:200px; display:block; margin:10px 0;"/>'
  452. case_title = f'<div style="max-width:350px;"><b>Case {cid}: {case_info["title"]}</b>{img_html}<br/><i>点击节点查看原帖</i></div>'
  453. nodes_js.append({
  454. "id": case_id,
  455. "label": case_label,
  456. "title": case_title,
  457. "url": case_info['link'], # 存储链接,用于点击事件
  458. "group": "case",
  459. "level": 0,
  460. "shape": "box"
  461. })
  462. # case → demand 边
  463. edges_js.append({"from": case_id, "to": demand_id})
  464. html = f"""<!DOCTYPE html>
  465. <html><head>
  466. <meta charset="utf-8">
  467. <title>Case → Demand 树状图</title>
  468. <script src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script>
  469. <style>
  470. body {{ margin: 0; font-family: sans-serif; }}
  471. #graph {{ width: 100%; height: 90vh; border: 1px solid #ddd; }}
  472. #legend {{ padding: 10px; display: flex; gap: 20px; align-items: center; background: #f5f5f5; }}
  473. .legend-item {{ display: flex; align-items: center; gap: 5px; }}
  474. .dot {{ width: 14px; height: 14px; border-radius: 3px; }}
  475. </style>
  476. </head><body>
  477. <div id="legend">
  478. <b>树状图:</b>
  479. <span class="legend-item"><span class="dot" style="background:#97C2FC"></span> 帖子 (Case)</span>
  480. <span class="legend-item"><span class="dot" style="background:#FFB366"></span> 需求 (Demand + 匹配节点)</span>
  481. </div>
  482. <div id="graph"></div>
  483. <script>
  484. var nodesData = {json.dumps(nodes_js, ensure_ascii=False)};
  485. var edgesData = {json.dumps(edges_js, ensure_ascii=False)};
  486. // 将 title 字符串转为 DOM 元素,让 vis.js 渲染 HTML
  487. nodesData.forEach(function(node) {{
  488. if (node.title && typeof node.title === 'string' && node.title.includes('<')) {{
  489. var container = document.createElement('div');
  490. container.innerHTML = node.title;
  491. node.title = container;
  492. }}
  493. }});
  494. var nodes = new vis.DataSet(nodesData);
  495. var edges = new vis.DataSet(edgesData);
  496. var container = document.getElementById("graph");
  497. var data = {{ nodes: nodes, edges: edges }};
  498. var options = {{
  499. layout: {{
  500. hierarchical: {{
  501. direction: "UD",
  502. sortMethod: "directed",
  503. levelSeparation: 180,
  504. nodeSpacing: 100
  505. }}
  506. }},
  507. groups: {{
  508. case: {{
  509. shape: "box",
  510. color: {{ background: "#97C2FC", border: "#2B7CE9" }},
  511. font: {{ size: 11 }},
  512. widthConstraint: {{ minimum: 200, maximum: 350 }}
  513. }},
  514. demand: {{
  515. shape: "box",
  516. color: {{ background: "#FFB366", border: "#FF8C00" }},
  517. font: {{ size: 12 }},
  518. widthConstraint: {{ minimum: 200, maximum: 400 }}
  519. }}
  520. }},
  521. edges: {{ arrows: "to", smooth: {{ type: "cubicBezier" }} }},
  522. physics: {{ enabled: false }},
  523. interaction: {{ hover: true, tooltipDelay: 100 }}
  524. }};
  525. var network = new vis.Network(container, data, options);
  526. // 点击 case 节点跳转到原帖
  527. network.on("click", function(params) {{
  528. if (params.nodes.length > 0) {{
  529. var nodeId = params.nodes[0];
  530. var nodeData = nodes.get(nodeId);
  531. if (nodeData.url) {{
  532. window.open(nodeData.url, '_blank');
  533. }}
  534. }}
  535. }});
  536. </script>
  537. </body></html>"""
  538. return html
  539. # ===== 主流程 =====
  540. async def run(cases_path: str):
  541. with open(cases_path, "r", encoding="utf-8") as f:
  542. cases_data = json.load(f)
  543. cases = cases_data.get("cases", [])
  544. topic = cases_data.get("topic", "未知主题")
  545. # 过滤:只保留有图片的 case(兼容多种字段名)
  546. def has_images(case):
  547. imgs = case.get("images") or case.get("effect_images") or []
  548. return len(imgs) > 0
  549. cases_with_images = [c for c in cases if has_images(c)]
  550. skipped = len(cases) - len(cases_with_images)
  551. print("=" * 60)
  552. print(f"输入: {cases_path}")
  553. print(f"主题: {topic} (总计 {len(cases)} 个案例)")
  554. if skipped > 0:
  555. print(f"过滤: 跳过 {skipped} 个纯视频案例,保留 {len(cases_with_images)} 个图文案例")
  556. print("=" * 60)
  557. if not cases_with_images:
  558. print("\n错误: 没有图文案例可处理")
  559. return
  560. # Step 1a: 逐帖子提需求(并发)
  561. print(f"\n第 1a 步:逐帖子提取需求({len(cases_with_images)} 个并发)...")
  562. tasks = [extract_demands_for_case(case) for case in cases_with_images]
  563. all_case_demands = await asyncio.gather(*tasks)
  564. total_raw = sum(len(cd["demands"]) for cd in all_case_demands)
  565. for cd in all_case_demands:
  566. n = len(cd["demands"])
  567. names = ", ".join(d["demand_name"] for d in cd["demands"][:3])
  568. suffix = "..." if n > 3 else ""
  569. print(f" Case {cd['case_id']}: {n} 个需求 [{names}{suffix}]")
  570. print(f" 共提取 {total_raw} 个原始需求")
  571. # Step 1b: 合并去重
  572. print(f"\n第 1b 步:合并去重...")
  573. merged_demands = await merge_demands(all_case_demands)
  574. print(f" 合并后 {len(merged_demands)} 个需求:")
  575. for i, md in enumerate(merged_demands, 1):
  576. print(f" {i}. {md['demand_name']} (来自 case {md.get('source_case_ids', [])})")
  577. # Step 2: 搜索节点
  578. print(f"\n第 2 步:语义搜索内容树 + 获取父子节点...")
  579. search_results = await step2_search_and_expand(merged_demands)
  580. for dn, dims in search_results.items():
  581. total = sum(len(ns) for ns in dims.values())
  582. print(f" [{dn}] {total} 个节点")
  583. # Step 3: 挂载决策(并发)
  584. print(f"\n第 3 步:LLM 挂载决策({len(merged_demands)} 个需求并发)...")
  585. decisions = await step3_mount_decisions(merged_demands, search_results)
  586. for d in decisions:
  587. print(f"\n{'─' * 40}")
  588. print(f"【{d['demand_name']}】(来自 case {d['source_case_ids']})")
  589. print(d["decision"])
  590. # 构建关系表
  591. relation_table = build_relation_table(all_case_demands, merged_demands, decisions)
  592. # 构建 case 信息映射(用于可视化)
  593. # 统一使用字符串作为 key,避免类型不匹配
  594. case_info_map = {}
  595. for c in cases_with_images:
  596. case_info_map[str(c["case_id"])] = {
  597. "title": c.get("title") or c.get("video_title") or c.get("post_title", ""),
  598. "images": c.get("images") or c.get("effect_images", []),
  599. "link": c.get("source_link") or c.get("video_url") or c.get("post_url", "")
  600. }
  601. # 保存结果
  602. output_dir = Path(cases_path).parent
  603. output_file = output_dir / "match_nodes_result.json"
  604. output_data = {
  605. "topic": topic,
  606. "case_info": case_info_map, # 新增:保存 case 信息
  607. "per_case_demands": [
  608. {"case_id": cd["case_id"], "title": cd["title"],
  609. "demands": [d["demand_name"] for d in cd["demands"]]}
  610. for cd in all_case_demands
  611. ],
  612. "merged_demands": merged_demands,
  613. "search_results": search_results,
  614. "mount_decisions": decisions,
  615. "relation_table": relation_table,
  616. }
  617. with open(output_file, "w", encoding="utf-8") as f:
  618. json.dump(output_data, f, ensure_ascii=False, indent=2)
  619. print(f"\n结果已保存到: {output_file}")
  620. # 生成可视化
  621. html = generate_html_visualization(cases_data, all_case_demands, merged_demands, decisions, search_results)
  622. html_file = output_dir / "match_nodes_graph.html"
  623. with open(html_file, "w", encoding="utf-8") as f:
  624. f.write(html)
  625. print(f"可视化已保存到: {html_file}")
  626. if __name__ == "__main__":
  627. if len(sys.argv) < 2:
  628. print("用法: python match_nodes.py <cases.json 路径>")
  629. print("示例: python match_nodes.py outputs/midjourney_0/02_cases.json")
  630. sys.exit(1)
  631. os.environ.setdefault("no_proxy", "*")
  632. asyncio.run(run(sys.argv[1]))