"""
制作案例 → 制作需求 → 内容树节点匹配
流程:
1a. 逐帖子分析(带图片),提炼每个帖子的制作需求(并发)
1b. 合并去重相似需求,保留 case_id 溯源
2. 按合并后的需求语义搜索内容树 + 取父子节点
3. 每个需求独立过 LLM 挂载决策(并发)
4. 输出 case → demand → node 关系表 + 可视化
"""
import asyncio
import json
import os
import re
import sys
from pathlib import Path
import httpx
# 添加项目根目录
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from dotenv import load_dotenv
load_dotenv()
from agent.llm.qwen import qwen_llm_call
# ===== 配置 =====
BASE_DIR = Path(__file__).parent
CONTENT_TREE_BASE = "http://8.147.104.190:8001"
CATEGORY_TREE_PATH = BASE_DIR / "prompts" / "category_tree.json"
DEFAULT_MODEL = "qwen-plus"
VISION_MODEL = "qwen-vl-max"
SOURCE_TYPES = ["实质", "形式", "意图"]
# 加载本地分类树
_CATEGORY_TREE_CACHE = None
def load_category_tree() -> dict:
"""加载本地分类树(缓存)"""
global _CATEGORY_TREE_CACHE
if _CATEGORY_TREE_CACHE is None:
with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f:
_CATEGORY_TREE_CACHE = json.load(f)
return _CATEGORY_TREE_CACHE
def collect_all_nodes(node: dict, nodes: list, parent_path: str = ""):
"""递归收集所有节点,展平树结构"""
if "id" in node:
node_copy = {
"entity_id": node["id"],
"name": node["name"],
"path": node.get("path", parent_path),
"source_type": node.get("source_type"),
"description": node.get("description") or "",
"level": node.get("level"),
"parent_id": node.get("parent_id"),
"element_count": node.get("element_count", 0),
}
nodes.append(node_copy)
if "children" in node:
current_path = node.get("path", parent_path)
for child in node["children"]:
collect_all_nodes(child, nodes, current_path)
def search_local_tree(query: str, source_type: str, top_k: int = 5) -> list:
"""在本地分类树中搜索节点"""
tree = load_category_tree()
all_nodes = []
collect_all_nodes(tree, all_nodes)
# 过滤:只保留指定维度
filtered = [n for n in all_nodes if n.get("source_type") == source_type]
# 简单文本匹配打分
query_lower = query.lower()
scored = []
for node in filtered:
name = node["name"].lower()
desc = node["description"].lower()
score = 0.0
# 名称完全匹配
if query_lower == name:
score = 1.0
# 名称包含
elif query_lower in name:
score = 0.8
# 描述包含
elif query_lower in desc:
score = 0.5
# 名称被包含(反向)
elif name in query_lower:
score = 0.6
if score > 0:
node["score"] = score
node["entity_type"] = "category" # 本地树都是 category
scored.append(node)
# 按分数排序,取 top_k
scored.sort(key=lambda x: x["score"], reverse=True)
return scored[:top_k]
# ===== Prompt 加载 =====
def load_prompt(filename: str) -> dict:
"""加载 .prompt 文件,解析 frontmatter 和 $role$ 分段"""
path = BASE_DIR / "prompts" / filename
text = path.read_text(encoding="utf-8")
config = {}
if text.startswith("---"):
_, fm, text = text.split("---", 2)
for line in fm.strip().splitlines():
if ":" in line:
k, v = line.split(":", 1)
k, v = k.strip(), v.strip()
if v.replace(".", "", 1).isdigit():
v = float(v) if "." in v else int(v)
config[k] = v
messages = []
parts = re.split(r'^\$(\w+)\$\s*$', text.strip(), flags=re.MULTILINE)
for i in range(1, len(parts), 2):
role = parts[i].strip()
content = parts[i + 1].strip() if i + 1 < len(parts) else ""
messages.append({"role": role, "content": content})
return {"config": config, "messages": messages}
def render_messages(prompt_data: dict, variables: dict) -> list[dict]:
"""用变量替换 prompt 模板中的 {var} 占位符"""
rendered = []
for msg in prompt_data["messages"]:
content = msg["content"]
for k, v in variables.items():
content = content.replace(f"{{{k}}}", str(v))
rendered.append({"role": msg["role"], "content": content})
return rendered
def parse_json_response(content: str) -> list | dict:
"""清理 LLM 输出中的 markdown 包裹并解析 JSON"""
content = content.strip()
if content.startswith("```"):
content = content.split("\n", 1)[1]
content = content.rsplit("```", 1)[0]
try:
return json.loads(content)
except json.JSONDecodeError as e:
# 尝试自动修复:替换字符串中的未转义引号
print(f"\n[JSON 解析错误] 尝试自动修复...")
# 简单修复:把字符串值中的单独双引号替换为单引号
import re
# 匹配 "key": "value with "quote" inside"
# 这个正则会找到字符串值中间的引号并替换
fixed = re.sub(r'("(?:evidence|description|demand_name)":\s*")([^"]*)"([^"]*)"([^"]*")',
r'\1\2\'\3\4', content)
try:
return json.loads(fixed)
except json.JSONDecodeError:
# 修复失败,打印错误位置
start = max(0, e.pos - 100)
end = min(len(content), e.pos + 100)
print(f"\n[JSON 解析错误] 位置 {e.pos}:")
print(f"...{content[start:end]}...")
print(f"\n完整内容已保存到 json_error.txt")
with open("json_error.txt", "w", encoding="utf-8") as f:
f.write(content)
raise
# ===== 内容树 API =====
async def search_content_tree(client: httpx.AsyncClient, query: str, source_type: str, top_k: int = 5) -> list:
params = {
"q": query, "source_type": source_type, "entity_type": "all",
"top_k": top_k, "use_description": "true",
"include_ancestors": "false", "descendant_depth": 0,
}
resp = await client.get(f"{CONTENT_TREE_BASE}/api/agent/search", params=params)
resp.raise_for_status()
return resp.json().get("results", [])
async def get_category_tree(client: httpx.AsyncClient, entity_id: int, source_type: str) -> dict:
params = {
"source_type": source_type, "include_ancestors": "true", "descendant_depth": 1,
}
resp = await client.get(f"{CONTENT_TREE_BASE}/api/agent/search/category/{entity_id}", params=params)
resp.raise_for_status()
return resp.json()
# ===== Step 1a:逐帖子提需求(支持图片) =====
def build_case_content(case: dict) -> list[dict]:
"""
构造单个帖子的多模态消息内容。
返回 OpenAI vision 格式的 content 数组。
"""
parts = []
# 文本部分(兼容多种字段名)
title = case.get("title") or case.get("video_title") or case.get("post_title", "未知")
text_lines = [f"## 案例:{title}"]
text_lines.append(f"来源:{case.get('source', '')}")
if case.get("user_input"):
text_lines.append(f"用户输入:{json.dumps(case['user_input'], ensure_ascii=False)}")
if case.get("output_description"):
text_lines.append(f"输出效果:{case['output_description']}")
if case.get("key_findings"):
text_lines.append(f"关键发现:{case['key_findings']}")
parts.append({"type": "text", "text": "\n".join(text_lines)})
# 图片部分(兼容多种字段名)
images = case.get("images") or case.get("effect_images", [])
for img_url in images[:4]: # 最多4张图,控制 token
parts.append({
"type": "image_url",
"image_url": {"url": img_url},
})
return parts
async def extract_demands_for_case(case: dict) -> dict:
"""对单个帖子提取需求"""
prompt_data = load_prompt("step1_extract_demands.prompt")
model = prompt_data["config"].get("model", VISION_MODEL)
temperature = prompt_data["config"].get("temperature", 0.3)
case_content = build_case_content(case)
has_images = any(p["type"] == "image_url" for p in case_content)
# 如果没有图片,降级用文本模型省成本
if not has_images:
model = DEFAULT_MODEL
# 纯文本模式:content 用字符串即可
text_only = "\n".join(p["text"] for p in case_content if p["type"] == "text")
messages = render_messages(prompt_data, {"case_content": text_only})
else:
# 多模态模式:user message 的 content 用数组
messages = []
for msg in prompt_data["messages"]:
if msg["role"] == "user":
# 把 prompt 模板文本放在图片前面
template_text = msg["content"].replace("{case_content}", "")
content = case_content.copy()
if template_text.strip():
content.insert(0, {"type": "text", "text": template_text})
messages.append({"role": "user", "content": content})
else:
messages.append(msg)
try:
result = await qwen_llm_call(messages, model=model, temperature=temperature)
demands = parse_json_response(result["content"])
case_id = case.get("case_id", 0)
# 给每个需求打上 case_id
for d in demands:
d["source_case_id"] = case_id
return {"case_id": case_id, "title": case.get("title", ""), "demands": demands}
except Exception as e:
case_id = case.get("case_id", 0)
print(f" [!] case {case_id} 提取失败: {e}")
return {"case_id": case_id, "title": case.get("title", ""), "demands": [], "error": str(e)}
# ===== Step 1b:合并去重 =====
async def merge_demands(all_case_demands: list[dict]) -> list[dict]:
"""用 LLM 合并相似需求,超过30个时分批处理"""
# 构造输入:展平所有 case 的 demands,标注来源
flat = []
for cd in all_case_demands:
for d in cd["demands"]:
flat.append({
"case_id": cd["case_id"],
"case_title": cd["title"],
"demand_name": d["demand_name"],
"description": d["description"],
"search_keywords": d["search_keywords"],
"evidence": d.get("evidence", ""),
})
if not flat:
return []
prompt_data = load_prompt("step1b_merge_demands.prompt")
model = prompt_data["config"].get("model", DEFAULT_MODEL)
temperature = prompt_data["config"].get("temperature", 0.3)
# 分批处理:每批最多30个需求
BATCH_SIZE = 30
if len(flat) <= BATCH_SIZE:
# 不需要分批
all_demands_json = json.dumps(flat, ensure_ascii=False, indent=2)
messages = render_messages(prompt_data, {"all_demands_json": all_demands_json})
result = await qwen_llm_call(messages, model=model, temperature=temperature)
merged = parse_json_response(result["content"])
# 统一转换 source_case_ids 为字符串
for m in merged:
m["source_case_ids"] = [str(cid) for cid in m.get("source_case_ids", [])]
return merged
else:
# 分批合并
print(f" 需求数量 {len(flat)} 超过 {BATCH_SIZE},分批合并...")
batches = [flat[i:i + BATCH_SIZE] for i in range(0, len(flat), BATCH_SIZE)]
batch_results = []
for i, batch in enumerate(batches, 1):
print(f" 批次 {i}/{len(batches)}: {len(batch)} 个需求")
batch_json = json.dumps(batch, ensure_ascii=False, indent=2)
messages = render_messages(prompt_data, {"all_demands_json": batch_json})
result = await qwen_llm_call(messages, model=model, temperature=temperature)
batch_merged = parse_json_response(result["content"])
batch_results.extend(batch_merged)
# 如果分批后的结果仍然很多,再做一次最终合并
if len(batch_results) > BATCH_SIZE:
print(f" 批次合并后仍有 {len(batch_results)} 个需求,进行最终合并...")
final_json = json.dumps(batch_results, ensure_ascii=False, indent=2)
messages = render_messages(prompt_data, {"all_demands_json": final_json})
result = await qwen_llm_call(messages, model=model, temperature=temperature)
final_merged = parse_json_response(result["content"])
# 统一转换 source_case_ids 为字符串
for m in final_merged:
m["source_case_ids"] = [str(cid) for cid in m.get("source_case_ids", [])]
return final_merged
else:
# 统一转换 source_case_ids 为字符串
for m in batch_results:
m["source_case_ids"] = [str(cid) for cid in m.get("source_case_ids", [])]
return batch_results
# ===== Step 2:语义搜索 + 取父子节点 =====
async def step2_search_and_expand(demands: list[dict]) -> dict:
"""语义搜索本地分类树 + 取父子节点"""
all_results = {}
tree = load_category_tree()
for demand in demands:
name = demand["demand_name"]
keywords = demand["search_keywords"]
all_results[name] = {}
for source_type in SOURCE_TYPES:
nodes = []
seen = set()
for kw in keywords:
# 使用本地搜索替代 API
results = search_local_tree(kw, source_type, top_k=3)
for r in results:
eid = r.get("entity_id")
if eid in seen:
continue
seen.add(eid)
node = {
"entity_id": eid,
"name": r.get("name", ""),
"entity_type": "category",
"score": r.get("score", 0),
"description": r.get("description", ""),
"path": r.get("path", ""),
}
# 本地树暂不支持动态取父子,先留空
node["ancestors"] = []
node["children"] = []
nodes.append(node)
if nodes:
all_results[name][source_type] = nodes
return all_results
# ===== Step 3:逐需求挂载决策(并发) =====
def build_single_demand_context(demand: dict, nodes_by_dim: dict) -> str:
parts = []
parts.append(f"## 需求:{demand['demand_name']}")
parts.append(f"描述:{demand['description']}")
parts.append(f"来源帖子:case_id={demand.get('source_case_ids', [])}")
if demand.get("evidence"):
parts.append(f"数据依据:{demand['evidence']}")
if not nodes_by_dim:
parts.append("(未搜索到相关节点)")
return "\n".join(parts)
for source_type, nodes in nodes_by_dim.items():
parts.append(f"\n### {source_type}维度匹配节点:")
for n in nodes:
line = f"- [{n['entity_type']}] entity_id={n['entity_id']} \"{n['name']}\" score={n['score']:.2f}"
if n.get("description"):
line += f" 描述: {n['description']}"
parts.append(line)
if n.get("ancestors"):
path = " > ".join(a["name"] for a in n["ancestors"])
parts.append(f" 父链: {path}")
if n.get("children"):
kids = ", ".join(c["name"] for c in n["children"][:10])
parts.append(f" 子节点: {kids}")
return "\n".join(parts)
async def mount_single_demand(demand: dict, nodes_by_dim: dict) -> dict:
prompt_data = load_prompt("step3_mount_decision.prompt")
model = prompt_data["config"].get("model", DEFAULT_MODEL)
temperature = prompt_data["config"].get("temperature", 0.3)
node_context = build_single_demand_context(demand, nodes_by_dim)
messages = render_messages(prompt_data, {"node_context": node_context})
result = await qwen_llm_call(messages, model=model, temperature=temperature, max_tokens=4096)
# 解析结构化 JSON 输出
try:
decision = parse_json_response(result["content"])
except Exception as e:
print(f" [!] 挂载决策解析失败 ({demand['demand_name']}): {e}")
decision = {"demand_name": demand["demand_name"], "mounted_nodes": [], "notes": f"解析失败: {e}"}
return {
"demand_name": demand["demand_name"],
"source_case_ids": demand.get("source_case_ids", []),
"decision": decision,
}
async def step3_mount_decisions(demands: list[dict], search_results: dict) -> list[dict]:
tasks = []
for demand in demands:
nodes_by_dim = search_results.get(demand["demand_name"], {})
tasks.append(mount_single_demand(demand, nodes_by_dim))
return await asyncio.gather(*tasks)
# ===== 可视化:生成关系表 + HTML =====
def build_relation_table(all_case_demands: list[dict], merged_demands: list[dict], decisions: list[dict]) -> list[dict]:
"""构建 case → demand → node 关系表"""
# 建立 demand_name → decision 的映射
decision_map = {d["demand_name"]: d for d in decisions}
rows = []
for md in merged_demands:
dn = md["demand_name"]
case_ids = md.get("source_case_ids", [])
dec = decision_map.get(dn, {})
rows.append({
"demand_name": dn,
"description": md["description"],
"source_case_ids": case_ids,
"mount_decision": dec.get("decision", ""),
})
return rows
def generate_html_visualization(cases_data: dict, all_case_demands: list[dict],
merged_demands: list[dict], decisions: list[dict],
search_results: dict) -> str:
"""生成简洁的 case→demand 树状图,节点作为标签显示在 demand 框里"""
# 构建完整的 case_map,兼容多种字段名
case_map = {}
for c in cases_data.get("cases", []):
case_map[c["case_id"]] = {
"title": c.get("title") or c.get("video_title") or c.get("post_title", ""),
"images": c.get("images") or c.get("effect_images", []),
"link": c.get("source_link") or c.get("video_url") or c.get("post_url", "")
}
# 构建 decision_map:demand_name → mounted_nodes
decision_map = {}
for d in decisions:
dec = d.get("decision", {})
if isinstance(dec, dict):
decision_map[d["demand_name"]] = dec.get("mounted_nodes", [])
else:
decision_map[d["demand_name"]] = []
nodes_js = []
edges_js = []
node_id = 0
# 为每个 demand 创建节点,并为其来源的每个 case 创建独立的 case 节点
for md in merged_demands:
demand_id = node_id
node_id += 1
# 构建 demand 节点的显示内容
dn = md["demand_name"]
desc = md["description"][:100] + "..." if len(md["description"]) > 100 else md["description"]
# 从挂载决策中获取最终选择的节点
mounted = decision_map.get(dn, [])
node_tags = [f"{n['name']}({n.get('source_type', '')})" for n in mounted]
tags_html = " | ".join(node_tags) if node_tags else "无挂载节点"
# demand 节点的 HTML 标签(包含挂载节点)
demand_label = f"{dn}\n\n[挂载] {tags_html}"
demand_title = f"{dn}\n\n{desc}\n\n挂载节点: {tags_html}"
nodes_js.append({
"id": demand_id,
"label": demand_label,
"title": demand_title,
"group": "demand",
"level": 1,
"shape": "box",
"font": {"size": 12}
})
# 为每个来源 case 创建独立的 case 节点
for cid in md.get("source_case_ids", []):
case_id = node_id
node_id += 1
# 类型转换:case_map 的 key 可能是字符串或整数
cid_str = str(cid)
case_info = case_map.get(cid_str) or case_map.get(cid, {"title": f"Case {cid}", "images": [], "link": ""})
# label 只显示标题
title_short = case_info['title'][:50] + "..." if len(case_info['title']) > 50 else case_info['title']
case_label = f"Case {cid}\n{title_short}"
# 构建带图片的 HTML tooltip(不含链接)
img_html = ""
if case_info['images']:
img_url = case_info['images'][0]
img_html = f''
case_title = f'