|
|
@@ -0,0 +1,285 @@
|
|
|
+"""
|
|
|
+工序维度词 → 分类树匹配(category-match)。
|
|
|
+
|
|
|
+把 mode_process 某帖的工序解构里五个维度的词(实质/形式/类型/作用/动作)送到
|
|
|
+分类匹配接口,拿到每个词命中的分类节点,再把命中的「实质/形式」分类名回写到
|
|
|
+对应 step 的 substanceMatch / formMatch 字段(并落库 mode_process)。
|
|
|
+
|
|
|
+外部接口(入参/返回见 README 或本文件 _post_category_match):
|
|
|
+ POST {CATEGORY_MATCH_API}/api/v1/category-match
|
|
|
+
|
|
|
+维度 → source_type 映射(取自 mode_process.steps[]):
|
|
|
+ 实质 ← steps[].substance
|
|
|
+ 形式 ← steps[].form
|
|
|
+ 类型 ← steps[].inputs[].type
|
|
|
+ 作用 ← steps[].effect
|
|
|
+ 动作 ← steps[].action
|
|
|
+
|
|
|
+term 拆分:每个值按「、」拆(括号内的「、」不拆,见 _split_values),空格保留;
|
|
|
+同一 source_type 下的 term 去重(跨 step / 跨工序)。
|
|
|
+
|
|
|
+回写规则(1.2):对每个 step,把 substance/form 同样按「、」拆成子项,逐子项查它在
|
|
|
+返回里命中的分类 name(按 source_type 过滤;择优:优先「分类名==子项」的精确同名,
|
|
|
+无同名再退回分数最高的一条),多个子项的 name 用「、」拼接 → substanceMatch /
|
|
|
+formMatch;无命中为 None。
|
|
|
+
|
|
|
+对外 FastAPI(供前端/批处理调用):
|
|
|
+ POST /category-match/single body={"query_id":..,"case_id":..}
|
|
|
+ POST /category-match/batch body={"query_id":[..],"case_id":[..]} (并发,见 BATCH_CONCURRENCY)
|
|
|
+
|
|
|
+启动:
|
|
|
+ python stages/category_match.py # 默认 0.0.0.0:8780
|
|
|
+ python stages/category_match.py 8090 # 指定端口
|
|
|
+ CATEGORY_MATCH_API=http://host:8300 python stages/category_match.py
|
|
|
+"""
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import asyncio
|
|
|
+import os
|
|
|
+import sys
|
|
|
+from pathlib import Path
|
|
|
+from typing import List, Optional
|
|
|
+
|
|
|
+import httpx
|
|
|
+from fastapi import FastAPI, HTTPException
|
|
|
+from pydantic import BaseModel
|
|
|
+
|
|
|
+# 复用项目根的 db(stages 脚本统一从上级目录导入)
|
|
|
+sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
+import db # noqa: E402
|
|
|
+
|
|
|
+# ── 配置(后续按需调整)──────────────────────────────────────────────────────────
|
|
|
+CATEGORY_MATCH_API = os.environ.get("CATEGORY_MATCH_API", "http://47.236.83.130:8300").rstrip("/")
|
|
|
+MATCH_ENDPOINT = "/api/v1/category-match"
|
|
|
+TOP_K = int(os.environ.get("CATEGORY_MATCH_TOP_K", "10")) # 后续再调
|
|
|
+MIN_SCORE = float(os.environ.get("CATEGORY_MATCH_MIN_SCORE", "0.8")) # 后续再调
|
|
|
+RECORD = True
|
|
|
+API_TIMEOUT = float(os.environ.get("CATEGORY_MATCH_TIMEOUT", "30"))
|
|
|
+BATCH_CONCURRENCY = int(os.environ.get("CATEGORY_MATCH_CONCURRENCY", "8")) # 批量并发上限
|
|
|
+
|
|
|
+# 维度 → source_type(外部接口约定的中文标签)
|
|
|
+ST_SUBSTANCE = "实质"
|
|
|
+ST_FORM = "形式"
|
|
|
+ST_TYPE = "类型"
|
|
|
+ST_EFFECT = "作用"
|
|
|
+ST_ACTION = "动作"
|
|
|
+
|
|
|
+
|
|
|
+# ── term 拆分(按「、」拆,括号内不拆,空格保留;去重保序)─────────────────────────────
|
|
|
+def _s(v) -> str:
|
|
|
+ """取值规整为字符串:None→"",其余 str().strip()。"""
|
|
|
+ return "" if v is None else str(v).strip()
|
|
|
+
|
|
|
+
|
|
|
+def _split_values(raw: str) -> List[str]:
|
|
|
+ """按顿号分割,括号内的顿号不作为分隔符,结果去重保序(口径同
|
|
|
+ import_process_knowledge._split_values)。空格保留,不按空格拆。
|
|
|
+ "高保真线框图、UI设计稿" → ["高保真线框图", "UI设计稿"]
|
|
|
+ "修改后的照片(发型、服装)、二次元服装" → ["修改后的照片(发型、服装)", "二次元服装"]
|
|
|
+ "3D 毛绒玩具风" → ["3D 毛绒玩具风"] (含空格的整体词不拆碎)
|
|
|
+ """
|
|
|
+ parts, current, depth = [], [], 0
|
|
|
+ for ch in raw:
|
|
|
+ if ch in ("(", "("):
|
|
|
+ depth += 1
|
|
|
+ current.append(ch)
|
|
|
+ elif ch in (")", ")"):
|
|
|
+ depth -= 1
|
|
|
+ current.append(ch)
|
|
|
+ elif ch == "、" and depth == 0:
|
|
|
+ part = "".join(current).strip()
|
|
|
+ if part:
|
|
|
+ parts.append(part)
|
|
|
+ current = []
|
|
|
+ else:
|
|
|
+ current.append(ch)
|
|
|
+ part = "".join(current).strip()
|
|
|
+ if part:
|
|
|
+ parts.append(part)
|
|
|
+ seen, result = set(), []
|
|
|
+ for p in parts:
|
|
|
+ if p not in seen:
|
|
|
+ seen.add(p)
|
|
|
+ result.append(p)
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+# ── 组装 items(五维度 + 同 source_type 去重)──────────────────────────────────────
|
|
|
+def _iter_step_terms(step: dict):
|
|
|
+ """产出该 step 的 (source_type, term) 拆分对(未跨 step 去重)。"""
|
|
|
+ for v in _split_values(_s(step.get("substance"))):
|
|
|
+ yield (ST_SUBSTANCE, v)
|
|
|
+ for v in _split_values(_s(step.get("form"))):
|
|
|
+ yield (ST_FORM, v)
|
|
|
+ for v in _split_values(_s(step.get("effect"))):
|
|
|
+ yield (ST_EFFECT, v)
|
|
|
+ for v in _split_values(_s(step.get("action"))):
|
|
|
+ yield (ST_ACTION, v)
|
|
|
+ for inp in (step.get("inputs") or []):
|
|
|
+ for v in _split_values(_s(inp.get("type"))):
|
|
|
+ yield (ST_TYPE, v)
|
|
|
+
|
|
|
+
|
|
|
+def build_items(procedures: List[dict]) -> List[dict]:
|
|
|
+ """遍历所有工序的 steps,收集五维度词,按 (source_type, term) 全局去重。
|
|
|
+ 返回 [{"term":..,"source_type":..}]。"""
|
|
|
+ seen, items = set(), []
|
|
|
+ for proc in (procedures or []):
|
|
|
+ for step in (proc.get("steps") or []):
|
|
|
+ for st, term in _iter_step_terms(step):
|
|
|
+ key = (st, term)
|
|
|
+ if key not in seen:
|
|
|
+ seen.add(key)
|
|
|
+ items.append({"term": term, "source_type": st})
|
|
|
+ return items
|
|
|
+
|
|
|
+
|
|
|
+# ── 回写 substanceMatch / formMatch ──────────────────────────────────────────────
|
|
|
+def _build_match_lookup(resp: dict) -> dict:
|
|
|
+ """返回 {(term, source_type): name}。择优规则:优先「分类名 == term」的精确同名候选,
|
|
|
+ 无同名再退回分数最高的;分数也并列时取先出现的(即下游返回顺序的第一个)。
|
|
|
+ 用 (is_exact, score) 元组排序:精确同名(True>False)永远压过非同名,同档再比分数。"""
|
|
|
+ best = {} # (term, st) -> (rank, name);rank=(is_exact, score),元组比大小
|
|
|
+ for item in (resp.get("items") or []):
|
|
|
+ term = item.get("term")
|
|
|
+ for m in (item.get("matches") or []):
|
|
|
+ st, name, score = m.get("source_type"), m.get("name"), m.get("score") or 0
|
|
|
+ if term is None or st is None or name is None:
|
|
|
+ continue
|
|
|
+ key = (term, st)
|
|
|
+ rank = (name == term, score) # 精确同名优先,其次分数高
|
|
|
+ if key not in best or rank > best[key][0]:
|
|
|
+ best[key] = (rank, name)
|
|
|
+ return {k: v[1] for k, v in best.items()}
|
|
|
+
|
|
|
+
|
|
|
+def enrich_steps(procedures: List[dict], resp: dict) -> List[dict]:
|
|
|
+ """逐子项匹配后拼接:对每个 step,把 substance/form 按「、」拆,逐子项查命中的
|
|
|
+ 分类 name(按 source_type 过滤),多个用「、」拼接写入 substanceMatch/formMatch;
|
|
|
+ 无命中写 None。原地修改 procedures 并返回。"""
|
|
|
+ lookup = _build_match_lookup(resp)
|
|
|
+
|
|
|
+ def match_for(raw, st):
|
|
|
+ names = []
|
|
|
+ for part in _split_values(_s(raw)):
|
|
|
+ name = lookup.get((part, st))
|
|
|
+ if name and name not in names:
|
|
|
+ names.append(name)
|
|
|
+ return "、".join(names) if names else None
|
|
|
+
|
|
|
+ for proc in (procedures or []):
|
|
|
+ for step in (proc.get("steps") or []):
|
|
|
+ step["substanceMatch"] = match_for(step.get("substance"), ST_SUBSTANCE)
|
|
|
+ step["formMatch"] = match_for(step.get("form"), ST_FORM)
|
|
|
+ return procedures
|
|
|
+
|
|
|
+
|
|
|
+# ── 调外部接口 ────────────────────────────────────────────────────────────────────
|
|
|
+async def _post_category_match(client: httpx.AsyncClient, post_id: str, knowledge_id: str,
|
|
|
+ items: List[dict]) -> dict:
|
|
|
+ body = {
|
|
|
+ "top_k": TOP_K,
|
|
|
+ "min_score": MIN_SCORE,
|
|
|
+ "record": RECORD,
|
|
|
+ "post_id": post_id,
|
|
|
+ "knowledge_id": knowledge_id,
|
|
|
+ "items": items,
|
|
|
+ }
|
|
|
+ r = await client.post(CATEGORY_MATCH_API + MATCH_ENDPOINT, json=body)
|
|
|
+ r.raise_for_status()
|
|
|
+ return r.json()
|
|
|
+
|
|
|
+
|
|
|
+# ── 单帖全流程(取数 → 调接口 → 回写 → 落库)─────────────────────────────────────────
|
|
|
+async def process_one(client: httpx.AsyncClient, query_id: str, case_id: str,
|
|
|
+ *, include_response: bool = False) -> dict:
|
|
|
+ """对一帖 (query_id=post_id, case_id=knowledge_id) 跑完整流程。绝不抛异常,
|
|
|
+ 错误以 {"ok": False, "error": ...} 返回,便于批量聚合。"""
|
|
|
+ base = {"query_id": query_id, "case_id": case_id}
|
|
|
+ try:
|
|
|
+ payload = await asyncio.to_thread(db.fetch_process_by_query, query_id, case_id)
|
|
|
+ if not payload:
|
|
|
+ return {**base, "ok": False, "error": "无工序解构记录"}
|
|
|
+ procedures = payload["procedures"]
|
|
|
+ version = payload["version"]
|
|
|
+
|
|
|
+ items = build_items(procedures)
|
|
|
+ if not items:
|
|
|
+ return {**base, "ok": True, "version": version, "items_sent": 0,
|
|
|
+ "rows_updated": 0, "note": "无可匹配维度词,跳过接口调用"}
|
|
|
+
|
|
|
+ resp = await _post_category_match(client, query_id, case_id, items)
|
|
|
+ if not resp.get("success"):
|
|
|
+ return {**base, "ok": False, "version": version,
|
|
|
+ "error": "category-match 返回 success=false", "response": resp}
|
|
|
+
|
|
|
+ enrich_steps(procedures, resp)
|
|
|
+ rows_updated = await asyncio.to_thread(
|
|
|
+ db.update_process_steps_by_query, query_id, case_id, version,
|
|
|
+ [p.get("steps") or [] for p in procedures])
|
|
|
+
|
|
|
+ out = {**base, "ok": True, "version": version, "items_sent": len(items),
|
|
|
+ "rows_updated": rows_updated, "recorded": resp.get("recorded")}
|
|
|
+ if include_response:
|
|
|
+ out["response"] = resp
|
|
|
+ out["procedures"] = procedures
|
|
|
+ return out
|
|
|
+ except httpx.HTTPError as e:
|
|
|
+ return {**base, "ok": False, "error": f"调用 category-match 失败: {type(e).__name__}: {e}"}
|
|
|
+ except Exception as e:
|
|
|
+ return {**base, "ok": False, "error": f"{type(e).__name__}: {e}"}
|
|
|
+
|
|
|
+
|
|
|
+# ── FastAPI ──────────────────────────────────────────────────────────────────────
|
|
|
+app = FastAPI(title="mode_workflow · category-match", version="1.0")
|
|
|
+
|
|
|
+
|
|
|
+class SingleReq(BaseModel):
|
|
|
+ query_id: str
|
|
|
+ case_id: str
|
|
|
+
|
|
|
+
|
|
|
+class BatchReq(BaseModel):
|
|
|
+ query_id: List[str]
|
|
|
+ case_id: List[str]
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/category-match/single")
|
|
|
+async def category_match_single(req: SingleReq):
|
|
|
+ """单帖:query_id→post_id,case_id→knowledge_id。返回命中详情 + 回写结果。"""
|
|
|
+ async with httpx.AsyncClient(timeout=API_TIMEOUT) as client:
|
|
|
+ return await process_one(client, req.query_id, req.case_id, include_response=True)
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/category-match/batch")
|
|
|
+async def category_match_batch(req: BatchReq):
|
|
|
+ """批量:query_id[i] 与 case_id[i] 一一配对,受 BATCH_CONCURRENCY 限流并发。
|
|
|
+ 每帖结果含 ok/error;返回总体统计 + 明细(明细不含完整 response,避免响应过大)。"""
|
|
|
+ if len(req.query_id) != len(req.case_id):
|
|
|
+ raise HTTPException(400, f"query_id({len(req.query_id)}) 与 case_id({len(req.case_id)}) 长度不一致")
|
|
|
+ pairs = list(zip(req.query_id, req.case_id))
|
|
|
+ if not pairs:
|
|
|
+ return {"total": 0, "ok": 0, "failed": 0, "results": []}
|
|
|
+
|
|
|
+ sem = asyncio.Semaphore(BATCH_CONCURRENCY)
|
|
|
+ # 整个批次共用一个连接池(keep-alive),并发受信号量约束,避免压垮下游接口
|
|
|
+ async with httpx.AsyncClient(
|
|
|
+ timeout=API_TIMEOUT,
|
|
|
+ limits=httpx.Limits(max_connections=BATCH_CONCURRENCY,
|
|
|
+ max_keepalive_connections=BATCH_CONCURRENCY),
|
|
|
+ ) as client:
|
|
|
+ async def _one(q, c):
|
|
|
+ async with sem:
|
|
|
+ return await process_one(client, q, c, include_response=False)
|
|
|
+
|
|
|
+ results = await asyncio.gather(*[_one(q, c) for q, c in pairs])
|
|
|
+
|
|
|
+ ok = sum(1 for r in results if r.get("ok"))
|
|
|
+ return {"total": len(results), "ok": ok, "failed": len(results) - ok, "results": results}
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ import uvicorn
|
|
|
+ port = int(sys.argv[1]) if len(sys.argv) > 1 else int(os.environ.get("CATEGORY_MATCH_PORT", "8780"))
|
|
|
+ print(f"🚀 category-match 服务 → http://0.0.0.0:{port} (下游 {CATEGORY_MATCH_API}{MATCH_ENDPOINT})")
|
|
|
+ uvicorn.run(app, host="0.0.0.0", port=port)
|