build_workflows.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """把一个 run 目录(如 runs_full/q0000)里**每个帖子**的 workflow.json,
  4. 与它对应的 post 信息(含 llm_evaluation)以及 query 词合并成一个 JSON。
  5. 以帖子为单位:一个 procedure 输出一个文件,文件落在 search_eval/workflows/ 下。
  6. (例:q0000 有 3 个 procedure -> 写出 3 个 json)
  7. 映射逻辑:
  8. procedure 文件夹名形如 {FORM}_{platform}_{hash前缀} 例: A_gzh_8f5fbfb0
  9. -> 读 form_{FORM}.json,在 results[] 里找 case_id 以 "{platform}_{hash前缀}" 开头的那条
  10. -> 该 result 即对应的 post(post / comments / llm_evaluation / source_url ...)
  11. -> query / original_q 取自 form_{FORM}.json 顶层
  12. 输出文件名: {run_id}_{folder}.json 例: q0000_A_gzh_8f5fbfb0.json
  13. 本模块既是 build 脚本,也是一个 HTTP 接口:
  14. * build 函数(build_run / write_run)保留,供外部 import 调用或经 POST /build 触发;
  15. * 接口本身实时扫描 workflows/ 目录,把里面所有 json 以数组形式返回。
  16. 用法(build):
  17. python build_workflows.py # 默认处理 q0000
  18. python build_workflows.py q0003 # 处理指定 run
  19. python build_workflows.py --all # 处理 runs_full 下所有 q* 目录
  20. 用法(接口):
  21. python build_workflows.py serve [port] # 默认 8771
  22. GET /workflows -> 实时扫描 workflows/*.json,返回数组
  23. GET / -> 同上(方便直接访问)
  24. POST /build -> body {"q":"q0003"} 或 {"all":true},触发 build 后返回结果
  25. """
  26. import json
  27. import os
  28. import re
  29. import sys
  30. import glob
  31. from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
  32. HERE = os.path.dirname(os.path.abspath(__file__))
  33. RUNS_DIR = os.path.join(HERE, "runs_full")
  34. OUT_DIR = os.path.join(HERE, "workflows")
  35. DEFAULT_PORT = 8771
  36. # 文件夹名: 表单字母 _ 平台 _ case_id 哈希前缀
  37. FOLDER_RE = re.compile(r"^([A-Za-z])_([a-z0-9]+)_([0-9a-fA-F]+)$")
  38. def load_json(path):
  39. with open(path, encoding="utf-8") as f:
  40. return json.load(f)
  41. def build_run(run_id, runs_dir=None):
  42. """为单个 run 目录构建合并结果。
  43. 以帖子为单位:返回一个 list,每个元素是 (folder, merged_dict),
  44. merged_dict 即单个帖子的合并 JSON(query + post + llm_evaluation + workflow)。
  45. 找不到 procedures 时返回空 list。
  46. runs_dir 缺省用模块的 RUNS_DIR;外部脚本(如 batch_extract_procedures.py 用了
  47. --output-dir)可传入自己的 runs_full,避免两边路径不一致。"""
  48. run_dir = os.path.join(runs_dir or RUNS_DIR, run_id)
  49. proc_root = os.path.join(run_dir, "procedures")
  50. if not os.path.isdir(proc_root):
  51. print(f"[skip] {run_id}: 没有 procedures/ 目录")
  52. return []
  53. # 缓存已加载的 form_{X}.json,并记录 query(取第一个见到的)
  54. forms = {}
  55. query = original_q = None
  56. platforms = None
  57. def get_form(letter):
  58. nonlocal query, original_q, platforms
  59. if letter not in forms:
  60. forms[letter] = load_json(os.path.join(run_dir, f"form_{letter}.json"))
  61. if query is None:
  62. query = forms[letter].get("query")
  63. original_q = forms[letter].get("original_q")
  64. platforms = forms[letter].get("platforms")
  65. return forms[letter]
  66. out = []
  67. for folder in sorted(os.listdir(proc_root)):
  68. folder_path = os.path.join(proc_root, folder)
  69. if not os.path.isdir(folder_path):
  70. continue
  71. m = FOLDER_RE.match(folder)
  72. if not m:
  73. print(f"[warn] {run_id}/{folder}: 文件夹名不符合命名规则,跳过")
  74. continue
  75. form_letter, platform, hash_prefix = m.groups()
  76. wf_path = os.path.join(folder_path, "workflow.json")
  77. if not os.path.isfile(wf_path):
  78. print(f"[warn] {run_id}/{folder}: 没有 workflow.json,跳过")
  79. continue
  80. workflow = load_json(wf_path)
  81. # 在对应 form 里按 case_id 前缀找匹配的 post
  82. form_data = get_form(form_letter)
  83. want_prefix = f"{platform}_{hash_prefix}"
  84. hits = [r for r in form_data.get("results", [])
  85. if r.get("case_id", "").startswith(want_prefix)]
  86. if len(hits) != 1:
  87. print(f"[warn] {run_id}/{folder}: 匹配到 {len(hits)} 条 result(期望 1),跳过")
  88. continue
  89. result = hits[0]
  90. # 可选:用 _source.json 的 link 校验映射没串台
  91. src_path = os.path.join(folder_path, "_source.json")
  92. if os.path.isfile(src_path):
  93. src = load_json(src_path)
  94. if src.get("link") and src["link"] != result.get("source_url"):
  95. print(f"[warn] {run_id}/{folder}: _source.link 与 result.source_url 不一致")
  96. # 以帖子为单位合并,只保留 5 个字段
  97. merged = {
  98. "query_id": run_id,
  99. "query": query,
  100. "platform": result.get("platform", platform),
  101. "post": result.get("post"),
  102. "llm_evaluation": result.get("llm_evaluation"),
  103. "workflow": workflow,
  104. }
  105. out.append((folder, merged))
  106. if not out:
  107. print(f"[skip] {run_id}: 没有可合并的 procedure")
  108. return out
  109. def write_run(run_id, runs_dir=None):
  110. entries = build_run(run_id, runs_dir=runs_dir)
  111. if not entries:
  112. return 0
  113. os.makedirs(OUT_DIR, exist_ok=True)
  114. for folder, merged in entries:
  115. out_path = os.path.join(OUT_DIR, f"{run_id}_{folder}.json")
  116. with open(out_path, "w", encoding="utf-8") as f:
  117. json.dump(merged, f, ensure_ascii=False, indent=2)
  118. print(f"[ok] {run_id}/{folder} -> {os.path.basename(out_path)}")
  119. return len(entries)
  120. def build_runs(run_ids):
  121. """对一组 run 执行 write_run,返回写出的帖子 json 总数。"""
  122. total = 0
  123. for run_id in run_ids:
  124. total += write_run(run_id)
  125. return total
  126. def all_run_ids():
  127. """runs_full 下所有 q* 目录。"""
  128. return sorted(d for d in os.listdir(RUNS_DIR)
  129. if re.match(r"^q\d+$", d)
  130. and os.path.isdir(os.path.join(RUNS_DIR, d)))
  131. # ---------- 接口:实时扫描 workflows/ 并以数组返回 ----------
  132. def scan_workflows():
  133. """实时扫描 workflows/*.json,把每个文件读成 dict,按文件名排序返回数组。
  134. 每次调用都重新读盘,所以 build 新写入的文件会立刻在接口里出现(无缓存)。"""
  135. items = []
  136. for fp in sorted(glob.glob(os.path.join(OUT_DIR, "*.json"))):
  137. try:
  138. items.append(load_json(fp))
  139. except Exception as e:
  140. print(f"[warn] 读取 {os.path.basename(fp)} 失败:{e}")
  141. return items
  142. class Handler(BaseHTTPRequestHandler):
  143. def _send(self, code, obj):
  144. body = json.dumps(obj, ensure_ascii=False).encode("utf-8")
  145. self.send_response(code)
  146. self.send_header("Content-Type", "application/json; charset=utf-8")
  147. self.send_header("Access-Control-Allow-Origin", "*")
  148. self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
  149. self.send_header("Access-Control-Allow-Headers", "Content-Type")
  150. self.send_header("Content-Length", str(len(body)))
  151. self.end_headers()
  152. self.wfile.write(body)
  153. def do_OPTIONS(self): # CORS 预检
  154. self._send(204, {})
  155. def do_GET(self):
  156. path = self.path.split("?")[0]
  157. if path in ("/", "/workflows", "/api/workflows"):
  158. self._send(200, scan_workflows())
  159. else:
  160. self._send(404, {"error": "not found"})
  161. def do_POST(self):
  162. if self.path.split("?")[0] != "/build":
  163. self._send(404, {"error": "not found"}); return
  164. length = int(self.headers.get("Content-Length") or 0)
  165. raw = self.rfile.read(length).decode("utf-8") if length > 0 else "{}"
  166. try:
  167. payload = json.loads(raw)
  168. except Exception as e:
  169. self._send(400, {"error": f"bad json: {e}"}); return
  170. if payload.get("all"):
  171. run_ids = all_run_ids()
  172. else:
  173. q = (payload.get("q") or "").strip()
  174. if not re.match(r"^q\d+$", q): # 限定 qNN 形式,避免路径注入
  175. self._send(400, {"error": f"bad q (expect 'qNN' or all=true): {q!r}"}); return
  176. run_ids = [q]
  177. try:
  178. n = build_runs(run_ids)
  179. self._send(200, {"status": "ok", "runs": run_ids, "written": n})
  180. except Exception as e:
  181. self._send(500, {"error": f"build failed: {e}"})
  182. def log_message(self, *a):
  183. pass
  184. def serve(port):
  185. n = len(scan_workflows())
  186. print(f"workflows 接口:http://0.0.0.0:{port}/workflows "
  187. f"(workflows/ 下当前 {n} 个 json,实时扫描)")
  188. ThreadingHTTPServer(("0.0.0.0", port), Handler).serve_forever()
  189. def main(argv):
  190. args = argv[1:]
  191. if args and args[0] == "serve":
  192. port = int(args[1]) if len(args) > 1 else DEFAULT_PORT
  193. serve(port)
  194. return
  195. if "--all" in args:
  196. run_ids = all_run_ids()
  197. elif args:
  198. run_ids = args
  199. else:
  200. run_ids = ["q0000"]
  201. total_files = build_runs(run_ids)
  202. print(f"\n完成:处理 {len(run_ids)} 个 run,共写出 {total_files} 个帖子 json")
  203. if __name__ == "__main__":
  204. try: # Windows 控制台默认 cp1252,中文 print 会崩,统一切 utf-8
  205. sys.stdout.reconfigure(encoding="utf-8")
  206. except Exception:
  207. pass
  208. main(sys.argv)