Преглед изворни кода

feat(mode_workflow): 添加query_id过滤功能与分类匹配服务

为/api/queries和/api/dashboard接口新增query_list参数,支持逗号分隔、重复参数、URL编码JSON数组三种写法
新增_parse_query_list工具函数统一解析不同格式的query_list参数,去重保序
扩展db.fetch_all_posts函数,新增query_ids参数支持按query_id过滤数据,空列表直接返回空结果
新增fetch_process_by_query和update_process_steps_by_query数据库操作函数
新增category_match.py分类匹配服务脚本,提供单帖和批量处理的FastAPI接口
更新工序接口文档,补充参数说明与使用示例
刘文武 пре 2 дана
родитељ
комит
cb6aae030f

+ 62 - 4
examples/mode_workflow/db.py

@@ -582,10 +582,12 @@ def fetch_post(query_id, case_id, table="search_process"):
     return row
 
 
-def fetch_all_posts(mode="process", *, adopted_only=False, distinct=False,
+def fetch_all_posts(mode="process", *, query_ids=None, adopted_only=False, distinct=False,
                     limit=None, offset=0):
     """某方向「全部帖子」:跨所有 query 的列表(瘦身列,口径同 fetch_posts,不拉
-    body/videos/llm_evaluation 大字段)。fetch_posts 限定单 query,本函数取全表。
+    body/videos/llm_evaluation 大字段)。fetch_posts 限定单 query,本函数默认取全表。
+      - query_ids:选填 query_id 列表,传了就 WHERE query_id IN(...) 只取这些 query
+        的帖子(SQL 层过滤,不拉全表);None=全部,[]=空结果。
       - adopted_only=True:只返回采纳帖(is_adopted_rel 口径,rel/repro 由
         _REL_SQL/_REPRO_SQL 直取标量算,不拉整表 blob)。
       - distinct=True:按 case_id 去重(同一帖被多个 query 搜到时,只保留
@@ -593,6 +595,12 @@ def fetch_all_posts(mode="process", *, adopted_only=False, distinct=False,
       - limit/offset:分页(limit=None 不分页)。
     返回 (total, rows):total 为过滤(+去重)后的总条数,rows 为本页切片。"""
     table = _search_table(mode)
+    where, params = "", []
+    if query_ids is not None:
+        if not query_ids:
+            return 0, []   # 显式空列表:直接空结果,不必查库
+        where = " WHERE query_id IN (" + ",".join(["%s"] * len(query_ids)) + ")"
+        params = list(query_ids)
     conn = _conn()
     try:
         with conn.cursor() as cur:
@@ -600,8 +608,8 @@ def fetch_all_posts(mode="process", *, adopted_only=False, distinct=False,
                                    title, url, content_type, images, like_count, publish_time,
                                    quality_score, quality_grade, found_by, knowledge_type, overall_score,
                                    {_REL_SQL} AS rel, {_REPRO_SQL} AS repro
-                            FROM {table}
-                            ORDER BY overall_score DESC, id""")
+                            FROM {table}{where}
+                            ORDER BY overall_score DESC, id""", params)
             rows = cur.fetchall()
             # has_process/has_tools 全局判定:跨 query 的「该帖是否已解构」,两张解构表各取一次
             cur.execute("SELECT DISTINCT case_id FROM mode_process")
@@ -726,6 +734,56 @@ def fetch_process(case_id, version=None):
     return _proc_payload(case_id, version, rows)
 
 
+def fetch_process_by_query(query_id, case_id, version=None):
+    """同 fetch_process,但用 (query_id, case_id) 精确定位某 query 下该帖的工序
+    (category-match 用:post_id=query_id / knowledge_id=case_id)。
+    version=None 取该 (query_id, case_id) 下最新真实版(link_ 排后)。无行返回 None。"""
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            if version is None:
+                cur.execute("""SELECT version FROM mode_process WHERE query_id=%s AND case_id=%s
+                               ORDER BY (LEFT(version,5)='link_') ASC, id DESC LIMIT 1""",
+                            (query_id, case_id))
+                row = cur.fetchone()
+                if not row:
+                    return None
+                version = row["version"]
+            cur.execute("""SELECT * FROM mode_process WHERE query_id=%s AND case_id=%s AND version=%s
+                           ORDER BY seq, id""", (query_id, case_id, version))
+            rows = cur.fetchall()
+    finally:
+        conn.close()
+    return _proc_payload(case_id, version, rows)
+
+
+def update_process_steps_by_query(query_id, case_id, version, steps_in_order):
+    """按工序顺序覆盖某 (query_id, case_id, version) 各行的 steps JSON 列。
+    steps_in_order 必须与 fetch_process_by_query 返回的 procedures 同序(均按 seq, id 升序);
+    按行 id 一一对应更新,稳健于 seq 不连续。行数与工序数不符则报错回滚。返回更新行数。"""
+    conn = _conn()
+    try:
+        conn.begin()
+        with conn.cursor() as cur:
+            cur.execute("""SELECT id FROM mode_process
+                           WHERE query_id=%s AND case_id=%s AND version=%s
+                           ORDER BY seq, id""", (query_id, case_id, version))
+            ids = [r["id"] for r in cur.fetchall()]
+            if len(ids) != len(steps_in_order):
+                raise ValueError(f"行数({len(ids)})与工序数({len(steps_in_order)})不一致")
+            n = 0
+            for row_id, steps in zip(ids, steps_in_order):
+                cur.execute("UPDATE mode_process SET steps=%s WHERE id=%s", (_j(steps), row_id))
+                n += cur.rowcount
+        conn.commit()
+        return n
+    except Exception:
+        conn.rollback()
+        raise
+    finally:
+        conn.close()
+
+
 def _proc_payload(case_id, version, rows):
     """mode_process 行集 → {case_id, version, …, procedures:[...]}。无行返回 None。"""
     if not rows:

+ 35 - 1
examples/mode_workflow/server.py

@@ -275,6 +275,33 @@ def _queries_cached(mode):
     return data
 
 
+def _parse_query_list(raw_query):
+    """从 query string 解析 query_list(选填);三种写法统一成 list[str](去重保序):
+      重复参数  ?query_list=q1&query_list=q2
+      逗号分隔  ?query_list=q1,q2
+      JSON 数组 ?query_list=["q1","q2"]
+    未提供返回 None(=查全部);提供了(哪怕显式空数组)返回 list(可能为空,=过滤后空结果)。"""
+    vals = parse_qs(raw_query).get("query_list")
+    if not vals:
+        return None
+    out = []
+    for v in vals:
+        v = v.strip()
+        if v.startswith("[") and v.endswith("]"):
+            try:
+                out.extend(str(x).strip() for x in json.loads(v) if str(x).strip())
+                continue
+            except Exception:
+                pass
+        out.extend(p.strip() for p in v.split(",") if p.strip())
+    seen, res = set(), []
+    for q in out:
+        if q not in seen:
+            seen.add(q)
+            res.append(q)
+    return res
+
+
 def _dashboard():
     posts, procs, tools = db.fetch_dashboard_rows()
 
@@ -611,7 +638,13 @@ class Handler(BaseHTTPRequestHandler):
             elif u.path == "/api/dashboard":
                 self._json_etag(_dashboard_cached())
             elif u.path == "/api/queries":
-                self._json_etag(_queries_cached(qs.get("mode", "process")))
+                # 选填 query_list:传了就按 query_id 过滤(顺序保持服务端默认排序),不传查全部
+                data = _queries_cached(qs.get("mode", "process"))
+                ql = _parse_query_list(u.query)
+                if ql is not None:
+                    wanted = set(ql)
+                    data = [q for q in data if q["query_id"] in wanted]
+                self._json_etag(data)
             elif u.path == "/api/posts":
                 self._json(db.fetch_posts(qs.get("query_id", ""), qs.get("mode", "process")))
             elif u.path == "/api/all_posts":
@@ -623,6 +656,7 @@ class Handler(BaseHTTPRequestHandler):
                     return self._err("page/page_size 须为整数", 400)
                 total, posts = db.fetch_all_posts(
                     qs.get("mode", "process"),
+                    query_ids=_parse_query_list(u.query),   # 选填:只查这些 query 的帖子,不传查全部
                     adopted_only=qs.get("adopted") in ("1", "true"),
                     distinct=qs.get("distinct") in ("1", "true"),
                     limit=page_size, offset=(page - 1) * page_size)

+ 285 - 0
examples/mode_workflow/stages/category_match.py

@@ -0,0 +1,285 @@
+"""
+工序维度词 → 分类树匹配(category-match)。
+
+把 mode_process 某帖的工序解构里五个维度的词(实质/形式/类型/作用/动作)送到
+分类匹配接口,拿到每个词命中的分类节点,再把命中的「实质/形式」分类名回写到
+对应 step 的 substanceMatch / formMatch 字段(并落库 mode_process)。
+
+外部接口(入参/返回见 README 或本文件 _post_category_match):
+    POST {CATEGORY_MATCH_API}/api/v1/category-match
+
+维度 → source_type 映射(取自 mode_process.steps[]):
+    实质 ← steps[].substance
+    形式 ← steps[].form
+    类型 ← steps[].inputs[].type
+    作用 ← steps[].effect
+    动作 ← steps[].action
+
+term 拆分:每个值按「、」拆(括号内的「、」不拆,见 _split_values),空格保留;
+同一 source_type 下的 term 去重(跨 step / 跨工序)。
+
+回写规则(1.2):对每个 step,把 substance/form 同样按「、」拆成子项,逐子项查它在
+返回里命中的分类 name(按 source_type 过滤;择优:优先「分类名==子项」的精确同名,
+无同名再退回分数最高的一条),多个子项的 name 用「、」拼接 → substanceMatch /
+formMatch;无命中为 None。
+
+对外 FastAPI(供前端/批处理调用):
+    POST /category-match/single  body={"query_id":..,"case_id":..}
+    POST /category-match/batch   body={"query_id":[..],"case_id":[..]}  (并发,见 BATCH_CONCURRENCY)
+
+启动:
+    python stages/category_match.py            # 默认 0.0.0.0:8780
+    python stages/category_match.py 8090        # 指定端口
+    CATEGORY_MATCH_API=http://host:8300 python stages/category_match.py
+"""
+from __future__ import annotations
+
+import asyncio
+import os
+import sys
+from pathlib import Path
+from typing import List, Optional
+
+import httpx
+from fastapi import FastAPI, HTTPException
+from pydantic import BaseModel
+
+# 复用项目根的 db(stages 脚本统一从上级目录导入)
+sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
+import db  # noqa: E402
+
+# ── 配置(后续按需调整)──────────────────────────────────────────────────────────
+CATEGORY_MATCH_API = os.environ.get("CATEGORY_MATCH_API", "http://47.236.83.130:8300").rstrip("/")
+MATCH_ENDPOINT = "/api/v1/category-match"
+TOP_K = int(os.environ.get("CATEGORY_MATCH_TOP_K", "10"))          # 后续再调
+MIN_SCORE = float(os.environ.get("CATEGORY_MATCH_MIN_SCORE", "0.8"))  # 后续再调
+RECORD = True
+API_TIMEOUT = float(os.environ.get("CATEGORY_MATCH_TIMEOUT", "30"))
+BATCH_CONCURRENCY = int(os.environ.get("CATEGORY_MATCH_CONCURRENCY", "8"))  # 批量并发上限
+
+# 维度 → source_type(外部接口约定的中文标签)
+ST_SUBSTANCE = "实质"
+ST_FORM = "形式"
+ST_TYPE = "类型"
+ST_EFFECT = "作用"
+ST_ACTION = "动作"
+
+
+# ── term 拆分(按「、」拆,括号内不拆,空格保留;去重保序)─────────────────────────────
+def _s(v) -> str:
+    """取值规整为字符串:None→"",其余 str().strip()。"""
+    return "" if v is None else str(v).strip()
+
+
+def _split_values(raw: str) -> List[str]:
+    """按顿号分割,括号内的顿号不作为分隔符,结果去重保序(口径同
+    import_process_knowledge._split_values)。空格保留,不按空格拆。
+      "高保真线框图、UI设计稿"            → ["高保真线框图", "UI设计稿"]
+      "修改后的照片(发型、服装)、二次元服装" → ["修改后的照片(发型、服装)", "二次元服装"]
+      "3D 毛绒玩具风"                     → ["3D 毛绒玩具风"]   (含空格的整体词不拆碎)
+    """
+    parts, current, depth = [], [], 0
+    for ch in raw:
+        if ch in ("(", "("):
+            depth += 1
+            current.append(ch)
+        elif ch in (")", ")"):
+            depth -= 1
+            current.append(ch)
+        elif ch == "、" and depth == 0:
+            part = "".join(current).strip()
+            if part:
+                parts.append(part)
+            current = []
+        else:
+            current.append(ch)
+    part = "".join(current).strip()
+    if part:
+        parts.append(part)
+    seen, result = set(), []
+    for p in parts:
+        if p not in seen:
+            seen.add(p)
+            result.append(p)
+    return result
+
+
+# ── 组装 items(五维度 + 同 source_type 去重)──────────────────────────────────────
+def _iter_step_terms(step: dict):
+    """产出该 step 的 (source_type, term) 拆分对(未跨 step 去重)。"""
+    for v in _split_values(_s(step.get("substance"))):
+        yield (ST_SUBSTANCE, v)
+    for v in _split_values(_s(step.get("form"))):
+        yield (ST_FORM, v)
+    for v in _split_values(_s(step.get("effect"))):
+        yield (ST_EFFECT, v)
+    for v in _split_values(_s(step.get("action"))):
+        yield (ST_ACTION, v)
+    for inp in (step.get("inputs") or []):
+        for v in _split_values(_s(inp.get("type"))):
+            yield (ST_TYPE, v)
+
+
+def build_items(procedures: List[dict]) -> List[dict]:
+    """遍历所有工序的 steps,收集五维度词,按 (source_type, term) 全局去重。
+    返回 [{"term":..,"source_type":..}]。"""
+    seen, items = set(), []
+    for proc in (procedures or []):
+        for step in (proc.get("steps") or []):
+            for st, term in _iter_step_terms(step):
+                key = (st, term)
+                if key not in seen:
+                    seen.add(key)
+                    items.append({"term": term, "source_type": st})
+    return items
+
+
+# ── 回写 substanceMatch / formMatch ──────────────────────────────────────────────
+def _build_match_lookup(resp: dict) -> dict:
+    """返回 {(term, source_type): name}。择优规则:优先「分类名 == term」的精确同名候选,
+    无同名再退回分数最高的;分数也并列时取先出现的(即下游返回顺序的第一个)。
+    用 (is_exact, score) 元组排序:精确同名(True>False)永远压过非同名,同档再比分数。"""
+    best = {}  # (term, st) -> (rank, name);rank=(is_exact, score),元组比大小
+    for item in (resp.get("items") or []):
+        term = item.get("term")
+        for m in (item.get("matches") or []):
+            st, name, score = m.get("source_type"), m.get("name"), m.get("score") or 0
+            if term is None or st is None or name is None:
+                continue
+            key = (term, st)
+            rank = (name == term, score)   # 精确同名优先,其次分数高
+            if key not in best or rank > best[key][0]:
+                best[key] = (rank, name)
+    return {k: v[1] for k, v in best.items()}
+
+
+def enrich_steps(procedures: List[dict], resp: dict) -> List[dict]:
+    """逐子项匹配后拼接:对每个 step,把 substance/form 按「、」拆,逐子项查命中的
+    分类 name(按 source_type 过滤),多个用「、」拼接写入 substanceMatch/formMatch;
+    无命中写 None。原地修改 procedures 并返回。"""
+    lookup = _build_match_lookup(resp)
+
+    def match_for(raw, st):
+        names = []
+        for part in _split_values(_s(raw)):
+            name = lookup.get((part, st))
+            if name and name not in names:
+                names.append(name)
+        return "、".join(names) if names else None
+
+    for proc in (procedures or []):
+        for step in (proc.get("steps") or []):
+            step["substanceMatch"] = match_for(step.get("substance"), ST_SUBSTANCE)
+            step["formMatch"] = match_for(step.get("form"), ST_FORM)
+    return procedures
+
+
+# ── 调外部接口 ────────────────────────────────────────────────────────────────────
+async def _post_category_match(client: httpx.AsyncClient, post_id: str, knowledge_id: str,
+                               items: List[dict]) -> dict:
+    body = {
+        "top_k": TOP_K,
+        "min_score": MIN_SCORE,
+        "record": RECORD,
+        "post_id": post_id,
+        "knowledge_id": knowledge_id,
+        "items": items,
+    }
+    r = await client.post(CATEGORY_MATCH_API + MATCH_ENDPOINT, json=body)
+    r.raise_for_status()
+    return r.json()
+
+
+# ── 单帖全流程(取数 → 调接口 → 回写 → 落库)─────────────────────────────────────────
+async def process_one(client: httpx.AsyncClient, query_id: str, case_id: str,
+                      *, include_response: bool = False) -> dict:
+    """对一帖 (query_id=post_id, case_id=knowledge_id) 跑完整流程。绝不抛异常,
+    错误以 {"ok": False, "error": ...} 返回,便于批量聚合。"""
+    base = {"query_id": query_id, "case_id": case_id}
+    try:
+        payload = await asyncio.to_thread(db.fetch_process_by_query, query_id, case_id)
+        if not payload:
+            return {**base, "ok": False, "error": "无工序解构记录"}
+        procedures = payload["procedures"]
+        version = payload["version"]
+
+        items = build_items(procedures)
+        if not items:
+            return {**base, "ok": True, "version": version, "items_sent": 0,
+                    "rows_updated": 0, "note": "无可匹配维度词,跳过接口调用"}
+
+        resp = await _post_category_match(client, query_id, case_id, items)
+        if not resp.get("success"):
+            return {**base, "ok": False, "version": version,
+                    "error": "category-match 返回 success=false", "response": resp}
+
+        enrich_steps(procedures, resp)
+        rows_updated = await asyncio.to_thread(
+            db.update_process_steps_by_query, query_id, case_id, version,
+            [p.get("steps") or [] for p in procedures])
+
+        out = {**base, "ok": True, "version": version, "items_sent": len(items),
+               "rows_updated": rows_updated, "recorded": resp.get("recorded")}
+        if include_response:
+            out["response"] = resp
+            out["procedures"] = procedures
+        return out
+    except httpx.HTTPError as e:
+        return {**base, "ok": False, "error": f"调用 category-match 失败: {type(e).__name__}: {e}"}
+    except Exception as e:
+        return {**base, "ok": False, "error": f"{type(e).__name__}: {e}"}
+
+
+# ── FastAPI ──────────────────────────────────────────────────────────────────────
+app = FastAPI(title="mode_workflow · category-match", version="1.0")
+
+
+class SingleReq(BaseModel):
+    query_id: str
+    case_id: str
+
+
+class BatchReq(BaseModel):
+    query_id: List[str]
+    case_id: List[str]
+
+
+@app.post("/category-match/single")
+async def category_match_single(req: SingleReq):
+    """单帖:query_id→post_id,case_id→knowledge_id。返回命中详情 + 回写结果。"""
+    async with httpx.AsyncClient(timeout=API_TIMEOUT) as client:
+        return await process_one(client, req.query_id, req.case_id, include_response=True)
+
+
+@app.post("/category-match/batch")
+async def category_match_batch(req: BatchReq):
+    """批量:query_id[i] 与 case_id[i] 一一配对,受 BATCH_CONCURRENCY 限流并发。
+    每帖结果含 ok/error;返回总体统计 + 明细(明细不含完整 response,避免响应过大)。"""
+    if len(req.query_id) != len(req.case_id):
+        raise HTTPException(400, f"query_id({len(req.query_id)}) 与 case_id({len(req.case_id)}) 长度不一致")
+    pairs = list(zip(req.query_id, req.case_id))
+    if not pairs:
+        return {"total": 0, "ok": 0, "failed": 0, "results": []}
+
+    sem = asyncio.Semaphore(BATCH_CONCURRENCY)
+    # 整个批次共用一个连接池(keep-alive),并发受信号量约束,避免压垮下游接口
+    async with httpx.AsyncClient(
+        timeout=API_TIMEOUT,
+        limits=httpx.Limits(max_connections=BATCH_CONCURRENCY,
+                            max_keepalive_connections=BATCH_CONCURRENCY),
+    ) as client:
+        async def _one(q, c):
+            async with sem:
+                return await process_one(client, q, c, include_response=False)
+
+        results = await asyncio.gather(*[_one(q, c) for q, c in pairs])
+
+    ok = sum(1 for r in results if r.get("ok"))
+    return {"total": len(results), "ok": ok, "failed": len(results) - ok, "results": results}
+
+
+if __name__ == "__main__":
+    import uvicorn
+    port = int(sys.argv[1]) if len(sys.argv) > 1 else int(os.environ.get("CATEGORY_MATCH_PORT", "8780"))
+    print(f"🚀 category-match 服务 → http://0.0.0.0:{port}  (下游 {CATEGORY_MATCH_API}{MATCH_ENDPOINT})")
+    uvicorn.run(app, host="0.0.0.0", port=port)

+ 19 - 0
examples/mode_workflow/工序接口文档.md

@@ -58,6 +58,19 @@
 ### GET /api/queries?mode=process
 某方向搜索表派生的 query 列表(含工序解构进度)。带 ETag。
 
+**Query 参数**
+| 参数 | 必填 | 缺省 | 说明 |
+|---|---|---|---|
+| `mode` | 否 | `process` | 方向(`process` / `tools`) |
+| `query_list` | 否 | 全部 | 只查指定 `query_id`,**不传 = 查全部**。1 个或多个均可,支持三种写法(见下) |
+
+`query_list` 三种等价写法(结果按服务端默认排序,命中不到的 id 自动忽略):
+- 逗号分隔:`?query_list=q0031,q0032`
+- 重复参数:`?query_list=q0031&query_list=q0032`
+- JSON 数组:`?query_list=["q0031","q0032"]`(需 URL 编码)
+
+> 单个就传 `?query_list=q0031`(或 `["q0031"]`)。传了但全不命中 → 返回空数组 `[]`。
+
 **响应**:`Array<Query>`
 ```jsonc
 [{
@@ -101,11 +114,17 @@
 | 参数 | 必填 | 缺省 | 说明 |
 |---|---|---|---|
 | `mode` | 否 | `process` | 方向(`process` / `tools`) |
+| `query_list` | 否 | 全部 | 只查指定 `query_id` 下的帖子,**不传 = 查全部**。1 个或多个,写法同 `/api/queries`(逗号分隔 / 重复参数 / JSON 数组,见下) |
 | `page` | 否 | `1` | 页码(从 1 起) |
 | `page_size` | 否 | `100` | 每页条数(上限 500) |
 | `adopted` | 否 | `0` | `1`/`true` 只返回采纳帖(`is_adopted_rel` 口径) |
 | `distinct` | 否 | `0` | `1`/`true` 按 `case_id` 去重(同帖被多 query 搜到时只保留 `overall_score` 最高的一行) |
 
+`query_list` 三种等价写法(在 SQL 层 `WHERE query_id IN(...)` 过滤,先过滤再去重/分页;命中不到的 id 自动忽略):
+- 逗号分隔:`?query_list=q0031,q0032`
+- 重复参数:`?query_list=q0031&query_list=q0032`
+- JSON 数组:`?query_list=["q0031","q0032"]`(需 URL 编码)
+
 > `page`/`page_size` 非整数返回 `400 {"error":"page/page_size 须为整数"}`。
 
 **响应**