| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758 |
- """
- 制作案例 → 制作需求 → 内容树节点匹配
- 流程:
- 1a. 逐帖子分析(带图片),提炼每个帖子的制作需求(并发)
- 1b. 合并去重相似需求,保留 case_id 溯源
- 2. 按合并后的需求语义搜索内容树 + 取父子节点
- 3. 每个需求独立过 LLM 挂载决策(并发)
- 4. 输出 case → demand → node 关系表 + 可视化
- """
- import asyncio
- import json
- import os
- import re
- import sys
- from pathlib import Path
- import httpx
- # 添加项目根目录
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
- from dotenv import load_dotenv
- load_dotenv()
- from agent.llm.qwen import qwen_llm_call
- # ===== 配置 =====
- BASE_DIR = Path(__file__).parent
- CONTENT_TREE_BASE = "http://8.147.104.190:8001"
- CATEGORY_TREE_PATH = BASE_DIR / "prompts" / "category_tree.json"
- DEFAULT_MODEL = "qwen-plus"
- VISION_MODEL = "qwen-vl-max"
- SOURCE_TYPES = ["实质", "形式", "意图"]
- # 加载本地分类树
- _CATEGORY_TREE_CACHE = None
- def load_category_tree() -> dict:
- """加载本地分类树(缓存)"""
- global _CATEGORY_TREE_CACHE
- if _CATEGORY_TREE_CACHE is None:
- with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f:
- _CATEGORY_TREE_CACHE = json.load(f)
- return _CATEGORY_TREE_CACHE
- def collect_all_nodes(node: dict, nodes: list, parent_path: str = ""):
- """递归收集所有节点,展平树结构"""
- if "id" in node:
- node_copy = {
- "entity_id": node["id"],
- "name": node["name"],
- "path": node.get("path", parent_path),
- "source_type": node.get("source_type"),
- "description": node.get("description") or "",
- "level": node.get("level"),
- "parent_id": node.get("parent_id"),
- "element_count": node.get("element_count", 0),
- }
- nodes.append(node_copy)
- if "children" in node:
- current_path = node.get("path", parent_path)
- for child in node["children"]:
- collect_all_nodes(child, nodes, current_path)
- def search_local_tree(query: str, source_type: str, top_k: int = 5) -> list:
- """在本地分类树中搜索节点"""
- tree = load_category_tree()
- all_nodes = []
- collect_all_nodes(tree, all_nodes)
- # 过滤:只保留指定维度
- filtered = [n for n in all_nodes if n.get("source_type") == source_type]
- # 简单文本匹配打分
- query_lower = query.lower()
- scored = []
- for node in filtered:
- name = node["name"].lower()
- desc = node["description"].lower()
- score = 0.0
- # 名称完全匹配
- if query_lower == name:
- score = 1.0
- # 名称包含
- elif query_lower in name:
- score = 0.8
- # 描述包含
- elif query_lower in desc:
- score = 0.5
- # 名称被包含(反向)
- elif name in query_lower:
- score = 0.6
- if score > 0:
- node["score"] = score
- node["entity_type"] = "category" # 本地树都是 category
- scored.append(node)
- # 按分数排序,取 top_k
- scored.sort(key=lambda x: x["score"], reverse=True)
- return scored[:top_k]
- # ===== Prompt 加载 =====
- def load_prompt(filename: str) -> dict:
- """加载 .prompt 文件,解析 frontmatter 和 $role$ 分段"""
- path = BASE_DIR / "prompts" / filename
- text = path.read_text(encoding="utf-8")
- config = {}
- if text.startswith("---"):
- _, fm, text = text.split("---", 2)
- for line in fm.strip().splitlines():
- if ":" in line:
- k, v = line.split(":", 1)
- k, v = k.strip(), v.strip()
- if v.replace(".", "", 1).isdigit():
- v = float(v) if "." in v else int(v)
- config[k] = v
- messages = []
- parts = re.split(r'^\$(\w+)\$\s*$', text.strip(), flags=re.MULTILINE)
- for i in range(1, len(parts), 2):
- role = parts[i].strip()
- content = parts[i + 1].strip() if i + 1 < len(parts) else ""
- messages.append({"role": role, "content": content})
- return {"config": config, "messages": messages}
- def render_messages(prompt_data: dict, variables: dict) -> list[dict]:
- """用变量替换 prompt 模板中的 {var} 占位符"""
- rendered = []
- for msg in prompt_data["messages"]:
- content = msg["content"]
- for k, v in variables.items():
- content = content.replace(f"{{{k}}}", str(v))
- rendered.append({"role": msg["role"], "content": content})
- return rendered
- def parse_json_response(content: str) -> list | dict:
- """清理 LLM 输出中的 markdown 包裹并解析 JSON"""
- content = content.strip()
- if content.startswith("```"):
- content = content.split("\n", 1)[1]
- content = content.rsplit("```", 1)[0]
- try:
- return json.loads(content)
- except json.JSONDecodeError as e:
- # 尝试自动修复:替换字符串中的未转义引号
- print(f"\n[JSON 解析错误] 尝试自动修复...")
- # 简单修复:把字符串值中的单独双引号替换为单引号
- import re
- # 匹配 "key": "value with "quote" inside"
- # 这个正则会找到字符串值中间的引号并替换
- fixed = re.sub(r'("(?:evidence|description|demand_name)":\s*")([^"]*)"([^"]*)"([^"]*")',
- r'\1\2\'\3\4', content)
- try:
- return json.loads(fixed)
- except json.JSONDecodeError:
- # 修复失败,打印错误位置
- start = max(0, e.pos - 100)
- end = min(len(content), e.pos + 100)
- print(f"\n[JSON 解析错误] 位置 {e.pos}:")
- print(f"...{content[start:end]}...")
- print(f"\n完整内容已保存到 json_error.txt")
- with open("json_error.txt", "w", encoding="utf-8") as f:
- f.write(content)
- raise
- # ===== 内容树 API =====
- async def search_content_tree(client: httpx.AsyncClient, query: str, source_type: str, top_k: int = 5) -> list:
- params = {
- "q": query, "source_type": source_type, "entity_type": "all",
- "top_k": top_k, "use_description": "true",
- "include_ancestors": "false", "descendant_depth": 0,
- }
- resp = await client.get(f"{CONTENT_TREE_BASE}/api/agent/search", params=params)
- resp.raise_for_status()
- return resp.json().get("results", [])
- async def get_category_tree(client: httpx.AsyncClient, entity_id: int, source_type: str) -> dict:
- params = {
- "source_type": source_type, "include_ancestors": "true", "descendant_depth": 1,
- }
- resp = await client.get(f"{CONTENT_TREE_BASE}/api/agent/search/category/{entity_id}", params=params)
- resp.raise_for_status()
- return resp.json()
- # ===== Step 1a:逐帖子提需求(支持图片) =====
- def build_case_content(case: dict) -> list[dict]:
- """
- 构造单个帖子的多模态消息内容。
- 返回 OpenAI vision 格式的 content 数组。
- """
- parts = []
- # 文本部分(兼容多种字段名)
- title = case.get("title") or case.get("video_title") or case.get("post_title", "未知")
- text_lines = [f"## 案例:{title}"]
- text_lines.append(f"来源:{case.get('source', '')}")
- if case.get("user_input"):
- text_lines.append(f"用户输入:{json.dumps(case['user_input'], ensure_ascii=False)}")
- if case.get("output_description"):
- text_lines.append(f"输出效果:{case['output_description']}")
- if case.get("key_findings"):
- text_lines.append(f"关键发现:{case['key_findings']}")
- parts.append({"type": "text", "text": "\n".join(text_lines)})
- # 图片部分(兼容多种字段名)
- images = case.get("images") or case.get("effect_images", [])
- for img_url in images[:4]: # 最多4张图,控制 token
- parts.append({
- "type": "image_url",
- "image_url": {"url": img_url},
- })
- return parts
- async def extract_demands_for_case(case: dict) -> dict:
- """对单个帖子提取需求"""
- prompt_data = load_prompt("step1_extract_demands.prompt")
- model = prompt_data["config"].get("model", VISION_MODEL)
- temperature = prompt_data["config"].get("temperature", 0.3)
- case_content = build_case_content(case)
- has_images = any(p["type"] == "image_url" for p in case_content)
- # 如果没有图片,降级用文本模型省成本
- if not has_images:
- model = DEFAULT_MODEL
- # 纯文本模式:content 用字符串即可
- text_only = "\n".join(p["text"] for p in case_content if p["type"] == "text")
- messages = render_messages(prompt_data, {"case_content": text_only})
- else:
- # 多模态模式:user message 的 content 用数组
- messages = []
- for msg in prompt_data["messages"]:
- if msg["role"] == "user":
- # 把 prompt 模板文本放在图片前面
- template_text = msg["content"].replace("{case_content}", "")
- content = case_content.copy()
- if template_text.strip():
- content.insert(0, {"type": "text", "text": template_text})
- messages.append({"role": "user", "content": content})
- else:
- messages.append(msg)
- try:
- result = await qwen_llm_call(messages, model=model, temperature=temperature)
- demands = parse_json_response(result["content"])
- case_id = case.get("case_id", 0)
- # 给每个需求打上 case_id
- for d in demands:
- d["source_case_id"] = case_id
- return {"case_id": case_id, "title": case.get("title", ""), "demands": demands}
- except Exception as e:
- case_id = case.get("case_id", 0)
- print(f" [!] case {case_id} 提取失败: {e}")
- return {"case_id": case_id, "title": case.get("title", ""), "demands": [], "error": str(e)}
- # ===== Step 1b:合并去重 =====
- async def merge_demands(all_case_demands: list[dict]) -> list[dict]:
- """用 LLM 合并相似需求,超过30个时分批处理"""
- # 构造输入:展平所有 case 的 demands,标注来源
- flat = []
- for cd in all_case_demands:
- for d in cd["demands"]:
- flat.append({
- "case_id": cd["case_id"],
- "case_title": cd["title"],
- "demand_name": d["demand_name"],
- "description": d["description"],
- "search_keywords": d["search_keywords"],
- "evidence": d.get("evidence", ""),
- })
- if not flat:
- return []
- prompt_data = load_prompt("step1b_merge_demands.prompt")
- model = prompt_data["config"].get("model", DEFAULT_MODEL)
- temperature = prompt_data["config"].get("temperature", 0.3)
- # 分批处理:每批最多30个需求
- BATCH_SIZE = 30
- if len(flat) <= BATCH_SIZE:
- # 不需要分批
- all_demands_json = json.dumps(flat, ensure_ascii=False, indent=2)
- messages = render_messages(prompt_data, {"all_demands_json": all_demands_json})
- result = await qwen_llm_call(messages, model=model, temperature=temperature)
- merged = parse_json_response(result["content"])
- # 统一转换 source_case_ids 为字符串
- for m in merged:
- m["source_case_ids"] = [str(cid) for cid in m.get("source_case_ids", [])]
- return merged
- else:
- # 分批合并
- print(f" 需求数量 {len(flat)} 超过 {BATCH_SIZE},分批合并...")
- batches = [flat[i:i + BATCH_SIZE] for i in range(0, len(flat), BATCH_SIZE)]
- batch_results = []
- for i, batch in enumerate(batches, 1):
- print(f" 批次 {i}/{len(batches)}: {len(batch)} 个需求")
- batch_json = json.dumps(batch, ensure_ascii=False, indent=2)
- messages = render_messages(prompt_data, {"all_demands_json": batch_json})
- result = await qwen_llm_call(messages, model=model, temperature=temperature)
- batch_merged = parse_json_response(result["content"])
- batch_results.extend(batch_merged)
- # 如果分批后的结果仍然很多,再做一次最终合并
- if len(batch_results) > BATCH_SIZE:
- print(f" 批次合并后仍有 {len(batch_results)} 个需求,进行最终合并...")
- final_json = json.dumps(batch_results, ensure_ascii=False, indent=2)
- messages = render_messages(prompt_data, {"all_demands_json": final_json})
- result = await qwen_llm_call(messages, model=model, temperature=temperature)
- final_merged = parse_json_response(result["content"])
- # 统一转换 source_case_ids 为字符串
- for m in final_merged:
- m["source_case_ids"] = [str(cid) for cid in m.get("source_case_ids", [])]
- return final_merged
- else:
- # 统一转换 source_case_ids 为字符串
- for m in batch_results:
- m["source_case_ids"] = [str(cid) for cid in m.get("source_case_ids", [])]
- return batch_results
- # ===== Step 2:语义搜索 + 取父子节点 =====
- async def step2_search_and_expand(demands: list[dict]) -> dict:
- """语义搜索本地分类树 + 取父子节点"""
- all_results = {}
- tree = load_category_tree()
- for demand in demands:
- name = demand["demand_name"]
- keywords = demand["search_keywords"]
- all_results[name] = {}
- for source_type in SOURCE_TYPES:
- nodes = []
- seen = set()
- for kw in keywords:
- # 使用本地搜索替代 API
- results = search_local_tree(kw, source_type, top_k=3)
- for r in results:
- eid = r.get("entity_id")
- if eid in seen:
- continue
- seen.add(eid)
- node = {
- "entity_id": eid,
- "name": r.get("name", ""),
- "entity_type": "category",
- "score": r.get("score", 0),
- "description": r.get("description", ""),
- "path": r.get("path", ""),
- }
- # 本地树暂不支持动态取父子,先留空
- node["ancestors"] = []
- node["children"] = []
- nodes.append(node)
- if nodes:
- all_results[name][source_type] = nodes
- return all_results
- # ===== Step 3:逐需求挂载决策(并发) =====
- def build_single_demand_context(demand: dict, nodes_by_dim: dict) -> str:
- parts = []
- parts.append(f"## 需求:{demand['demand_name']}")
- parts.append(f"描述:{demand['description']}")
- parts.append(f"来源帖子:case_id={demand.get('source_case_ids', [])}")
- if demand.get("evidence"):
- parts.append(f"数据依据:{demand['evidence']}")
- if not nodes_by_dim:
- parts.append("(未搜索到相关节点)")
- return "\n".join(parts)
- for source_type, nodes in nodes_by_dim.items():
- parts.append(f"\n### {source_type}维度匹配节点:")
- for n in nodes:
- line = f"- [{n['entity_type']}] entity_id={n['entity_id']} \"{n['name']}\" score={n['score']:.2f}"
- if n.get("description"):
- line += f" 描述: {n['description']}"
- parts.append(line)
- if n.get("ancestors"):
- path = " > ".join(a["name"] for a in n["ancestors"])
- parts.append(f" 父链: {path}")
- if n.get("children"):
- kids = ", ".join(c["name"] for c in n["children"][:10])
- parts.append(f" 子节点: {kids}")
- return "\n".join(parts)
- async def mount_single_demand(demand: dict, nodes_by_dim: dict) -> dict:
- prompt_data = load_prompt("step3_mount_decision.prompt")
- model = prompt_data["config"].get("model", DEFAULT_MODEL)
- temperature = prompt_data["config"].get("temperature", 0.3)
- node_context = build_single_demand_context(demand, nodes_by_dim)
- messages = render_messages(prompt_data, {"node_context": node_context})
- result = await qwen_llm_call(messages, model=model, temperature=temperature, max_tokens=4096)
- # 解析结构化 JSON 输出
- try:
- decision = parse_json_response(result["content"])
- except Exception as e:
- print(f" [!] 挂载决策解析失败 ({demand['demand_name']}): {e}")
- decision = {"demand_name": demand["demand_name"], "mounted_nodes": [], "notes": f"解析失败: {e}"}
- return {
- "demand_name": demand["demand_name"],
- "source_case_ids": demand.get("source_case_ids", []),
- "decision": decision,
- }
- async def step3_mount_decisions(demands: list[dict], search_results: dict) -> list[dict]:
- tasks = []
- for demand in demands:
- nodes_by_dim = search_results.get(demand["demand_name"], {})
- tasks.append(mount_single_demand(demand, nodes_by_dim))
- return await asyncio.gather(*tasks)
- # ===== 可视化:生成关系表 + HTML =====
- def build_relation_table(all_case_demands: list[dict], merged_demands: list[dict], decisions: list[dict]) -> list[dict]:
- """构建 case → demand → node 关系表"""
- # 建立 demand_name → decision 的映射
- decision_map = {d["demand_name"]: d for d in decisions}
- rows = []
- for md in merged_demands:
- dn = md["demand_name"]
- case_ids = md.get("source_case_ids", [])
- dec = decision_map.get(dn, {})
- rows.append({
- "demand_name": dn,
- "description": md["description"],
- "source_case_ids": case_ids,
- "mount_decision": dec.get("decision", ""),
- })
- return rows
- def generate_html_visualization(cases_data: dict, all_case_demands: list[dict],
- merged_demands: list[dict], decisions: list[dict],
- search_results: dict) -> str:
- """生成简洁的 case→demand 树状图,节点作为标签显示在 demand 框里"""
- # 构建完整的 case_map,兼容多种字段名
- case_map = {}
- for c in cases_data.get("cases", []):
- case_map[c["case_id"]] = {
- "title": c.get("title") or c.get("video_title") or c.get("post_title", ""),
- "images": c.get("images") or c.get("effect_images", []),
- "link": c.get("source_link") or c.get("video_url") or c.get("post_url", "")
- }
- # 构建 decision_map:demand_name → mounted_nodes
- decision_map = {}
- for d in decisions:
- dec = d.get("decision", {})
- if isinstance(dec, dict):
- decision_map[d["demand_name"]] = dec.get("mounted_nodes", [])
- else:
- decision_map[d["demand_name"]] = []
- nodes_js = []
- edges_js = []
- node_id = 0
- # 为每个 demand 创建节点,并为其来源的每个 case 创建独立的 case 节点
- for md in merged_demands:
- demand_id = node_id
- node_id += 1
- # 构建 demand 节点的显示内容
- dn = md["demand_name"]
- desc = md["description"][:100] + "..." if len(md["description"]) > 100 else md["description"]
- # 从挂载决策中获取最终选择的节点
- mounted = decision_map.get(dn, [])
- node_tags = [f"{n['name']}({n.get('source_type', '')})" for n in mounted]
- tags_html = " | ".join(node_tags) if node_tags else "无挂载节点"
- # demand 节点的 HTML 标签(包含挂载节点)
- demand_label = f"{dn}\n\n[挂载] {tags_html}"
- demand_title = f"{dn}\n\n{desc}\n\n挂载节点: {tags_html}"
- nodes_js.append({
- "id": demand_id,
- "label": demand_label,
- "title": demand_title,
- "group": "demand",
- "level": 1,
- "shape": "box",
- "font": {"size": 12}
- })
- # 为每个来源 case 创建独立的 case 节点
- for cid in md.get("source_case_ids", []):
- case_id = node_id
- node_id += 1
- # 类型转换:case_map 的 key 可能是字符串或整数
- cid_str = str(cid)
- case_info = case_map.get(cid_str) or case_map.get(cid, {"title": f"Case {cid}", "images": [], "link": ""})
- # label 只显示标题
- title_short = case_info['title'][:50] + "..." if len(case_info['title']) > 50 else case_info['title']
- case_label = f"Case {cid}\n{title_short}"
- # 构建带图片的 HTML tooltip(不含链接)
- img_html = ""
- if case_info['images']:
- img_url = case_info['images'][0]
- img_html = f'<img src="{img_url}" style="max-width:300px; max-height:200px; display:block; margin:10px 0;"/>'
- case_title = f'<div style="max-width:350px;"><b>Case {cid}: {case_info["title"]}</b>{img_html}<br/><i>点击节点查看原帖</i></div>'
- nodes_js.append({
- "id": case_id,
- "label": case_label,
- "title": case_title,
- "url": case_info['link'], # 存储链接,用于点击事件
- "group": "case",
- "level": 0,
- "shape": "box"
- })
- # case → demand 边
- edges_js.append({"from": case_id, "to": demand_id})
- html = f"""<!DOCTYPE html>
- <html><head>
- <meta charset="utf-8">
- <title>Case → Demand 树状图</title>
- <script src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script>
- <style>
- body {{ margin: 0; font-family: sans-serif; }}
- #graph {{ width: 100%; height: 90vh; border: 1px solid #ddd; }}
- #legend {{ padding: 10px; display: flex; gap: 20px; align-items: center; background: #f5f5f5; }}
- .legend-item {{ display: flex; align-items: center; gap: 5px; }}
- .dot {{ width: 14px; height: 14px; border-radius: 3px; }}
- </style>
- </head><body>
- <div id="legend">
- <b>树状图:</b>
- <span class="legend-item"><span class="dot" style="background:#97C2FC"></span> 帖子 (Case)</span>
- <span class="legend-item"><span class="dot" style="background:#FFB366"></span> 需求 (Demand + 匹配节点)</span>
- </div>
- <div id="graph"></div>
- <script>
- var nodesData = {json.dumps(nodes_js, ensure_ascii=False)};
- var edgesData = {json.dumps(edges_js, ensure_ascii=False)};
- // 将 title 字符串转为 DOM 元素,让 vis.js 渲染 HTML
- nodesData.forEach(function(node) {{
- if (node.title && typeof node.title === 'string' && node.title.includes('<')) {{
- var container = document.createElement('div');
- container.innerHTML = node.title;
- node.title = container;
- }}
- }});
- var nodes = new vis.DataSet(nodesData);
- var edges = new vis.DataSet(edgesData);
- var container = document.getElementById("graph");
- var data = {{ nodes: nodes, edges: edges }};
- var options = {{
- layout: {{
- hierarchical: {{
- direction: "UD",
- sortMethod: "directed",
- levelSeparation: 180,
- nodeSpacing: 100
- }}
- }},
- groups: {{
- case: {{
- shape: "box",
- color: {{ background: "#97C2FC", border: "#2B7CE9" }},
- font: {{ size: 11 }},
- widthConstraint: {{ minimum: 200, maximum: 350 }}
- }},
- demand: {{
- shape: "box",
- color: {{ background: "#FFB366", border: "#FF8C00" }},
- font: {{ size: 12 }},
- widthConstraint: {{ minimum: 200, maximum: 400 }}
- }}
- }},
- edges: {{ arrows: "to", smooth: {{ type: "cubicBezier" }} }},
- physics: {{ enabled: false }},
- interaction: {{ hover: true, tooltipDelay: 100 }}
- }};
- var network = new vis.Network(container, data, options);
- // 点击 case 节点跳转到原帖
- network.on("click", function(params) {{
- if (params.nodes.length > 0) {{
- var nodeId = params.nodes[0];
- var nodeData = nodes.get(nodeId);
- if (nodeData.url) {{
- window.open(nodeData.url, '_blank');
- }}
- }}
- }});
- </script>
- </body></html>"""
- return html
- # ===== 主流程 =====
- async def run(cases_path: str):
- with open(cases_path, "r", encoding="utf-8") as f:
- cases_data = json.load(f)
- cases = cases_data.get("cases", [])
- topic = cases_data.get("topic", "未知主题")
- # 过滤:只保留有图片的 case(兼容多种字段名)
- def has_images(case):
- imgs = case.get("images") or case.get("effect_images") or []
- return len(imgs) > 0
- cases_with_images = [c for c in cases if has_images(c)]
- skipped = len(cases) - len(cases_with_images)
- print("=" * 60)
- print(f"输入: {cases_path}")
- print(f"主题: {topic} (总计 {len(cases)} 个案例)")
- if skipped > 0:
- print(f"过滤: 跳过 {skipped} 个纯视频案例,保留 {len(cases_with_images)} 个图文案例")
- print("=" * 60)
- if not cases_with_images:
- print("\n错误: 没有图文案例可处理")
- return
- # Step 1a: 逐帖子提需求(并发)
- print(f"\n第 1a 步:逐帖子提取需求({len(cases_with_images)} 个并发)...")
- tasks = [extract_demands_for_case(case) for case in cases_with_images]
- all_case_demands = await asyncio.gather(*tasks)
- total_raw = sum(len(cd["demands"]) for cd in all_case_demands)
- for cd in all_case_demands:
- n = len(cd["demands"])
- names = ", ".join(d["demand_name"] for d in cd["demands"][:3])
- suffix = "..." if n > 3 else ""
- print(f" Case {cd['case_id']}: {n} 个需求 [{names}{suffix}]")
- print(f" 共提取 {total_raw} 个原始需求")
- # Step 1b: 合并去重
- print(f"\n第 1b 步:合并去重...")
- merged_demands = await merge_demands(all_case_demands)
- print(f" 合并后 {len(merged_demands)} 个需求:")
- for i, md in enumerate(merged_demands, 1):
- print(f" {i}. {md['demand_name']} (来自 case {md.get('source_case_ids', [])})")
- # Step 2: 搜索节点
- print(f"\n第 2 步:语义搜索内容树 + 获取父子节点...")
- search_results = await step2_search_and_expand(merged_demands)
- for dn, dims in search_results.items():
- total = sum(len(ns) for ns in dims.values())
- print(f" [{dn}] {total} 个节点")
- # Step 3: 挂载决策(并发)
- print(f"\n第 3 步:LLM 挂载决策({len(merged_demands)} 个需求并发)...")
- decisions = await step3_mount_decisions(merged_demands, search_results)
- for d in decisions:
- print(f"\n{'─' * 40}")
- print(f"【{d['demand_name']}】(来自 case {d['source_case_ids']})")
- print(d["decision"])
- # 构建关系表
- relation_table = build_relation_table(all_case_demands, merged_demands, decisions)
- # 构建 case 信息映射(用于可视化)
- # 统一使用字符串作为 key,避免类型不匹配
- case_info_map = {}
- for c in cases_with_images:
- case_info_map[str(c["case_id"])] = {
- "title": c.get("title") or c.get("video_title") or c.get("post_title", ""),
- "images": c.get("images") or c.get("effect_images", []),
- "link": c.get("source_link") or c.get("video_url") or c.get("post_url", "")
- }
- # 保存结果
- output_dir = Path(cases_path).parent
- output_file = output_dir / "match_nodes_result.json"
- output_data = {
- "topic": topic,
- "case_info": case_info_map, # 新增:保存 case 信息
- "per_case_demands": [
- {"case_id": cd["case_id"], "title": cd["title"],
- "demands": [d["demand_name"] for d in cd["demands"]]}
- for cd in all_case_demands
- ],
- "merged_demands": merged_demands,
- "search_results": search_results,
- "mount_decisions": decisions,
- "relation_table": relation_table,
- }
- with open(output_file, "w", encoding="utf-8") as f:
- json.dump(output_data, f, ensure_ascii=False, indent=2)
- print(f"\n结果已保存到: {output_file}")
- # 生成可视化
- html = generate_html_visualization(cases_data, all_case_demands, merged_demands, decisions, search_results)
- html_file = output_dir / "match_nodes_graph.html"
- with open(html_file, "w", encoding="utf-8") as f:
- f.write(html)
- print(f"可视化已保存到: {html_file}")
- if __name__ == "__main__":
- if len(sys.argv) < 2:
- print("用法: python match_nodes.py <cases.json 路径>")
- print("示例: python match_nodes.py outputs/midjourney_0/02_cases.json")
- sys.exit(1)
- os.environ.setdefault("no_proxy", "*")
- asyncio.run(run(sys.argv[1]))
|