reclassify_tool_knowledge.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. #!/usr/bin/env python3
  2. """
  3. 修复脚本:根据知识的 types 标签,将 tool_knowledge 中的知识重新分类到正确的字段
  4. 映射规则:
  5. - types 包含 "plan" → process_knowledge(工序知识)
  6. - types 包含 "usecase" → case_knowledge(用例知识)
  7. - 其他 → 保留在 tool_knowledge
  8. """
  9. import os
  10. import sys
  11. import json
  12. import psycopg2
  13. from psycopg2.extras import RealDictCursor
  14. from dotenv import load_dotenv
  15. _script_dir = os.path.dirname(os.path.abspath(__file__))
  16. _project_root = os.path.normpath(os.path.join(_script_dir, '..', '..'))
  17. load_dotenv(os.path.join(_project_root, '.env'))
  18. def get_connection():
  19. host = os.getenv('KNOWHUB_DB')
  20. port = int(os.getenv('KNOWHUB_PORT', 5432))
  21. user = os.getenv('KNOWHUB_USER')
  22. password = os.getenv('KNOWHUB_PASSWORD')
  23. dbname = os.getenv('KNOWHUB_DB_NAME')
  24. print(f"连接到 {host}:{port}/{dbname} as {user} ...")
  25. conn = psycopg2.connect(
  26. host=host, port=port, user=user,
  27. password=password, database=dbname, connect_timeout=10
  28. )
  29. conn.autocommit = False
  30. print("连接成功。\n")
  31. return conn
  32. def parse_jsonb(val):
  33. """安全解析 JSONB 字段"""
  34. if val is None:
  35. return []
  36. if isinstance(val, list):
  37. return val
  38. if isinstance(val, str):
  39. try:
  40. return json.loads(val)
  41. except json.JSONDecodeError:
  42. return []
  43. return []
  44. def main():
  45. dry_run = '--dry-run' in sys.argv
  46. if dry_run:
  47. print("=== DRY RUN 模式(不会实际修改数据)===\n")
  48. conn = get_connection()
  49. cursor = conn.cursor(cursor_factory=RealDictCursor)
  50. # 1. 获取所有有 tool_knowledge 的工具
  51. cursor.execute("""
  52. SELECT id, tool_knowledge, case_knowledge, process_knowledge
  53. FROM tool_table
  54. WHERE tool_knowledge IS NOT NULL AND tool_knowledge != '[]'::jsonb
  55. """)
  56. tools = cursor.fetchall()
  57. print(f"找到 {len(tools)} 个有关联知识的工具\n")
  58. if not tools:
  59. print("没有需要处理的工具,退出。")
  60. cursor.close()
  61. conn.close()
  62. return
  63. # 2. 收集所有涉及的知识 ID
  64. all_knowledge_ids = set()
  65. for tool in tools:
  66. all_knowledge_ids.update(parse_jsonb(tool['tool_knowledge']))
  67. print(f"涉及 {len(all_knowledge_ids)} 条知识,正在查询类型...\n")
  68. # 3. 批量查询知识的 types
  69. knowledge_types = {}
  70. if all_knowledge_ids:
  71. id_list = list(all_knowledge_ids)
  72. # 分批查询,避免参数过多
  73. batch_size = 100
  74. for i in range(0, len(id_list), batch_size):
  75. batch = id_list[i:i + batch_size]
  76. placeholders = ','.join(['%s'] * len(batch))
  77. cursor.execute(f"""
  78. SELECT id, types FROM knowledge
  79. WHERE id IN ({placeholders})
  80. """, batch)
  81. for row in cursor.fetchall():
  82. types_val = row['types']
  83. if isinstance(types_val, str):
  84. try:
  85. types_val = json.loads(types_val)
  86. except json.JSONDecodeError:
  87. types_val = []
  88. knowledge_types[row['id']] = types_val or []
  89. print(f"成功查询到 {len(knowledge_types)} 条知识的类型信息\n")
  90. # 统计
  91. plan_count = sum(1 for t in knowledge_types.values() if 'plan' in t)
  92. usecase_count = sum(1 for t in knowledge_types.values() if 'usecase' in t)
  93. print(f"类型分布:plan={plan_count}, usecase={usecase_count}, "
  94. f"其他={len(knowledge_types) - plan_count - usecase_count}\n")
  95. # 4. 重新分类每个工具的知识
  96. updated_count = 0
  97. for tool in tools:
  98. tool_id = tool['id']
  99. old_tool_knowledge = parse_jsonb(tool['tool_knowledge'])
  100. old_case_knowledge = parse_jsonb(tool['case_knowledge'])
  101. old_process_knowledge = parse_jsonb(tool['process_knowledge'])
  102. new_tool_knowledge = []
  103. new_case_knowledge = list(old_case_knowledge) # 保留已有的
  104. new_process_knowledge = list(old_process_knowledge) # 保留已有的
  105. for kid in old_tool_knowledge:
  106. types = knowledge_types.get(kid, [])
  107. if 'plan' in types:
  108. if kid not in new_process_knowledge:
  109. new_process_knowledge.append(kid)
  110. elif 'usecase' in types:
  111. if kid not in new_case_knowledge:
  112. new_case_knowledge.append(kid)
  113. else:
  114. new_tool_knowledge.append(kid)
  115. # 检查是否有变化
  116. changed = (
  117. set(new_tool_knowledge) != set(old_tool_knowledge) or
  118. set(new_case_knowledge) != set(old_case_knowledge) or
  119. set(new_process_knowledge) != set(old_process_knowledge)
  120. )
  121. if changed:
  122. moved_to_case = [k for k in old_tool_knowledge if k in new_case_knowledge and k not in old_case_knowledge]
  123. moved_to_process = [k for k in old_tool_knowledge if k in new_process_knowledge and k not in old_process_knowledge]
  124. print(f"工具: {tool_id}")
  125. if moved_to_process:
  126. print(f" → process_knowledge (plan): +{len(moved_to_process)} 条")
  127. for kid in moved_to_process:
  128. print(f" {kid}")
  129. if moved_to_case:
  130. print(f" → case_knowledge (usecase): +{len(moved_to_case)} 条")
  131. for kid in moved_to_case:
  132. print(f" {kid}")
  133. print(f" tool_knowledge: {len(old_tool_knowledge)} → {len(new_tool_knowledge)}")
  134. print()
  135. if not dry_run:
  136. cursor.execute("""
  137. UPDATE tool_table
  138. SET tool_knowledge = %s,
  139. case_knowledge = %s,
  140. process_knowledge = %s
  141. WHERE id = %s
  142. """, (
  143. json.dumps(new_tool_knowledge),
  144. json.dumps(new_case_knowledge),
  145. json.dumps(new_process_knowledge),
  146. tool_id
  147. ))
  148. updated_count += 1
  149. if not dry_run and updated_count > 0:
  150. conn.commit()
  151. print("=" * 60)
  152. print(f"处理完成:共 {len(tools)} 个工具,{updated_count} 个需要更新")
  153. if dry_run:
  154. print("(DRY RUN 模式,未实际修改数据)")
  155. print("确认无误后,去掉 --dry-run 参数重新运行即可。")
  156. else:
  157. print(f"已成功更新 {updated_count} 个工具的知识分类。")
  158. cursor.close()
  159. conn.close()
  160. if __name__ == '__main__':
  161. print("知识分类修复工具")
  162. print("用法:")
  163. print(" python reclassify_tool_knowledge.py --dry-run # 预览变更")
  164. print(" python reclassify_tool_knowledge.py # 执行变更")
  165. print()
  166. main()