pg_tool_store.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. """
  2. PostgreSQL tool 存储封装
  3. 用于存储和检索工具数据,支持向量检索。
  4. 表名:tool(从 tool_table 迁移)
  5. """
  6. import os
  7. import json
  8. import psycopg2
  9. from psycopg2.extras import RealDictCursor
  10. from typing import List, Dict, Optional
  11. from dotenv import load_dotenv
  12. from knowhub.knowhub_db.cascade import cascade_delete
  13. load_dotenv()
  14. # 关联字段子查询。knowledge 边暴露两种视图:knowledge_ids(扁平)+ knowledge_links(含 type)
  15. _REL_SUBQUERIES = """
  16. (SELECT COALESCE(json_agg(ct.capability_id), '[]'::json)
  17. FROM capability_tool ct WHERE ct.tool_id = tool.id) AS capability_ids,
  18. (SELECT COALESCE(json_agg(tk.knowledge_id), '[]'::json)
  19. FROM tool_knowledge tk WHERE tk.tool_id = tool.id) AS knowledge_ids,
  20. (SELECT COALESCE(json_agg(json_build_object(
  21. 'id', tk2.knowledge_id, 'relation_type', tk2.relation_type
  22. )), '[]'::json)
  23. FROM tool_knowledge tk2 WHERE tk2.tool_id = tool.id) AS knowledge_links,
  24. (SELECT COALESCE(json_agg(tp.provider_id), '[]'::json)
  25. FROM tool_provider tp WHERE tp.tool_id = tool.id) AS provider_ids
  26. """
  27. _BASE_FIELDS = "id, name, version, introduction, tutorial, input, output, updated_time, status"
  28. _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}"
  29. def _normalize_links(data: Dict, links_key: str, ids_key: str, default_type: str):
  30. """两种输入格式统一:{links_key: [{id, relation_type}]} 或 {ids_key: [id]}"""
  31. if links_key in data and data[links_key] is not None:
  32. out = []
  33. for item in data[links_key]:
  34. if isinstance(item, dict):
  35. out.append((item['id'], item.get('relation_type', default_type)))
  36. else:
  37. out.append((item, default_type))
  38. return out
  39. if ids_key in data and data[ids_key] is not None:
  40. return [(i, default_type) for i in data[ids_key]]
  41. return None
  42. class PostgreSQLToolStore:
  43. def __init__(self):
  44. """初始化 PostgreSQL 连接"""
  45. self.conn = psycopg2.connect(
  46. host=os.getenv('KNOWHUB_DB'),
  47. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  48. user=os.getenv('KNOWHUB_USER'),
  49. password=os.getenv('KNOWHUB_PASSWORD'),
  50. database=os.getenv('KNOWHUB_DB_NAME')
  51. )
  52. self.conn.autocommit = True
  53. print(f"[PostgreSQL Tool] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  54. def _reconnect(self):
  55. self.conn = psycopg2.connect(
  56. host=os.getenv('KNOWHUB_DB'),
  57. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  58. user=os.getenv('KNOWHUB_USER'),
  59. password=os.getenv('KNOWHUB_PASSWORD'),
  60. database=os.getenv('KNOWHUB_DB_NAME')
  61. )
  62. self.conn.autocommit = True
  63. def _ensure_connection(self):
  64. if self.conn.closed != 0:
  65. self._reconnect()
  66. else:
  67. try:
  68. c = self.conn.cursor()
  69. c.execute("SELECT 1")
  70. c.close()
  71. except (psycopg2.OperationalError, psycopg2.InterfaceError):
  72. self._reconnect()
  73. def _get_cursor(self):
  74. self._ensure_connection()
  75. return self.conn.cursor(cursor_factory=RealDictCursor)
  76. def insert_or_update(self, tool: Dict):
  77. """插入或更新工具"""
  78. cursor = self._get_cursor()
  79. try:
  80. cursor.execute("""
  81. INSERT INTO tool (
  82. id, name, version, introduction, tutorial, input, output,
  83. updated_time, status, embedding
  84. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  85. ON CONFLICT (id) DO UPDATE SET
  86. name = EXCLUDED.name,
  87. version = EXCLUDED.version,
  88. introduction = EXCLUDED.introduction,
  89. tutorial = EXCLUDED.tutorial,
  90. input = EXCLUDED.input,
  91. output = EXCLUDED.output,
  92. updated_time = EXCLUDED.updated_time,
  93. status = EXCLUDED.status,
  94. embedding = EXCLUDED.embedding
  95. """, (
  96. tool['id'],
  97. tool.get('name', ''),
  98. tool.get('version'),
  99. tool.get('introduction', ''),
  100. tool.get('tutorial', ''),
  101. json.dumps(tool.get('input', '')),
  102. json.dumps(tool.get('output', '')),
  103. tool.get('updated_time', 0),
  104. tool.get('status', '未接入'),
  105. tool.get('embedding'),
  106. ))
  107. # 写入关联表
  108. tool_id = tool['id']
  109. if 'capability_ids' in tool:
  110. cursor.execute("DELETE FROM capability_tool WHERE tool_id = %s", (tool_id,))
  111. for cap_id in tool['capability_ids']:
  112. cursor.execute(
  113. "INSERT INTO capability_tool (capability_id, tool_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  114. (cap_id, tool_id))
  115. k_links = _normalize_links(tool, 'knowledge_links', 'knowledge_ids', 'related')
  116. if k_links is not None:
  117. cursor.execute("DELETE FROM tool_knowledge WHERE tool_id = %s", (tool_id,))
  118. for kid, rtype in k_links:
  119. cursor.execute(
  120. "INSERT INTO tool_knowledge (tool_id, knowledge_id, relation_type) "
  121. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  122. (tool_id, kid, rtype))
  123. if 'provider_ids' in tool:
  124. cursor.execute("DELETE FROM tool_provider WHERE tool_id = %s", (tool_id,))
  125. for pid in tool['provider_ids']:
  126. cursor.execute(
  127. "INSERT INTO tool_provider (tool_id, provider_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  128. (tool_id, pid))
  129. self.conn.commit()
  130. finally:
  131. cursor.close()
  132. def get_by_id(self, tool_id: str) -> Optional[Dict]:
  133. """根据 ID 获取工具"""
  134. cursor = self._get_cursor()
  135. try:
  136. cursor.execute(f"""
  137. SELECT {_SELECT_FIELDS}
  138. FROM tool WHERE id = %s
  139. """, (tool_id,))
  140. result = cursor.fetchone()
  141. return self._format_result(result) if result else None
  142. finally:
  143. cursor.close()
  144. def search(self, query_embedding: List[float], limit: int = 10, status: Optional[str] = None) -> List[Dict]:
  145. """向量检索工具"""
  146. cursor = self._get_cursor()
  147. try:
  148. if status:
  149. sql = f"""
  150. SELECT {_SELECT_FIELDS},
  151. 1 - (embedding <=> %s::real[]) as score
  152. FROM tool
  153. WHERE embedding IS NOT NULL AND status = %s
  154. ORDER BY embedding <=> %s::real[]
  155. LIMIT %s
  156. """
  157. params = [query_embedding, status, query_embedding, limit]
  158. else:
  159. sql = f"""
  160. SELECT {_SELECT_FIELDS},
  161. 1 - (embedding <=> %s::real[]) as score
  162. FROM tool
  163. WHERE embedding IS NOT NULL
  164. ORDER BY embedding <=> %s::real[]
  165. LIMIT %s
  166. """
  167. params = [query_embedding, query_embedding, limit]
  168. cursor.execute(sql, params)
  169. results = cursor.fetchall()
  170. return [self._format_result(r) for r in results]
  171. finally:
  172. cursor.close()
  173. def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]:
  174. """列出工具"""
  175. cursor = self._get_cursor()
  176. try:
  177. if status:
  178. cursor.execute(f"""
  179. SELECT {_SELECT_FIELDS}
  180. FROM tool
  181. WHERE status = %s
  182. ORDER BY updated_time DESC
  183. LIMIT %s OFFSET %s
  184. """, (status, limit, offset))
  185. else:
  186. cursor.execute(f"""
  187. SELECT {_SELECT_FIELDS}
  188. FROM tool
  189. ORDER BY updated_time DESC
  190. LIMIT %s OFFSET %s
  191. """, (limit, offset))
  192. results = cursor.fetchall()
  193. return [self._format_result(r) for r in results]
  194. finally:
  195. cursor.close()
  196. def update(self, tool_id: str, updates: Dict):
  197. """更新工具字段"""
  198. cursor = self._get_cursor()
  199. try:
  200. # 分离关联字段
  201. cap_ids = updates.pop('capability_ids', None)
  202. knowledge_ids = updates.pop('knowledge_ids', None)
  203. knowledge_links = updates.pop('knowledge_links', None)
  204. provider_ids = updates.pop('provider_ids', None)
  205. if updates:
  206. set_parts = []
  207. params = []
  208. json_fields = ('input', 'output')
  209. for key, value in updates.items():
  210. set_parts.append(f"{key} = %s")
  211. if key in json_fields:
  212. params.append(json.dumps(value))
  213. else:
  214. params.append(value)
  215. params.append(tool_id)
  216. cursor.execute(
  217. f"UPDATE tool SET {', '.join(set_parts)} WHERE id = %s",
  218. params
  219. )
  220. # 更新关联表
  221. if cap_ids is not None:
  222. cursor.execute("DELETE FROM capability_tool WHERE tool_id = %s", (tool_id,))
  223. for cap_id in cap_ids:
  224. cursor.execute(
  225. "INSERT INTO capability_tool (capability_id, tool_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  226. (cap_id, tool_id))
  227. k_links = _normalize_links(
  228. {'knowledge_links': knowledge_links, 'knowledge_ids': knowledge_ids},
  229. 'knowledge_links', 'knowledge_ids', 'related')
  230. if k_links is not None:
  231. cursor.execute("DELETE FROM tool_knowledge WHERE tool_id = %s", (tool_id,))
  232. for kid, rtype in k_links:
  233. cursor.execute(
  234. "INSERT INTO tool_knowledge (tool_id, knowledge_id, relation_type) "
  235. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  236. (tool_id, kid, rtype))
  237. if provider_ids is not None:
  238. cursor.execute("DELETE FROM tool_provider WHERE tool_id = %s", (tool_id,))
  239. for pid in provider_ids:
  240. cursor.execute(
  241. "INSERT INTO tool_provider (tool_id, provider_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  242. (tool_id, pid))
  243. self.conn.commit()
  244. finally:
  245. cursor.close()
  246. def add_knowledge(self, tool_id: str, knowledge_id: str, relation_type: str = 'related'):
  247. """向工具添加一条知识关联(不删除已有关联)"""
  248. cursor = self._get_cursor()
  249. try:
  250. cursor.execute(
  251. "INSERT INTO tool_knowledge (tool_id, knowledge_id, relation_type) "
  252. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  253. (tool_id, knowledge_id, relation_type))
  254. self.conn.commit()
  255. finally:
  256. cursor.close()
  257. def delete(self, tool_id: str):
  258. """删除工具及其关联表记录"""
  259. cursor = self._get_cursor()
  260. try:
  261. cascade_delete(cursor, 'tool', tool_id)
  262. self.conn.commit()
  263. finally:
  264. cursor.close()
  265. def count(self, status: Optional[str] = None) -> int:
  266. """统计工具总数"""
  267. cursor = self._get_cursor()
  268. try:
  269. if status:
  270. cursor.execute("SELECT COUNT(*) as count FROM tool WHERE status = %s", (status,))
  271. else:
  272. cursor.execute("SELECT COUNT(*) as count FROM tool")
  273. return cursor.fetchone()['count']
  274. finally:
  275. cursor.close()
  276. def _format_result(self, row: Dict) -> Dict:
  277. """格式化查询结果,将 JSON 字符串解析为对象"""
  278. if not row:
  279. return None
  280. result = dict(row)
  281. for field in ('input', 'output'):
  282. if field in result and isinstance(result[field], str):
  283. try:
  284. result[field] = json.loads(result[field]) if result[field].strip() else None
  285. except json.JSONDecodeError:
  286. result[field] = None
  287. # 关联字段(来自 junction table 子查询)
  288. for field in ('capability_ids', 'knowledge_ids', 'provider_ids', 'knowledge_links'):
  289. if field in result and isinstance(result[field], str):
  290. result[field] = json.loads(result[field])
  291. elif field in result and result[field] is None:
  292. result[field] = []
  293. return result
  294. def close(self):
  295. if self.conn:
  296. self.conn.close()