""" 制作案例 → 制作需求 → 内容树节点匹配 流程: 1a. 逐帖子分析(带图片),提炼每个帖子的制作需求(并发) 1b. 合并去重相似需求,保留 case_id 溯源 2. 按合并后的需求语义搜索内容树 + 取父子节点 3. 每个需求独立过 LLM 挂载决策(并发) 4. 输出 case → demand → node 关系表 + 可视化 """ import asyncio import json import os import re import sys from pathlib import Path import httpx # 添加项目根目录 sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from dotenv import load_dotenv load_dotenv() from agent.llm.qwen import qwen_llm_call # ===== 配置 ===== BASE_DIR = Path(__file__).parent CONTENT_TREE_BASE = "http://8.147.104.190:8001" CATEGORY_TREE_PATH = BASE_DIR / "prompts" / "category_tree.json" DEFAULT_MODEL = "qwen-plus" VISION_MODEL = "qwen-vl-max" SOURCE_TYPES = ["实质", "形式", "意图"] # 加载本地分类树 _CATEGORY_TREE_CACHE = None def load_category_tree() -> dict: """加载本地分类树(缓存)""" global _CATEGORY_TREE_CACHE if _CATEGORY_TREE_CACHE is None: with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f: _CATEGORY_TREE_CACHE = json.load(f) return _CATEGORY_TREE_CACHE def collect_all_nodes(node: dict, nodes: list, parent_path: str = ""): """递归收集所有节点,展平树结构""" if "id" in node: node_copy = { "entity_id": node["id"], "name": node["name"], "path": node.get("path", parent_path), "source_type": node.get("source_type"), "description": node.get("description") or "", "level": node.get("level"), "parent_id": node.get("parent_id"), "element_count": node.get("element_count", 0), } nodes.append(node_copy) if "children" in node: current_path = node.get("path", parent_path) for child in node["children"]: collect_all_nodes(child, nodes, current_path) def search_local_tree(query: str, source_type: str, top_k: int = 5) -> list: """在本地分类树中搜索节点""" tree = load_category_tree() all_nodes = [] collect_all_nodes(tree, all_nodes) # 过滤:只保留指定维度 filtered = [n for n in all_nodes if n.get("source_type") == source_type] # 简单文本匹配打分 query_lower = query.lower() scored = [] for node in filtered: name = node["name"].lower() desc = node["description"].lower() score = 0.0 # 名称完全匹配 if query_lower == name: score = 1.0 # 名称包含 elif query_lower in name: score = 0.8 # 描述包含 elif query_lower in desc: score = 0.5 # 名称被包含(反向) elif name in query_lower: score = 0.6 if score > 0: node["score"] = score node["entity_type"] = "category" # 本地树都是 category scored.append(node) # 按分数排序,取 top_k scored.sort(key=lambda x: x["score"], reverse=True) return scored[:top_k] # ===== Prompt 加载 ===== def load_prompt(filename: str) -> dict: """加载 .prompt 文件,解析 frontmatter 和 $role$ 分段""" path = BASE_DIR / "prompts" / filename text = path.read_text(encoding="utf-8") config = {} if text.startswith("---"): _, fm, text = text.split("---", 2) for line in fm.strip().splitlines(): if ":" in line: k, v = line.split(":", 1) k, v = k.strip(), v.strip() if v.replace(".", "", 1).isdigit(): v = float(v) if "." in v else int(v) config[k] = v messages = [] parts = re.split(r'^\$(\w+)\$\s*$', text.strip(), flags=re.MULTILINE) for i in range(1, len(parts), 2): role = parts[i].strip() content = parts[i + 1].strip() if i + 1 < len(parts) else "" messages.append({"role": role, "content": content}) return {"config": config, "messages": messages} def render_messages(prompt_data: dict, variables: dict) -> list[dict]: """用变量替换 prompt 模板中的 {var} 占位符""" rendered = [] for msg in prompt_data["messages"]: content = msg["content"] for k, v in variables.items(): content = content.replace(f"{{{k}}}", str(v)) rendered.append({"role": msg["role"], "content": content}) return rendered def parse_json_response(content: str) -> list | dict: """清理 LLM 输出中的 markdown 包裹并解析 JSON""" content = content.strip() if content.startswith("```"): content = content.split("\n", 1)[1] content = content.rsplit("```", 1)[0] try: return json.loads(content) except json.JSONDecodeError as e: # 尝试自动修复:替换字符串中的未转义引号 print(f"\n[JSON 解析错误] 尝试自动修复...") # 简单修复:把字符串值中的单独双引号替换为单引号 import re # 匹配 "key": "value with "quote" inside" # 这个正则会找到字符串值中间的引号并替换 fixed = re.sub(r'("(?:evidence|description|demand_name)":\s*")([^"]*)"([^"]*)"([^"]*")', r'\1\2\'\3\4', content) try: return json.loads(fixed) except json.JSONDecodeError: # 修复失败,打印错误位置 start = max(0, e.pos - 100) end = min(len(content), e.pos + 100) print(f"\n[JSON 解析错误] 位置 {e.pos}:") print(f"...{content[start:end]}...") print(f"\n完整内容已保存到 json_error.txt") with open("json_error.txt", "w", encoding="utf-8") as f: f.write(content) raise # ===== 内容树 API ===== async def search_content_tree(client: httpx.AsyncClient, query: str, source_type: str, top_k: int = 5) -> list: params = { "q": query, "source_type": source_type, "entity_type": "all", "top_k": top_k, "use_description": "true", "include_ancestors": "false", "descendant_depth": 0, } resp = await client.get(f"{CONTENT_TREE_BASE}/api/agent/search", params=params) resp.raise_for_status() return resp.json().get("results", []) async def get_category_tree(client: httpx.AsyncClient, entity_id: int, source_type: str) -> dict: params = { "source_type": source_type, "include_ancestors": "true", "descendant_depth": 1, } resp = await client.get(f"{CONTENT_TREE_BASE}/api/agent/search/category/{entity_id}", params=params) resp.raise_for_status() return resp.json() # ===== Step 1a:逐帖子提需求(支持图片) ===== def build_case_content(case: dict) -> list[dict]: """ 构造单个帖子的多模态消息内容。 返回 OpenAI vision 格式的 content 数组。 """ parts = [] # 文本部分(兼容多种字段名) title = case.get("title") or case.get("video_title") or case.get("post_title", "未知") text_lines = [f"## 案例:{title}"] text_lines.append(f"来源:{case.get('source', '')}") if case.get("user_input"): text_lines.append(f"用户输入:{json.dumps(case['user_input'], ensure_ascii=False)}") if case.get("output_description"): text_lines.append(f"输出效果:{case['output_description']}") if case.get("key_findings"): text_lines.append(f"关键发现:{case['key_findings']}") parts.append({"type": "text", "text": "\n".join(text_lines)}) # 图片部分(兼容多种字段名) images = case.get("images") or case.get("effect_images", []) for img_url in images[:4]: # 最多4张图,控制 token parts.append({ "type": "image_url", "image_url": {"url": img_url}, }) return parts async def extract_demands_for_case(case: dict) -> dict: """对单个帖子提取需求""" prompt_data = load_prompt("step1_extract_demands.prompt") model = prompt_data["config"].get("model", VISION_MODEL) temperature = prompt_data["config"].get("temperature", 0.3) case_content = build_case_content(case) has_images = any(p["type"] == "image_url" for p in case_content) # 如果没有图片,降级用文本模型省成本 if not has_images: model = DEFAULT_MODEL # 纯文本模式:content 用字符串即可 text_only = "\n".join(p["text"] for p in case_content if p["type"] == "text") messages = render_messages(prompt_data, {"case_content": text_only}) else: # 多模态模式:user message 的 content 用数组 messages = [] for msg in prompt_data["messages"]: if msg["role"] == "user": # 把 prompt 模板文本放在图片前面 template_text = msg["content"].replace("{case_content}", "") content = case_content.copy() if template_text.strip(): content.insert(0, {"type": "text", "text": template_text}) messages.append({"role": "user", "content": content}) else: messages.append(msg) try: result = await qwen_llm_call(messages, model=model, temperature=temperature) demands = parse_json_response(result["content"]) case_id = case.get("case_id", 0) # 给每个需求打上 case_id for d in demands: d["source_case_id"] = case_id return {"case_id": case_id, "title": case.get("title", ""), "demands": demands} except Exception as e: case_id = case.get("case_id", 0) print(f" [!] case {case_id} 提取失败: {e}") return {"case_id": case_id, "title": case.get("title", ""), "demands": [], "error": str(e)} # ===== Step 1b:合并去重 ===== async def merge_demands(all_case_demands: list[dict]) -> list[dict]: """用 LLM 合并相似需求,超过30个时分批处理""" # 构造输入:展平所有 case 的 demands,标注来源 flat = [] for cd in all_case_demands: for d in cd["demands"]: flat.append({ "case_id": cd["case_id"], "case_title": cd["title"], "demand_name": d["demand_name"], "description": d["description"], "search_keywords": d["search_keywords"], "evidence": d.get("evidence", ""), }) if not flat: return [] prompt_data = load_prompt("step1b_merge_demands.prompt") model = prompt_data["config"].get("model", DEFAULT_MODEL) temperature = prompt_data["config"].get("temperature", 0.3) # 分批处理:每批最多30个需求 BATCH_SIZE = 30 if len(flat) <= BATCH_SIZE: # 不需要分批 all_demands_json = json.dumps(flat, ensure_ascii=False, indent=2) messages = render_messages(prompt_data, {"all_demands_json": all_demands_json}) result = await qwen_llm_call(messages, model=model, temperature=temperature) merged = parse_json_response(result["content"]) # 统一转换 source_case_ids 为字符串 for m in merged: m["source_case_ids"] = [str(cid) for cid in m.get("source_case_ids", [])] return merged else: # 分批合并 print(f" 需求数量 {len(flat)} 超过 {BATCH_SIZE},分批合并...") batches = [flat[i:i + BATCH_SIZE] for i in range(0, len(flat), BATCH_SIZE)] batch_results = [] for i, batch in enumerate(batches, 1): print(f" 批次 {i}/{len(batches)}: {len(batch)} 个需求") batch_json = json.dumps(batch, ensure_ascii=False, indent=2) messages = render_messages(prompt_data, {"all_demands_json": batch_json}) result = await qwen_llm_call(messages, model=model, temperature=temperature) batch_merged = parse_json_response(result["content"]) batch_results.extend(batch_merged) # 如果分批后的结果仍然很多,再做一次最终合并 if len(batch_results) > BATCH_SIZE: print(f" 批次合并后仍有 {len(batch_results)} 个需求,进行最终合并...") final_json = json.dumps(batch_results, ensure_ascii=False, indent=2) messages = render_messages(prompt_data, {"all_demands_json": final_json}) result = await qwen_llm_call(messages, model=model, temperature=temperature) final_merged = parse_json_response(result["content"]) # 统一转换 source_case_ids 为字符串 for m in final_merged: m["source_case_ids"] = [str(cid) for cid in m.get("source_case_ids", [])] return final_merged else: # 统一转换 source_case_ids 为字符串 for m in batch_results: m["source_case_ids"] = [str(cid) for cid in m.get("source_case_ids", [])] return batch_results # ===== Step 2:语义搜索 + 取父子节点 ===== async def step2_search_and_expand(demands: list[dict]) -> dict: """语义搜索本地分类树 + 取父子节点""" all_results = {} tree = load_category_tree() for demand in demands: name = demand["demand_name"] keywords = demand["search_keywords"] all_results[name] = {} for source_type in SOURCE_TYPES: nodes = [] seen = set() for kw in keywords: # 使用本地搜索替代 API results = search_local_tree(kw, source_type, top_k=3) for r in results: eid = r.get("entity_id") if eid in seen: continue seen.add(eid) node = { "entity_id": eid, "name": r.get("name", ""), "entity_type": "category", "score": r.get("score", 0), "description": r.get("description", ""), "path": r.get("path", ""), } # 本地树暂不支持动态取父子,先留空 node["ancestors"] = [] node["children"] = [] nodes.append(node) if nodes: all_results[name][source_type] = nodes return all_results # ===== Step 3:逐需求挂载决策(并发) ===== def build_single_demand_context(demand: dict, nodes_by_dim: dict) -> str: parts = [] parts.append(f"## 需求:{demand['demand_name']}") parts.append(f"描述:{demand['description']}") parts.append(f"来源帖子:case_id={demand.get('source_case_ids', [])}") if demand.get("evidence"): parts.append(f"数据依据:{demand['evidence']}") if not nodes_by_dim: parts.append("(未搜索到相关节点)") return "\n".join(parts) for source_type, nodes in nodes_by_dim.items(): parts.append(f"\n### {source_type}维度匹配节点:") for n in nodes: line = f"- [{n['entity_type']}] entity_id={n['entity_id']} \"{n['name']}\" score={n['score']:.2f}" if n.get("description"): line += f" 描述: {n['description']}" parts.append(line) if n.get("ancestors"): path = " > ".join(a["name"] for a in n["ancestors"]) parts.append(f" 父链: {path}") if n.get("children"): kids = ", ".join(c["name"] for c in n["children"][:10]) parts.append(f" 子节点: {kids}") return "\n".join(parts) async def mount_single_demand(demand: dict, nodes_by_dim: dict) -> dict: prompt_data = load_prompt("step3_mount_decision.prompt") model = prompt_data["config"].get("model", DEFAULT_MODEL) temperature = prompt_data["config"].get("temperature", 0.3) node_context = build_single_demand_context(demand, nodes_by_dim) messages = render_messages(prompt_data, {"node_context": node_context}) result = await qwen_llm_call(messages, model=model, temperature=temperature, max_tokens=4096) # 解析结构化 JSON 输出 try: decision = parse_json_response(result["content"]) except Exception as e: print(f" [!] 挂载决策解析失败 ({demand['demand_name']}): {e}") decision = {"demand_name": demand["demand_name"], "mounted_nodes": [], "notes": f"解析失败: {e}"} return { "demand_name": demand["demand_name"], "source_case_ids": demand.get("source_case_ids", []), "decision": decision, } async def step3_mount_decisions(demands: list[dict], search_results: dict) -> list[dict]: tasks = [] for demand in demands: nodes_by_dim = search_results.get(demand["demand_name"], {}) tasks.append(mount_single_demand(demand, nodes_by_dim)) return await asyncio.gather(*tasks) # ===== 可视化:生成关系表 + HTML ===== def build_relation_table(all_case_demands: list[dict], merged_demands: list[dict], decisions: list[dict]) -> list[dict]: """构建 case → demand → node 关系表""" # 建立 demand_name → decision 的映射 decision_map = {d["demand_name"]: d for d in decisions} rows = [] for md in merged_demands: dn = md["demand_name"] case_ids = md.get("source_case_ids", []) dec = decision_map.get(dn, {}) rows.append({ "demand_name": dn, "description": md["description"], "source_case_ids": case_ids, "mount_decision": dec.get("decision", ""), }) return rows def generate_html_visualization(cases_data: dict, all_case_demands: list[dict], merged_demands: list[dict], decisions: list[dict], search_results: dict) -> str: """生成简洁的 case→demand 树状图,节点作为标签显示在 demand 框里""" # 构建完整的 case_map,兼容多种字段名 case_map = {} for c in cases_data.get("cases", []): case_map[c["case_id"]] = { "title": c.get("title") or c.get("video_title") or c.get("post_title", ""), "images": c.get("images") or c.get("effect_images", []), "link": c.get("source_link") or c.get("video_url") or c.get("post_url", "") } # 构建 decision_map:demand_name → mounted_nodes decision_map = {} for d in decisions: dec = d.get("decision", {}) if isinstance(dec, dict): decision_map[d["demand_name"]] = dec.get("mounted_nodes", []) else: decision_map[d["demand_name"]] = [] nodes_js = [] edges_js = [] node_id = 0 # 为每个 demand 创建节点,并为其来源的每个 case 创建独立的 case 节点 for md in merged_demands: demand_id = node_id node_id += 1 # 构建 demand 节点的显示内容 dn = md["demand_name"] desc = md["description"][:100] + "..." if len(md["description"]) > 100 else md["description"] # 从挂载决策中获取最终选择的节点 mounted = decision_map.get(dn, []) node_tags = [f"{n['name']}({n.get('source_type', '')})" for n in mounted] tags_html = " | ".join(node_tags) if node_tags else "无挂载节点" # demand 节点的 HTML 标签(包含挂载节点) demand_label = f"{dn}\n\n[挂载] {tags_html}" demand_title = f"{dn}\n\n{desc}\n\n挂载节点: {tags_html}" nodes_js.append({ "id": demand_id, "label": demand_label, "title": demand_title, "group": "demand", "level": 1, "shape": "box", "font": {"size": 12} }) # 为每个来源 case 创建独立的 case 节点 for cid in md.get("source_case_ids", []): case_id = node_id node_id += 1 # 类型转换:case_map 的 key 可能是字符串或整数 cid_str = str(cid) case_info = case_map.get(cid_str) or case_map.get(cid, {"title": f"Case {cid}", "images": [], "link": ""}) # label 只显示标题 title_short = case_info['title'][:50] + "..." if len(case_info['title']) > 50 else case_info['title'] case_label = f"Case {cid}\n{title_short}" # 构建带图片的 HTML tooltip(不含链接) img_html = "" if case_info['images']: img_url = case_info['images'][0] img_html = f'' case_title = f'
Case {cid}: {case_info["title"]}{img_html}
点击节点查看原帖
' nodes_js.append({ "id": case_id, "label": case_label, "title": case_title, "url": case_info['link'], # 存储链接,用于点击事件 "group": "case", "level": 0, "shape": "box" }) # case → demand 边 edges_js.append({"from": case_id, "to": demand_id}) html = f""" Case → Demand 树状图
树状图: 帖子 (Case) 需求 (Demand + 匹配节点)
""" return html # ===== 主流程 ===== async def run(cases_path: str): with open(cases_path, "r", encoding="utf-8") as f: cases_data = json.load(f) cases = cases_data.get("cases", []) topic = cases_data.get("topic", "未知主题") # 过滤:只保留有图片的 case(兼容多种字段名) def has_images(case): imgs = case.get("images") or case.get("effect_images") or [] return len(imgs) > 0 cases_with_images = [c for c in cases if has_images(c)] skipped = len(cases) - len(cases_with_images) print("=" * 60) print(f"输入: {cases_path}") print(f"主题: {topic} (总计 {len(cases)} 个案例)") if skipped > 0: print(f"过滤: 跳过 {skipped} 个纯视频案例,保留 {len(cases_with_images)} 个图文案例") print("=" * 60) if not cases_with_images: print("\n错误: 没有图文案例可处理") return # Step 1a: 逐帖子提需求(并发) print(f"\n第 1a 步:逐帖子提取需求({len(cases_with_images)} 个并发)...") tasks = [extract_demands_for_case(case) for case in cases_with_images] all_case_demands = await asyncio.gather(*tasks) total_raw = sum(len(cd["demands"]) for cd in all_case_demands) for cd in all_case_demands: n = len(cd["demands"]) names = ", ".join(d["demand_name"] for d in cd["demands"][:3]) suffix = "..." if n > 3 else "" print(f" Case {cd['case_id']}: {n} 个需求 [{names}{suffix}]") print(f" 共提取 {total_raw} 个原始需求") # Step 1b: 合并去重 print(f"\n第 1b 步:合并去重...") merged_demands = await merge_demands(all_case_demands) print(f" 合并后 {len(merged_demands)} 个需求:") for i, md in enumerate(merged_demands, 1): print(f" {i}. {md['demand_name']} (来自 case {md.get('source_case_ids', [])})") # Step 2: 搜索节点 print(f"\n第 2 步:语义搜索内容树 + 获取父子节点...") search_results = await step2_search_and_expand(merged_demands) for dn, dims in search_results.items(): total = sum(len(ns) for ns in dims.values()) print(f" [{dn}] {total} 个节点") # Step 3: 挂载决策(并发) print(f"\n第 3 步:LLM 挂载决策({len(merged_demands)} 个需求并发)...") decisions = await step3_mount_decisions(merged_demands, search_results) for d in decisions: print(f"\n{'─' * 40}") print(f"【{d['demand_name']}】(来自 case {d['source_case_ids']})") print(d["decision"]) # 构建关系表 relation_table = build_relation_table(all_case_demands, merged_demands, decisions) # 构建 case 信息映射(用于可视化) # 统一使用字符串作为 key,避免类型不匹配 case_info_map = {} for c in cases_with_images: case_info_map[str(c["case_id"])] = { "title": c.get("title") or c.get("video_title") or c.get("post_title", ""), "images": c.get("images") or c.get("effect_images", []), "link": c.get("source_link") or c.get("video_url") or c.get("post_url", "") } # 保存结果 output_dir = Path(cases_path).parent output_file = output_dir / "match_nodes_result.json" output_data = { "topic": topic, "case_info": case_info_map, # 新增:保存 case 信息 "per_case_demands": [ {"case_id": cd["case_id"], "title": cd["title"], "demands": [d["demand_name"] for d in cd["demands"]]} for cd in all_case_demands ], "merged_demands": merged_demands, "search_results": search_results, "mount_decisions": decisions, "relation_table": relation_table, } with open(output_file, "w", encoding="utf-8") as f: json.dump(output_data, f, ensure_ascii=False, indent=2) print(f"\n结果已保存到: {output_file}") # 生成可视化 html = generate_html_visualization(cases_data, all_case_demands, merged_demands, decisions, search_results) html_file = output_dir / "match_nodes_graph.html" with open(html_file, "w", encoding="utf-8") as f: f.write(html) print(f"可视化已保存到: {html_file}") if __name__ == "__main__": if len(sys.argv) < 2: print("用法: python match_nodes.py ") print("示例: python match_nodes.py outputs/midjourney_0/02_cases.json") sys.exit(1) os.environ.setdefault("no_proxy", "*") asyncio.run(run(sys.argv[1]))