|
|
@@ -31,6 +31,9 @@ formMatch;无命中为 None。
|
|
|
python stages/category_match.py # 默认 0.0.0.0:8780
|
|
|
python stages/category_match.py 8090 # 指定端口
|
|
|
CATEGORY_MATCH_API=http://host:8300 python stages/category_match.py
|
|
|
+
|
|
|
+可调环境变量(前缀均 CATEGORY_MATCH_):TOP_K / MIN_SCORE / TIMEOUT(默认 60s) /
|
|
|
+ CONCURRENCY(默认 8) / RETRIES(默认 3,仅超时/网络/5xx 重试,指数退避) / BACKOFF(默认 1s)。
|
|
|
"""
|
|
|
from __future__ import annotations
|
|
|
|
|
|
@@ -54,8 +57,11 @@ MATCH_ENDPOINT = "/api/v1/category-match"
|
|
|
TOP_K = int(os.environ.get("CATEGORY_MATCH_TOP_K", "10")) # 后续再调
|
|
|
MIN_SCORE = float(os.environ.get("CATEGORY_MATCH_MIN_SCORE", "0.8")) # 后续再调
|
|
|
RECORD = True
|
|
|
-API_TIMEOUT = float(os.environ.get("CATEGORY_MATCH_TIMEOUT", "30"))
|
|
|
+API_TIMEOUT = float(os.environ.get("CATEGORY_MATCH_TIMEOUT", "60")) # 下游单帖可达 30s+,默认放宽到 60
|
|
|
BATCH_CONCURRENCY = int(os.environ.get("CATEGORY_MATCH_CONCURRENCY", "8")) # 批量并发上限
|
|
|
+# 下游调用失败重试:仅对「可重试」错误(超时/网络/5xx),指数退避;业务错误(4xx)立即失败不重试
|
|
|
+MAX_RETRIES = int(os.environ.get("CATEGORY_MATCH_RETRIES", "3")) # 额外重试次数(总尝试 = 1 + N)
|
|
|
+RETRY_BACKOFF = float(os.environ.get("CATEGORY_MATCH_BACKOFF", "1.0")) # 退避基数(秒):第 k 次重试前睡 BACKOFF*2^(k-1)
|
|
|
|
|
|
# 维度 → source_type(外部接口约定的中文标签)
|
|
|
ST_SUBSTANCE = "实质"
|
|
|
@@ -174,9 +180,29 @@ def enrich_steps(procedures: List[dict], resp: dict) -> List[dict]:
|
|
|
return procedures
|
|
|
|
|
|
|
|
|
-# ── 调外部接口 ────────────────────────────────────────────────────────────────────
|
|
|
+# ── 调外部接口(带重试)──────────────────────────────────────────────────────────────
|
|
|
+class _RetryExhausted(Exception):
|
|
|
+ """重试耗尽(或遇不可重试错误)时抛出,携带 last_exc 与实际尝试次数 attempts,
|
|
|
+ 使失败结果也能报告 attempts(否则 except 分支拿不到次数)。"""
|
|
|
+ def __init__(self, last_exc: Exception, attempts: int):
|
|
|
+ self.last_exc = last_exc
|
|
|
+ self.attempts = attempts
|
|
|
+ super().__init__(str(last_exc))
|
|
|
+
|
|
|
+
|
|
|
+def _is_retryable(exc: Exception) -> bool:
|
|
|
+ """判定异常是否值得重试:超时/连接/读写等传输错误,或 5xx 服务端错误。
|
|
|
+ 4xx(如 400 请求格式错)是确定性失败,重试无意义 → 不重试。"""
|
|
|
+ if isinstance(exc, httpx.HTTPStatusError):
|
|
|
+ return exc.response.status_code >= 500
|
|
|
+ return isinstance(exc, httpx.TransportError) # 含 ReadTimeout/ConnectError 等
|
|
|
+
|
|
|
+
|
|
|
async def _post_category_match(client: httpx.AsyncClient, post_id: str, knowledge_id: str,
|
|
|
- items: List[dict]) -> dict:
|
|
|
+ items: List[dict]) -> tuple:
|
|
|
+ """POST 到下游 category-match。可重试错误(超时/网络/5xx)按指数退避重试 MAX_RETRIES 次;
|
|
|
+ 重试耗尽或遇不可重试错误时抛出最后一次异常(由 process_one 兜成 ok:False)。
|
|
|
+ 返回 (resp_json, attempts):attempts=实际尝试次数(1=一次成功,>1=重试过)。"""
|
|
|
body = {
|
|
|
"top_k": TOP_K,
|
|
|
"min_score": MIN_SCORE,
|
|
|
@@ -185,19 +211,31 @@ async def _post_category_match(client: httpx.AsyncClient, post_id: str, knowledg
|
|
|
"knowledge_id": knowledge_id,
|
|
|
"items": items,
|
|
|
}
|
|
|
- r = await client.post(CATEGORY_MATCH_API + MATCH_ENDPOINT, json=body)
|
|
|
- r.raise_for_status()
|
|
|
- return r.json()
|
|
|
+ last_exc: Optional[Exception] = None
|
|
|
+ for attempt in range(MAX_RETRIES + 1): # 第 0 次为首发,其后为重试
|
|
|
+ try:
|
|
|
+ r = await client.post(CATEGORY_MATCH_API + MATCH_ENDPOINT, json=body)
|
|
|
+ r.raise_for_status()
|
|
|
+ return r.json(), attempt + 1
|
|
|
+ except httpx.HTTPError as e:
|
|
|
+ last_exc = e
|
|
|
+ if attempt < MAX_RETRIES and _is_retryable(e):
|
|
|
+ await asyncio.sleep(RETRY_BACKOFF * (2 ** attempt)) # 1,2,4,… 秒
|
|
|
+ continue
|
|
|
+ raise _RetryExhausted(e, attempt + 1) from e # 携带尝试次数,供失败结果报告
|
|
|
+ raise _RetryExhausted(last_exc, MAX_RETRIES + 1) # 理论不可达
|
|
|
|
|
|
|
|
|
# ── 单帖全流程(取数 → 调接口 → 回写 → 落库)─────────────────────────────────────────
|
|
|
async def process_one(client: httpx.AsyncClient, query_id: str, case_id: str,
|
|
|
*, include_response: bool = False) -> dict:
|
|
|
- """对一帖 (query_id=post_id, case_id=knowledge_id) 跑完整流程。绝不抛异常,
|
|
|
- 错误以 {"ok": False, "error": ...} 返回,便于批量聚合。"""
|
|
|
+ """对一帖跑完整流程:query_id=post_id(给下游记录),case_id=knowledge_id。绝不抛异常,
|
|
|
+ 错误以 {"ok": False, "error": ...} 返回,便于批量聚合。
|
|
|
+ 取数/回写按 case 的「最新真实版」(fetch_process,与前端 /api/extract 同口径),
|
|
|
+ 保证回写的版本即前端展示的版本——否则 link_ 复制帖会写错版本、前端看不到 tag。"""
|
|
|
base = {"query_id": query_id, "case_id": case_id}
|
|
|
try:
|
|
|
- payload = await asyncio.to_thread(db.fetch_process_by_query, query_id, case_id)
|
|
|
+ payload = await asyncio.to_thread(db.fetch_process, case_id) # 最新真实版,对齐前端展示
|
|
|
if not payload:
|
|
|
return {**base, "ok": False, "error": "无工序解构记录"}
|
|
|
procedures = payload["procedures"]
|
|
|
@@ -208,28 +246,50 @@ async def process_one(client: httpx.AsyncClient, query_id: str, case_id: str,
|
|
|
return {**base, "ok": True, "version": version, "items_sent": 0,
|
|
|
"rows_updated": 0, "note": "无可匹配维度词,跳过接口调用"}
|
|
|
|
|
|
- resp = await _post_category_match(client, query_id, case_id, items)
|
|
|
+ resp, attempts = await _post_category_match(client, query_id, case_id, items)
|
|
|
if not resp.get("success"):
|
|
|
- return {**base, "ok": False, "version": version,
|
|
|
+ return {**base, "ok": False, "version": version, "attempts": attempts,
|
|
|
"error": "category-match 返回 success=false", "response": resp}
|
|
|
|
|
|
enrich_steps(procedures, resp)
|
|
|
rows_updated = await asyncio.to_thread(
|
|
|
- db.update_process_steps_by_query, query_id, case_id, version,
|
|
|
+ db.update_process_steps, case_id, version,
|
|
|
[p.get("steps") or [] for p in procedures])
|
|
|
|
|
|
out = {**base, "ok": True, "version": version, "items_sent": len(items),
|
|
|
- "rows_updated": rows_updated, "recorded": resp.get("recorded")}
|
|
|
+ "rows_updated": rows_updated, "recorded": resp.get("recorded"),
|
|
|
+ "attempts": attempts}
|
|
|
if include_response:
|
|
|
out["response"] = resp
|
|
|
out["procedures"] = procedures
|
|
|
return out
|
|
|
- except httpx.HTTPError as e:
|
|
|
- return {**base, "ok": False, "error": f"调用 category-match 失败: {type(e).__name__}: {e}"}
|
|
|
+ except _RetryExhausted as e:
|
|
|
+ return {**base, "ok": False, "attempts": e.attempts,
|
|
|
+ "error": f"调用 category-match 失败(尝试 {e.attempts} 次): "
|
|
|
+ f"{type(e.last_exc).__name__}: {e.last_exc}"}
|
|
|
except Exception as e:
|
|
|
return {**base, "ok": False, "error": f"{type(e).__name__}: {e}"}
|
|
|
|
|
|
|
|
|
+async def gather_pairs(pairs, *, on_each=None, include_response=False) -> list:
|
|
|
+ """对一批 (query_id, case_id) 并发跑 process_one(受 BATCH_CONCURRENCY 限流,
|
|
|
+ 共用一个 keep-alive 连接池)。on_each(index, result):每帖完成时回调(用于打印进度)。
|
|
|
+ 返回结果列表(顺序同 pairs)。供 FastAPI batch 与 CLI --run 共用。"""
|
|
|
+ sem = asyncio.Semaphore(BATCH_CONCURRENCY)
|
|
|
+ async with httpx.AsyncClient(
|
|
|
+ timeout=API_TIMEOUT,
|
|
|
+ limits=httpx.Limits(max_connections=BATCH_CONCURRENCY,
|
|
|
+ max_keepalive_connections=BATCH_CONCURRENCY),
|
|
|
+ ) as client:
|
|
|
+ async def _one(i, q, c):
|
|
|
+ async with sem:
|
|
|
+ r = await process_one(client, q, c, include_response=include_response)
|
|
|
+ if on_each:
|
|
|
+ on_each(i, r)
|
|
|
+ return r
|
|
|
+ return await asyncio.gather(*[_one(i, q, c) for i, (q, c) in enumerate(pairs)])
|
|
|
+
|
|
|
+
|
|
|
# ── FastAPI ──────────────────────────────────────────────────────────────────────
|
|
|
app = FastAPI(title="mode_workflow · category-match", version="1.0")
|
|
|
|
|
|
@@ -260,26 +320,51 @@ async def category_match_batch(req: BatchReq):
|
|
|
pairs = list(zip(req.query_id, req.case_id))
|
|
|
if not pairs:
|
|
|
return {"total": 0, "ok": 0, "failed": 0, "results": []}
|
|
|
+ results = await gather_pairs(pairs)
|
|
|
+ ok = sum(1 for r in results if r.get("ok"))
|
|
|
+ return {"total": len(results), "ok": ok, "failed": len(results) - ok, "results": results}
|
|
|
|
|
|
- sem = asyncio.Semaphore(BATCH_CONCURRENCY)
|
|
|
- # 整个批次共用一个连接池(keep-alive),并发受信号量约束,避免压垮下游接口
|
|
|
- async with httpx.AsyncClient(
|
|
|
- timeout=API_TIMEOUT,
|
|
|
- limits=httpx.Limits(max_connections=BATCH_CONCURRENCY,
|
|
|
- max_keepalive_connections=BATCH_CONCURRENCY),
|
|
|
- ) as client:
|
|
|
- async def _one(q, c):
|
|
|
- async with sem:
|
|
|
- return await process_one(client, q, c, include_response=False)
|
|
|
-
|
|
|
- results = await asyncio.gather(*[_one(q, c) for q, c in pairs])
|
|
|
|
|
|
+def _cli_run(query_id: str, case_ids: List[str]) -> int:
|
|
|
+ """CLI 归类:对一个 query 下的若干 case 跑归类,实时打印进度(供 server.py 起子进程、
|
|
|
+ 前端轮询日志)。返回退出码:全成功=0,有失败=1(便于任务状态判定)。"""
|
|
|
+ pairs = [(query_id, c) for c in case_ids]
|
|
|
+ if not pairs:
|
|
|
+ print("无 case 可归类"); return 0
|
|
|
+ print(f"开始归类:query_id={query_id} {len(pairs)} 帖 (下游 {CATEGORY_MATCH_API}{MATCH_ENDPOINT})", flush=True)
|
|
|
+ n = len(pairs)
|
|
|
+ cnt = {"i": 0}
|
|
|
+
|
|
|
+ def _progress(_idx, r):
|
|
|
+ cnt["i"] += 1
|
|
|
+ tag = "✓ OK " if r.get("ok") else "✗ FAIL"
|
|
|
+ extra = (f"items={r.get('items_sent')} rows_updated={r.get('rows_updated')} attempts={r.get('attempts')}"
|
|
|
+ if r.get("ok") else (r.get("error") or r.get("note") or ""))
|
|
|
+ print(f"[{cnt['i']}/{n}] {tag} {r.get('case_id')} {extra}", flush=True)
|
|
|
+
|
|
|
+ results = asyncio.run(gather_pairs(pairs, on_each=_progress))
|
|
|
ok = sum(1 for r in results if r.get("ok"))
|
|
|
- return {"total": len(results), "ok": ok, "failed": len(results) - ok, "results": results}
|
|
|
+ print(f"\n归类完成:{ok}/{n} 成功,{n - ok} 失败", flush=True)
|
|
|
+ return 0 if ok == n else 1
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
+ import argparse
|
|
|
+ ap = argparse.ArgumentParser(description="category-match:FastAPI 服务 或 CLI 归类(--run)")
|
|
|
+ ap.add_argument("port", nargs="?", type=int,
|
|
|
+ default=int(os.environ.get("CATEGORY_MATCH_PORT", "8780")),
|
|
|
+ help="服务端口(不带 --run 时生效)")
|
|
|
+ ap.add_argument("--run", action="store_true", help="CLI 归类模式:跑完即退出,不起服务")
|
|
|
+ ap.add_argument("--query-id", help="--run 用:post_id")
|
|
|
+ ap.add_argument("--case-ids", help="--run 用:逗号分隔的 case_id(knowledge_id)")
|
|
|
+ args = ap.parse_args()
|
|
|
+
|
|
|
+ if args.run:
|
|
|
+ cids = [c.strip() for c in (args.case_ids or "").split(",") if c.strip()]
|
|
|
+ if not args.query_id or not cids:
|
|
|
+ print("--run 需提供 --query-id 与 --case-ids"); sys.exit(2)
|
|
|
+ sys.exit(_cli_run(args.query_id, cids))
|
|
|
+
|
|
|
import uvicorn
|
|
|
- port = int(sys.argv[1]) if len(sys.argv) > 1 else int(os.environ.get("CATEGORY_MATCH_PORT", "8780"))
|
|
|
- print(f"🚀 category-match 服务 → http://0.0.0.0:{port} (下游 {CATEGORY_MATCH_API}{MATCH_ENDPOINT})")
|
|
|
- uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
+ print(f"🚀 category-match 服务 → http://0.0.0.0:{args.port} (下游 {CATEGORY_MATCH_API}{MATCH_ENDPOINT})")
|
|
|
+ uvicorn.run(app, host="0.0.0.0", port=args.port)
|