pg_capability_store.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. """
  2. PostgreSQL capability 存储封装
  3. 用于存储和检索原子能力数据,支持向量检索。
  4. 表名:capability(从 atomic_capability 迁移)
  5. """
  6. import os
  7. import json
  8. import psycopg2
  9. from psycopg2.extras import RealDictCursor
  10. from typing import List, Dict, Optional
  11. from dotenv import load_dotenv
  12. from knowhub.knowhub_db.cascade import cascade_delete
  13. from knowhub.knowhub_db.version_context import version_where
  14. load_dotenv()
  15. # 关联字段子查询
  16. _REL_SUBQUERIES = """
  17. (SELECT COALESCE(json_agg(rc.requirement_id), '[]'::json)
  18. FROM requirement_capability rc WHERE rc.capability_id = capability.id) AS requirement_ids,
  19. (SELECT COALESCE(json_agg(ct.tool_id), '[]'::json)
  20. FROM capability_tool ct WHERE ct.capability_id = capability.id) AS tool_ids,
  21. (SELECT COALESCE(
  22. json_object_agg(ct2.tool_id, ct2.description), '{}'::json)
  23. FROM capability_tool ct2 WHERE ct2.capability_id = capability.id AND ct2.description != '') AS implements,
  24. (SELECT COALESCE(json_agg(ck.knowledge_id), '[]'::json)
  25. FROM capability_knowledge ck WHERE ck.capability_id = capability.id) AS knowledge_ids,
  26. (SELECT COALESCE(json_agg(json_build_object(
  27. 'id', ck2.knowledge_id, 'relation_type', ck2.relation_type
  28. )), '[]'::json)
  29. FROM capability_knowledge ck2 WHERE ck2.capability_id = capability.id) AS knowledge_links,
  30. (SELECT COALESCE(json_agg(cr.resource_id), '[]'::json)
  31. FROM capability_resource cr WHERE cr.capability_id = capability.id) AS resource_ids
  32. """
  33. _BASE_FIELDS = "id, name, criterion, description,version, effects"
  34. _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}"
  35. def _normalize_links(data: Dict, links_key: str, ids_key: str, default_type: str):
  36. """两种输入格式统一:{links_key: [{id, relation_type}]} 或 {ids_key: [id]}"""
  37. if links_key in data and data[links_key] is not None:
  38. out = []
  39. for item in data[links_key]:
  40. if isinstance(item, dict):
  41. out.append((item['id'], item.get('relation_type', default_type)))
  42. else:
  43. out.append((item, default_type))
  44. return out
  45. if ids_key in data and data[ids_key] is not None:
  46. return [(i, default_type) for i in data[ids_key]]
  47. return None
  48. class PostgreSQLCapabilityStore:
  49. def __init__(self):
  50. """初始化 PostgreSQL 连接"""
  51. self.conn = psycopg2.connect(
  52. host=os.getenv('KNOWHUB_DB'),
  53. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  54. user=os.getenv('KNOWHUB_USER'),
  55. password=os.getenv('KNOWHUB_PASSWORD'),
  56. database=os.getenv('KNOWHUB_DB_NAME')
  57. )
  58. self.conn.autocommit = True
  59. print(f"[PostgreSQL Capability] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  60. def _reconnect(self):
  61. self.conn = psycopg2.connect(
  62. host=os.getenv('KNOWHUB_DB'),
  63. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  64. user=os.getenv('KNOWHUB_USER'),
  65. password=os.getenv('KNOWHUB_PASSWORD'),
  66. database=os.getenv('KNOWHUB_DB_NAME')
  67. )
  68. self.conn.autocommit = True
  69. def _ensure_connection(self):
  70. if self.conn.closed != 0:
  71. self._reconnect()
  72. else:
  73. try:
  74. c = self.conn.cursor()
  75. c.execute("SELECT 1")
  76. c.close()
  77. except (psycopg2.OperationalError, psycopg2.InterfaceError):
  78. self._reconnect()
  79. def _get_cursor(self):
  80. self._ensure_connection()
  81. return self.conn.cursor(cursor_factory=RealDictCursor)
  82. def _save_relations(self, cursor, cap_id: str, data: Dict):
  83. """保存 capability 的关联表数据"""
  84. if 'requirement_ids' in data:
  85. cursor.execute("DELETE FROM requirement_capability WHERE capability_id = %s", (cap_id,))
  86. for req_id in data['requirement_ids']:
  87. cursor.execute(
  88. "INSERT INTO requirement_capability (requirement_id, capability_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  89. (req_id, cap_id))
  90. # tool_ids + implements 合并写入 capability_tool
  91. if 'tool_ids' in data or 'implements' in data:
  92. cursor.execute("DELETE FROM capability_tool WHERE capability_id = %s", (cap_id,))
  93. implements = data.get('implements', {})
  94. tool_ids = set(data.get('tool_ids', []))
  95. # 先写 tool_ids 列表中的(附带 implements 的 description)
  96. for tool_id in tool_ids:
  97. desc = implements.get(tool_id, '')
  98. cursor.execute(
  99. "INSERT INTO capability_tool (capability_id, tool_id, description) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  100. (cap_id, tool_id, desc))
  101. # 再写 implements 中有但 tool_ids 列表没有的
  102. for tool_id, desc in implements.items():
  103. if tool_id not in tool_ids:
  104. cursor.execute(
  105. "INSERT INTO capability_tool (capability_id, tool_id, description) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  106. (cap_id, tool_id, desc))
  107. k_links = _normalize_links(data, 'knowledge_links', 'knowledge_ids', 'related')
  108. if k_links is not None:
  109. cursor.execute("DELETE FROM capability_knowledge WHERE capability_id = %s", (cap_id,))
  110. for kid, rtype in k_links:
  111. cursor.execute(
  112. "INSERT INTO capability_knowledge (capability_id, knowledge_id, relation_type) "
  113. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  114. (cap_id, kid, rtype))
  115. if 'resource_ids' in data and data['resource_ids'] is not None:
  116. cursor.execute("DELETE FROM capability_resource WHERE capability_id = %s", (cap_id,))
  117. for rid in data['resource_ids']:
  118. cursor.execute(
  119. "INSERT INTO capability_resource (capability_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  120. (cap_id, rid))
  121. def insert_or_update(self, cap: Dict):
  122. """插入或更新原子能力。AnalyticDB beam 表不支持 ON CONFLICT UPDATE 当含 ALTER 新增列,改用 DELETE+INSERT。"""
  123. cursor = self._get_cursor()
  124. try:
  125. cursor.execute("DELETE FROM capability WHERE id = %s", (cap['id'],))
  126. cursor.execute("""
  127. INSERT INTO capability (
  128. id, name, criterion, description, effects, embedding, version
  129. ) VALUES (%s, %s, %s, %s, %s, %s, %s)
  130. ON CONFLICT (id) DO UPDATE SET
  131. name = EXCLUDED.name,
  132. criterion = EXCLUDED.criterion,
  133. description = EXCLUDED.description,
  134. effects = EXCLUDED.effects,
  135. embedding = EXCLUDED.embedding,
  136. version = EXCLUDED.version
  137. """, (
  138. cap['id'],
  139. cap.get('name', ''),
  140. cap.get('criterion', ''),
  141. cap.get('description', ''),
  142. json.dumps(cap.get('effects', [])),
  143. cap.get('embedding'),
  144. cap.get('version', 'v0'),
  145. ))
  146. self._save_relations(cursor, cap['id'], cap)
  147. self.conn.commit()
  148. finally:
  149. cursor.close()
  150. def get_by_id(self, cap_id: str) -> Optional[Dict]:
  151. """根据 ID 获取原子能力"""
  152. cursor = self._get_cursor()
  153. try:
  154. vf, vp = version_where()
  155. cursor.execute(f"""
  156. SELECT {_SELECT_FIELDS}
  157. FROM capability WHERE id = %s AND {vf}
  158. """, (cap_id, *vp))
  159. result = cursor.fetchone()
  160. return self._format_result(result) if result else None
  161. finally:
  162. cursor.close()
  163. def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]:
  164. """向量检索原子能力"""
  165. cursor = self._get_cursor()
  166. try:
  167. vf, vp = version_where()
  168. cursor.execute(f"""
  169. SELECT {_SELECT_FIELDS},
  170. 1 - (embedding <=> %s::real[]) as score
  171. FROM capability
  172. WHERE embedding IS NOT NULL AND {vf}
  173. ORDER BY embedding <=> %s::real[]
  174. LIMIT %s
  175. """, (query_embedding, *vp, query_embedding, limit))
  176. results = cursor.fetchall()
  177. return [self._format_result(r) for r in results]
  178. finally:
  179. cursor.close()
  180. def list_all(self, limit: int = 100, offset: int = 0) -> List[Dict]:
  181. """列出原子能力"""
  182. cursor = self._get_cursor()
  183. try:
  184. vf, vp = version_where()
  185. cursor.execute(f"""
  186. SELECT {_SELECT_FIELDS}
  187. FROM capability
  188. WHERE {vf}
  189. ORDER BY id
  190. LIMIT %s OFFSET %s
  191. """, (*vp, limit, offset))
  192. results = cursor.fetchall()
  193. return [self._format_result(r) for r in results]
  194. finally:
  195. cursor.close()
  196. def update(self, cap_id: str, updates: Dict):
  197. """更新原子能力字段"""
  198. cursor = self._get_cursor()
  199. try:
  200. # 分离关联字段
  201. rel_fields = {}
  202. for key in ('requirement_ids', 'implements', 'tool_ids',
  203. 'knowledge_ids', 'knowledge_links', 'resource_ids'):
  204. if key in updates:
  205. rel_fields[key] = updates.pop(key)
  206. if updates:
  207. set_parts = []
  208. params = []
  209. for key, value in updates.items():
  210. set_parts.append(f"{key} = %s")
  211. params.append(value)
  212. params.append(cap_id)
  213. cursor.execute(
  214. f"UPDATE capability SET {', '.join(set_parts)} WHERE id = %s",
  215. params
  216. )
  217. if rel_fields:
  218. self._save_relations(cursor, cap_id, rel_fields)
  219. self.conn.commit()
  220. finally:
  221. cursor.close()
  222. def delete(self, cap_id: str):
  223. """删除原子能力及其关联表记录"""
  224. cursor = self._get_cursor()
  225. try:
  226. cascade_delete(cursor, 'capability', cap_id)
  227. self.conn.commit()
  228. finally:
  229. cursor.close()
  230. def count(self) -> int:
  231. """统计原子能力总数"""
  232. cursor = self._get_cursor()
  233. try:
  234. vf, vp = version_where()
  235. cursor.execute(f"SELECT COUNT(*) as count FROM capability WHERE {vf}", vp)
  236. return cursor.fetchone()['count']
  237. finally:
  238. cursor.close()
  239. def _format_result(self, row: Dict) -> Dict:
  240. """格式化查询结果"""
  241. if not row:
  242. return None
  243. result = dict(row)
  244. for field in ('requirement_ids', 'tool_ids', 'knowledge_ids', 'effects'):
  245. if field in result and isinstance(result[field], str):
  246. result[field] = json.loads(result[field])
  247. elif field in result and result[field] is None:
  248. result[field] = []
  249. if 'implements' in result:
  250. if isinstance(result['implements'], str):
  251. result['implements'] = json.loads(result['implements'])
  252. elif result['implements'] is None:
  253. result['implements'] = {}
  254. return result
  255. def add_knowledge(self, cap_id: str, knowledge_id: str, relation_type: str = 'related'):
  256. """增量挂接 capability-knowledge 边"""
  257. cursor = self._get_cursor()
  258. try:
  259. cursor.execute(
  260. "INSERT INTO capability_knowledge (capability_id, knowledge_id, relation_type) "
  261. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  262. (cap_id, knowledge_id, relation_type))
  263. self.conn.commit()
  264. finally:
  265. cursor.close()
  266. def add_resource(self, cap_id: str, resource_id: str):
  267. """增量挂接 capability-resource 边"""
  268. cursor = self._get_cursor()
  269. try:
  270. cursor.execute(
  271. "INSERT INTO capability_resource (capability_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  272. (cap_id, resource_id))
  273. self.conn.commit()
  274. finally:
  275. cursor.close()
  276. def close(self):
  277. if self.conn:
  278. self.conn.close()