pg_capability_store.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. """
  2. PostgreSQL capability 存储封装
  3. 用于存储和检索原子能力数据,支持向量检索。
  4. 表名:capability(从 atomic_capability 迁移)
  5. """
  6. import os
  7. import json
  8. import psycopg2
  9. from psycopg2.extras import RealDictCursor
  10. from typing import List, Dict, Optional
  11. from dotenv import load_dotenv
  12. from knowhub.knowhub_db.cascade import cascade_delete
  13. load_dotenv()
  14. # 关联字段子查询
  15. _REL_SUBQUERIES = """
  16. (SELECT COALESCE(json_agg(rc.requirement_id), '[]'::json)
  17. FROM requirement_capability rc WHERE rc.capability_id = capability.id) AS requirement_ids,
  18. (SELECT COALESCE(json_agg(ct.tool_id), '[]'::json)
  19. FROM capability_tool ct WHERE ct.capability_id = capability.id) AS tool_ids,
  20. (SELECT COALESCE(
  21. json_object_agg(ct2.tool_id, ct2.description), '{}'::json)
  22. FROM capability_tool ct2 WHERE ct2.capability_id = capability.id AND ct2.description != '') AS implements,
  23. (SELECT COALESCE(json_agg(ck.knowledge_id), '[]'::json)
  24. FROM capability_knowledge ck WHERE ck.capability_id = capability.id) AS knowledge_ids
  25. """
  26. _BASE_FIELDS = "id, name, criterion, description"
  27. _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}"
  28. class PostgreSQLCapabilityStore:
  29. def __init__(self):
  30. """初始化 PostgreSQL 连接"""
  31. self.conn = psycopg2.connect(
  32. host=os.getenv('KNOWHUB_DB'),
  33. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  34. user=os.getenv('KNOWHUB_USER'),
  35. password=os.getenv('KNOWHUB_PASSWORD'),
  36. database=os.getenv('KNOWHUB_DB_NAME')
  37. )
  38. self.conn.autocommit = False
  39. print(f"[PostgreSQL Capability] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  40. def _reconnect(self):
  41. self.conn = psycopg2.connect(
  42. host=os.getenv('KNOWHUB_DB'),
  43. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  44. user=os.getenv('KNOWHUB_USER'),
  45. password=os.getenv('KNOWHUB_PASSWORD'),
  46. database=os.getenv('KNOWHUB_DB_NAME')
  47. )
  48. self.conn.autocommit = False
  49. def _ensure_connection(self):
  50. if self.conn.closed != 0:
  51. self._reconnect()
  52. else:
  53. try:
  54. c = self.conn.cursor()
  55. c.execute("SELECT 1")
  56. c.close()
  57. except (psycopg2.OperationalError, psycopg2.InterfaceError):
  58. self._reconnect()
  59. def _get_cursor(self):
  60. self._ensure_connection()
  61. return self.conn.cursor(cursor_factory=RealDictCursor)
  62. def _save_relations(self, cursor, cap_id: str, data: Dict):
  63. """保存 capability 的关联表数据"""
  64. if 'requirement_ids' in data:
  65. cursor.execute("DELETE FROM requirement_capability WHERE capability_id = %s", (cap_id,))
  66. for req_id in data['requirement_ids']:
  67. cursor.execute(
  68. "INSERT INTO requirement_capability (requirement_id, capability_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  69. (req_id, cap_id))
  70. # tool_ids + implements 合并写入 capability_tool
  71. if 'tool_ids' in data or 'implements' in data:
  72. cursor.execute("DELETE FROM capability_tool WHERE capability_id = %s", (cap_id,))
  73. implements = data.get('implements', {})
  74. tool_ids = set(data.get('tool_ids', []))
  75. # 先写 tool_ids 列表中的(附带 implements 的 description)
  76. for tool_id in tool_ids:
  77. desc = implements.get(tool_id, '')
  78. cursor.execute(
  79. "INSERT INTO capability_tool (capability_id, tool_id, description) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  80. (cap_id, tool_id, desc))
  81. # 再写 implements 中有但 tool_ids 列表没有的
  82. for tool_id, desc in implements.items():
  83. if tool_id not in tool_ids:
  84. cursor.execute(
  85. "INSERT INTO capability_tool (capability_id, tool_id, description) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  86. (cap_id, tool_id, desc))
  87. if 'knowledge_ids' in data:
  88. cursor.execute("DELETE FROM capability_knowledge WHERE capability_id = %s", (cap_id,))
  89. for kid in data['knowledge_ids']:
  90. cursor.execute(
  91. "INSERT INTO capability_knowledge (capability_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  92. (cap_id, kid))
  93. def insert_or_update(self, cap: Dict):
  94. """插入或更新原子能力"""
  95. cursor = self._get_cursor()
  96. try:
  97. cursor.execute("""
  98. INSERT INTO capability (
  99. id, name, criterion, description, embedding
  100. ) VALUES (%s, %s, %s, %s, %s)
  101. ON CONFLICT (id) DO UPDATE SET
  102. name = EXCLUDED.name,
  103. criterion = EXCLUDED.criterion,
  104. description = EXCLUDED.description,
  105. embedding = EXCLUDED.embedding
  106. """, (
  107. cap['id'],
  108. cap.get('name', ''),
  109. cap.get('criterion', ''),
  110. cap.get('description', ''),
  111. cap.get('embedding'),
  112. ))
  113. self._save_relations(cursor, cap['id'], cap)
  114. self.conn.commit()
  115. finally:
  116. cursor.close()
  117. def get_by_id(self, cap_id: str) -> Optional[Dict]:
  118. """根据 ID 获取原子能力"""
  119. cursor = self._get_cursor()
  120. try:
  121. cursor.execute(f"""
  122. SELECT {_SELECT_FIELDS}
  123. FROM capability WHERE id = %s
  124. """, (cap_id,))
  125. result = cursor.fetchone()
  126. return self._format_result(result) if result else None
  127. finally:
  128. cursor.close()
  129. def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]:
  130. """向量检索原子能力"""
  131. cursor = self._get_cursor()
  132. try:
  133. cursor.execute(f"""
  134. SELECT {_SELECT_FIELDS},
  135. 1 - (embedding <=> %s::real[]) as score
  136. FROM capability
  137. WHERE embedding IS NOT NULL
  138. ORDER BY embedding <=> %s::real[]
  139. LIMIT %s
  140. """, (query_embedding, query_embedding, limit))
  141. results = cursor.fetchall()
  142. return [self._format_result(r) for r in results]
  143. finally:
  144. cursor.close()
  145. def list_all(self, limit: int = 100, offset: int = 0) -> List[Dict]:
  146. """列出原子能力"""
  147. cursor = self._get_cursor()
  148. try:
  149. cursor.execute(f"""
  150. SELECT {_SELECT_FIELDS}
  151. FROM capability
  152. ORDER BY id
  153. LIMIT %s OFFSET %s
  154. """, (limit, offset))
  155. results = cursor.fetchall()
  156. return [self._format_result(r) for r in results]
  157. finally:
  158. cursor.close()
  159. def update(self, cap_id: str, updates: Dict):
  160. """更新原子能力字段"""
  161. cursor = self._get_cursor()
  162. try:
  163. # 分离关联字段
  164. rel_fields = {}
  165. for key in ('requirement_ids', 'implements', 'tool_ids', 'knowledge_ids'):
  166. if key in updates:
  167. rel_fields[key] = updates.pop(key)
  168. if updates:
  169. set_parts = []
  170. params = []
  171. for key, value in updates.items():
  172. set_parts.append(f"{key} = %s")
  173. params.append(value)
  174. params.append(cap_id)
  175. cursor.execute(
  176. f"UPDATE capability SET {', '.join(set_parts)} WHERE id = %s",
  177. params
  178. )
  179. if rel_fields:
  180. self._save_relations(cursor, cap_id, rel_fields)
  181. self.conn.commit()
  182. finally:
  183. cursor.close()
  184. def delete(self, cap_id: str):
  185. """删除原子能力及其关联表记录"""
  186. cursor = self._get_cursor()
  187. try:
  188. cascade_delete(cursor, 'capability', cap_id)
  189. self.conn.commit()
  190. finally:
  191. cursor.close()
  192. def count(self) -> int:
  193. """统计原子能力总数"""
  194. cursor = self._get_cursor()
  195. try:
  196. cursor.execute("SELECT COUNT(*) as count FROM capability")
  197. return cursor.fetchone()['count']
  198. finally:
  199. cursor.close()
  200. def _format_result(self, row: Dict) -> Dict:
  201. """格式化查询结果"""
  202. if not row:
  203. return None
  204. result = dict(row)
  205. for field in ('requirement_ids', 'tool_ids', 'knowledge_ids'):
  206. if field in result and isinstance(result[field], str):
  207. result[field] = json.loads(result[field])
  208. elif field in result and result[field] is None:
  209. result[field] = []
  210. if 'implements' in result:
  211. if isinstance(result['implements'], str):
  212. result['implements'] = json.loads(result['implements'])
  213. elif result['implements'] is None:
  214. result['implements'] = {}
  215. return result
  216. def close(self):
  217. if self.conn:
  218. self.conn.close()