cluster_strategies.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. #!/usr/bin/env python3
  2. """
  3. Phase 3:对 193 条具体 strategy 做聚类,准备抽象化。
  4. 特征:
  5. - cap_signature: set of canonical capability_ids (去重后)
  6. - phase_count: int
  7. - tool_set: set of primary tool IDs from workflow_outline[*].capabilities[*].implements
  8. - name_bigrams: char 2-grams of strategy name
  9. 聚类策略(多信号组合):
  10. 1. 建立 pairwise composite similarity:
  11. composite = 0.6 * cap_jaccard
  12. + 0.2 * tool_jaccard
  13. + 0.1 * name_bigram_jaccard
  14. + 0.1 * phase_count_proximity
  15. 2. 粗筛阈值:composite >= 0.35(粗筛 → recall 优先)
  16. 3. 传递闭包得到候选簇
  17. 输出 /tmp/strategy_clusters.md 含:
  18. - 每簇 N 条 strategy(req_id / name / is_selected / coverage_score)
  19. - 簇内 cap 交集(majority 50%+ 成员持有)
  20. - 簇内常见 tools
  21. - 簇内常见 name 2-grams
  22. - 建议抽象名 + 判断提示
  23. """
  24. import json
  25. import math
  26. import re
  27. import sys
  28. from collections import defaultdict, Counter
  29. from itertools import combinations
  30. from pathlib import Path
  31. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  32. from knowhub.knowhub_db.pg_capability_store import PostgreSQLCapabilityStore
  33. CAP_WEIGHT = 0.55
  34. TOOL_WEIGHT = 0.25
  35. NAME_WEIGHT = 0.15
  36. PHASE_WEIGHT = 0.05
  37. COMPOSITE_THRESHOLD = 0.40 # 粗筛
  38. STRONG_LINK_THRESHOLD = 0.55 # 传递闭包只走强边
  39. def norm(s):
  40. return re.sub(r'\s+', '', (s or '').strip().lower())
  41. def ch_bigrams(s):
  42. s = re.sub(r'\s+', '', s or '')
  43. return set(s[i:i+2] for i in range(len(s) - 1))
  44. def jaccard(a, b):
  45. if not a or not b: return 0
  46. return len(a & b) / len(a | b)
  47. def phase_proximity(a, b):
  48. """Phase 数量差越小越相似"""
  49. if a == 0 and b == 0: return 0 # both empty = no signal
  50. d = abs(a - b)
  51. if d == 0: return 1.0
  52. if d == 1: return 0.7
  53. if d == 2: return 0.4
  54. return 0.1
  55. def extract_strategy_features(cur):
  56. """Return list of {id, req_id, name, is_selected, coverage_score, caps, tools, phases}.
  57. caps 从 strategy_capability 表读取(已经过 alias 解析,准确);
  58. tools + phases 从 body.workflow_outline 提取。
  59. """
  60. db_caps = {}
  61. cur.execute('SELECT id, name FROM capability')
  62. for r in cur.fetchall(): db_caps[r['id']] = r['name']
  63. # strategy_id → set of cap_ids
  64. strat_caps_map = defaultdict(set)
  65. cur.execute('SELECT strategy_id, capability_id FROM strategy_capability')
  66. for r in cur.fetchall():
  67. strat_caps_map[r['strategy_id']].add(r['capability_id'])
  68. cur.execute("""
  69. SELECT s.id, s.name, s.body,
  70. rs.requirement_id, rs.is_selected, rs.coverage_score
  71. FROM strategy s
  72. JOIN requirement_strategy rs ON rs.strategy_id=s.id
  73. ORDER BY rs.requirement_id, rs.is_selected DESC""")
  74. rows = cur.fetchall()
  75. features = []
  76. for row in rows:
  77. body = row['body'] if isinstance(row['body'], dict) else json.loads(row['body'] or '{}')
  78. wo = body.get('workflow_outline') or []
  79. tools = set(); phases = 0
  80. if isinstance(wo, list):
  81. phases = len(wo)
  82. for ph in wo:
  83. if not isinstance(ph, dict): continue
  84. for c in ph.get('capabilities', []) or []:
  85. if isinstance(c, dict):
  86. impl = c.get('implements')
  87. if isinstance(impl, dict):
  88. for t in impl.keys(): tools.add(t)
  89. elif isinstance(impl, list):
  90. for t in impl:
  91. if isinstance(t, str): tools.add(t)
  92. features.append({
  93. 'id': row['id'],
  94. 'req_id': row['requirement_id'],
  95. 'name': row['name'],
  96. 'is_selected': row['is_selected'],
  97. 'coverage_score': row['coverage_score'],
  98. 'caps': strat_caps_map[row['id']], # 从 junction 读,准确
  99. 'tools': tools,
  100. 'phases': phases,
  101. 'name_bigrams': ch_bigrams(row['name']),
  102. })
  103. return features, db_caps
  104. def compute_cap_idf(features):
  105. """log(N / df(cap)) — rare caps get higher weight."""
  106. N = len(features)
  107. df = Counter()
  108. for s in features:
  109. for c in s['caps']: df[c] += 1
  110. idf = {c: math.log(N / n) for c, n in df.items()}
  111. return idf, df
  112. def weighted_jaccard(a, b, idf):
  113. """Sum of IDF weights over intersection / union."""
  114. inter = a & b; union = a | b
  115. if not union: return 0
  116. w_inter = sum(idf.get(c, 0) for c in inter)
  117. w_union = sum(idf.get(c, 0) for c in union)
  118. return w_inter / w_union if w_union > 0 else 0
  119. def composite_sim(a, b, cap_idf):
  120. """Weighted cap similarity with IDF + other signals."""
  121. cap_j = weighted_jaccard(a['caps'], b['caps'], cap_idf)
  122. tool_j = jaccard(a['tools'], b['tools'])
  123. name_j = jaccard(a['name_bigrams'], b['name_bigrams'])
  124. phase_p = phase_proximity(a['phases'], b['phases'])
  125. sim = (CAP_WEIGHT * cap_j + TOOL_WEIGHT * tool_j +
  126. NAME_WEIGHT * name_j + PHASE_WEIGHT * phase_p)
  127. return sim, cap_j, tool_j, name_j, phase_p
  128. def cluster_with_strong_links(features, candidate_threshold, strong_threshold, cap_idf):
  129. """
  130. 两层:候选边(做参考)+ 强边(做簇)。
  131. 只用强边做传递闭包,避免弱边链条化。
  132. """
  133. n = len(features)
  134. strong_adj = defaultdict(set)
  135. all_edges = []
  136. for i, j in combinations(range(n), 2):
  137. sim, cap_j, tool_j, name_j, phase_p = composite_sim(features[i], features[j], cap_idf)
  138. if sim >= candidate_threshold:
  139. all_edges.append((i, j, sim, cap_j, tool_j, name_j, phase_p))
  140. if sim >= strong_threshold:
  141. strong_adj[i].add(j); strong_adj[j].add(i)
  142. visited = set()
  143. clusters = []
  144. for i in range(n):
  145. if i in visited: continue
  146. if i not in strong_adj:
  147. clusters.append([i]); visited.add(i); continue
  148. comp = []; stack = [i]
  149. while stack:
  150. x = stack.pop()
  151. if x in visited: continue
  152. visited.add(x); comp.append(x)
  153. for y in strong_adj[x]:
  154. if y not in visited: stack.append(y)
  155. clusters.append(comp)
  156. return clusters, all_edges
  157. def summarize_cluster(idx, members, features, db_caps, out):
  158. """Write cluster summary to markdown file out."""
  159. strats = [features[i] for i in members]
  160. out.write(f'\n## 簇 {idx}({len(strats)} 条 strategies)\n\n')
  161. # Stats
  162. cap_count = Counter()
  163. for s in strats:
  164. for c in s['caps']: cap_count[c] += 1
  165. tool_count = Counter()
  166. for s in strats:
  167. for t in s['tools']: tool_count[t] += 1
  168. name_bigram_count = Counter()
  169. for s in strats:
  170. for g in s['name_bigrams']: name_bigram_count[g] += 1
  171. total = len(strats)
  172. threshold_majority = max(2, total * 0.5)
  173. majority_caps = [(c, n) for c, n in cap_count.most_common() if n >= threshold_majority]
  174. common_tools = [(t, n) for t, n in tool_count.most_common(8)]
  175. common_bigrams = [(g, n) for g, n in name_bigram_count.most_common(6) if n >= 2]
  176. # Members table
  177. out.write('| req | is_sel | cov | phases | name |\n')
  178. out.write('|---|---|---|---|---|\n')
  179. for s in strats:
  180. cov = f'{s["coverage_score"]:.2f}' if s['coverage_score'] is not None else '—'
  181. is_sel = '✓' if s['is_selected'] else ' '
  182. out.write(f'| {s["req_id"]} | {is_sel} | {cov} | {s["phases"]} | {s["name"][:60]} |\n')
  183. out.write(f'\n**Majority caps**(≥{threshold_majority:.0f}/{total}成员持有):\n')
  184. for cid, n in majority_caps[:15]:
  185. out.write(f'- {n}/{total}: `{cid}` {db_caps.get(cid, "?")}\n')
  186. if common_tools:
  187. out.write(f'\n**Common tools**:\n')
  188. for t, n in common_tools:
  189. out.write(f'- {n}/{total}: `{t}`\n')
  190. if common_bigrams:
  191. out.write(f'\n**Common name bigrams** (≥2): ')
  192. out.write(' '.join(f'`{g}`({n})' for g, n in common_bigrams) + '\n')
  193. def main():
  194. s = PostgreSQLCapabilityStore()
  195. cur = s._get_cursor()
  196. try:
  197. features, db_caps = extract_strategy_features(cur)
  198. print(f'Total strategies: {len(features)}', flush=True)
  199. cap_idf, cap_df = compute_cap_idf(features)
  200. # Print top non-distinctive caps (high df)
  201. print('\nTop 10 most common caps (low IDF):')
  202. for c, df in sorted(cap_df.items(), key=lambda x: -x[1])[:10]:
  203. print(f' df={df:3d} idf={cap_idf[c]:.2f} {c} {db_caps.get(c,"?")[:40]}')
  204. clusters, edges = cluster_with_strong_links(
  205. features, COMPOSITE_THRESHOLD, STRONG_LINK_THRESHOLD, cap_idf)
  206. # Sort clusters by size desc
  207. clusters.sort(key=lambda c: -len(c))
  208. multi = [c for c in clusters if len(c) >= 2]
  209. singleton = [c for c in clusters if len(c) == 1]
  210. print(f'Clusters: {len(multi)} multi-strategy + {len(singleton)} singletons', flush=True)
  211. out_path = Path('/tmp/strategy_clusters.md')
  212. with out_path.open('w') as out:
  213. out.write(f'# Strategy Clusters (composite >= {COMPOSITE_THRESHOLD})\n\n')
  214. out.write(f'Total strategies: {len(features)}\n')
  215. out.write(f'Multi-member clusters: {len(multi)}(覆盖 {sum(len(c) for c in multi)} strategies)\n')
  216. out.write(f'Singletons: {len(singleton)}\n\n')
  217. out.write(f'Weights: cap={CAP_WEIGHT}, tool={TOOL_WEIGHT}, name={NAME_WEIGHT}, phase={PHASE_WEIGHT}\n\n')
  218. out.write('---\n\n# Multi-member clusters\n')
  219. for idx, members in enumerate(multi, 1):
  220. summarize_cluster(idx, members, features, db_caps, out)
  221. out.write('\n\n---\n\n# Singletons(独立 strategy)\n\n')
  222. out.write(f'{len(singleton)} strategies 没找到相似伙伴。\n\n')
  223. for c in singleton:
  224. s_ = features[c[0]]
  225. cov = f'{s_["coverage_score"]:.2f}' if s_['coverage_score'] is not None else '—'
  226. out.write(f'- [{s_["req_id"]}] caps={len(s_["caps"])}, phases={s_["phases"]}, cov={cov}: {s_["name"][:60]}\n')
  227. print(f'Written: {out_path}', flush=True)
  228. # Print summary
  229. print('\n== Multi-member clusters ==')
  230. for idx, members in enumerate(multi[:20], 1):
  231. req_ids = [features[i]['req_id'] for i in members]
  232. print(f' [{idx}] {len(members)} members: {", ".join(sorted(set(req_ids)))}')
  233. if len(multi) > 20: print(f' ... ({len(multi)-20} more)')
  234. finally:
  235. cur.close()
  236. s.close()
  237. if __name__ == '__main__':
  238. main()