pg_tool_store.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. """
  2. PostgreSQL tool 存储封装
  3. 用于存储和检索工具数据,支持向量检索。
  4. 表名:tool(从 tool_table 迁移)
  5. """
  6. import os
  7. import json
  8. import psycopg2
  9. from psycopg2.extras import RealDictCursor
  10. from typing import List, Dict, Optional
  11. from dotenv import load_dotenv
  12. from knowhub.knowhub_db.cascade import cascade_delete
  13. load_dotenv()
  14. # 关联字段子查询
  15. _REL_SUBQUERIES = """
  16. (SELECT COALESCE(json_agg(ct.capability_id), '[]'::json)
  17. FROM capability_tool ct WHERE ct.tool_id = tool.id) AS capability_ids,
  18. (SELECT COALESCE(json_agg(tk.knowledge_id), '[]'::json)
  19. FROM tool_knowledge tk WHERE tk.tool_id = tool.id) AS knowledge_ids,
  20. (SELECT COALESCE(json_agg(tp.provider_id), '[]'::json)
  21. FROM tool_provider tp WHERE tp.tool_id = tool.id) AS provider_ids
  22. """
  23. _BASE_FIELDS = "id, name, version, introduction, tutorial, input, output, updated_time, status"
  24. _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}"
  25. class PostgreSQLToolStore:
  26. def __init__(self):
  27. """初始化 PostgreSQL 连接"""
  28. self.conn = psycopg2.connect(
  29. host=os.getenv('KNOWHUB_DB'),
  30. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  31. user=os.getenv('KNOWHUB_USER'),
  32. password=os.getenv('KNOWHUB_PASSWORD'),
  33. database=os.getenv('KNOWHUB_DB_NAME')
  34. )
  35. self.conn.autocommit = False
  36. print(f"[PostgreSQL Tool] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  37. def _reconnect(self):
  38. self.conn = psycopg2.connect(
  39. host=os.getenv('KNOWHUB_DB'),
  40. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  41. user=os.getenv('KNOWHUB_USER'),
  42. password=os.getenv('KNOWHUB_PASSWORD'),
  43. database=os.getenv('KNOWHUB_DB_NAME')
  44. )
  45. self.conn.autocommit = False
  46. def _ensure_connection(self):
  47. if self.conn.closed != 0:
  48. self._reconnect()
  49. else:
  50. try:
  51. c = self.conn.cursor()
  52. c.execute("SELECT 1")
  53. c.close()
  54. except (psycopg2.OperationalError, psycopg2.InterfaceError):
  55. self._reconnect()
  56. def _get_cursor(self):
  57. self._ensure_connection()
  58. return self.conn.cursor(cursor_factory=RealDictCursor)
  59. def insert_or_update(self, tool: Dict):
  60. """插入或更新工具"""
  61. cursor = self._get_cursor()
  62. try:
  63. cursor.execute("""
  64. INSERT INTO tool (
  65. id, name, version, introduction, tutorial, input, output,
  66. updated_time, status, embedding
  67. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  68. ON CONFLICT (id) DO UPDATE SET
  69. name = EXCLUDED.name,
  70. version = EXCLUDED.version,
  71. introduction = EXCLUDED.introduction,
  72. tutorial = EXCLUDED.tutorial,
  73. input = EXCLUDED.input,
  74. output = EXCLUDED.output,
  75. updated_time = EXCLUDED.updated_time,
  76. status = EXCLUDED.status,
  77. embedding = EXCLUDED.embedding
  78. """, (
  79. tool['id'],
  80. tool.get('name', ''),
  81. tool.get('version'),
  82. tool.get('introduction', ''),
  83. tool.get('tutorial', ''),
  84. json.dumps(tool.get('input', '')),
  85. json.dumps(tool.get('output', '')),
  86. tool.get('updated_time', 0),
  87. tool.get('status', '未接入'),
  88. tool.get('embedding'),
  89. ))
  90. # 写入关联表
  91. tool_id = tool['id']
  92. if 'capability_ids' in tool:
  93. cursor.execute("DELETE FROM capability_tool WHERE tool_id = %s", (tool_id,))
  94. for cap_id in tool['capability_ids']:
  95. cursor.execute(
  96. "INSERT INTO capability_tool (capability_id, tool_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  97. (cap_id, tool_id))
  98. if 'knowledge_ids' in tool:
  99. cursor.execute("DELETE FROM tool_knowledge WHERE tool_id = %s", (tool_id,))
  100. for kid in tool['knowledge_ids']:
  101. cursor.execute(
  102. "INSERT INTO tool_knowledge (tool_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  103. (tool_id, kid))
  104. if 'provider_ids' in tool:
  105. cursor.execute("DELETE FROM tool_provider WHERE tool_id = %s", (tool_id,))
  106. for pid in tool['provider_ids']:
  107. cursor.execute(
  108. "INSERT INTO tool_provider (tool_id, provider_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  109. (tool_id, pid))
  110. self.conn.commit()
  111. finally:
  112. cursor.close()
  113. def get_by_id(self, tool_id: str) -> Optional[Dict]:
  114. """根据 ID 获取工具"""
  115. cursor = self._get_cursor()
  116. try:
  117. cursor.execute(f"""
  118. SELECT {_SELECT_FIELDS}
  119. FROM tool WHERE id = %s
  120. """, (tool_id,))
  121. result = cursor.fetchone()
  122. return self._format_result(result) if result else None
  123. finally:
  124. cursor.close()
  125. def search(self, query_embedding: List[float], limit: int = 10, status: Optional[str] = None) -> List[Dict]:
  126. """向量检索工具"""
  127. cursor = self._get_cursor()
  128. try:
  129. if status:
  130. sql = f"""
  131. SELECT {_SELECT_FIELDS},
  132. 1 - (embedding <=> %s::real[]) as score
  133. FROM tool
  134. WHERE embedding IS NOT NULL AND status = %s
  135. ORDER BY embedding <=> %s::real[]
  136. LIMIT %s
  137. """
  138. params = [query_embedding, status, query_embedding, limit]
  139. else:
  140. sql = f"""
  141. SELECT {_SELECT_FIELDS},
  142. 1 - (embedding <=> %s::real[]) as score
  143. FROM tool
  144. WHERE embedding IS NOT NULL
  145. ORDER BY embedding <=> %s::real[]
  146. LIMIT %s
  147. """
  148. params = [query_embedding, query_embedding, limit]
  149. cursor.execute(sql, params)
  150. results = cursor.fetchall()
  151. return [self._format_result(r) for r in results]
  152. finally:
  153. cursor.close()
  154. def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]:
  155. """列出工具"""
  156. cursor = self._get_cursor()
  157. try:
  158. if status:
  159. cursor.execute(f"""
  160. SELECT {_SELECT_FIELDS}
  161. FROM tool
  162. WHERE status = %s
  163. ORDER BY updated_time DESC
  164. LIMIT %s OFFSET %s
  165. """, (status, limit, offset))
  166. else:
  167. cursor.execute(f"""
  168. SELECT {_SELECT_FIELDS}
  169. FROM tool
  170. ORDER BY updated_time DESC
  171. LIMIT %s OFFSET %s
  172. """, (limit, offset))
  173. results = cursor.fetchall()
  174. return [self._format_result(r) for r in results]
  175. finally:
  176. cursor.close()
  177. def update(self, tool_id: str, updates: Dict):
  178. """更新工具字段"""
  179. cursor = self._get_cursor()
  180. try:
  181. # 分离关联字段
  182. cap_ids = updates.pop('capability_ids', None)
  183. knowledge_ids = updates.pop('knowledge_ids', None)
  184. provider_ids = updates.pop('provider_ids', None)
  185. if updates:
  186. set_parts = []
  187. params = []
  188. json_fields = ('input', 'output')
  189. for key, value in updates.items():
  190. set_parts.append(f"{key} = %s")
  191. if key in json_fields:
  192. params.append(json.dumps(value))
  193. else:
  194. params.append(value)
  195. params.append(tool_id)
  196. cursor.execute(
  197. f"UPDATE tool SET {', '.join(set_parts)} WHERE id = %s",
  198. params
  199. )
  200. # 更新关联表
  201. if cap_ids is not None:
  202. cursor.execute("DELETE FROM capability_tool WHERE tool_id = %s", (tool_id,))
  203. for cap_id in cap_ids:
  204. cursor.execute(
  205. "INSERT INTO capability_tool (capability_id, tool_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  206. (cap_id, tool_id))
  207. if knowledge_ids is not None:
  208. cursor.execute("DELETE FROM tool_knowledge WHERE tool_id = %s", (tool_id,))
  209. for kid in knowledge_ids:
  210. cursor.execute(
  211. "INSERT INTO tool_knowledge (tool_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  212. (tool_id, kid))
  213. if provider_ids is not None:
  214. cursor.execute("DELETE FROM tool_provider WHERE tool_id = %s", (tool_id,))
  215. for pid in provider_ids:
  216. cursor.execute(
  217. "INSERT INTO tool_provider (tool_id, provider_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  218. (tool_id, pid))
  219. self.conn.commit()
  220. finally:
  221. cursor.close()
  222. def add_knowledge(self, tool_id: str, knowledge_id: str):
  223. """向工具添加一条知识关联(不删除已有关联)"""
  224. cursor = self._get_cursor()
  225. try:
  226. cursor.execute(
  227. "INSERT INTO tool_knowledge (tool_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  228. (tool_id, knowledge_id))
  229. self.conn.commit()
  230. finally:
  231. cursor.close()
  232. def delete(self, tool_id: str):
  233. """删除工具及其关联表记录"""
  234. cursor = self._get_cursor()
  235. try:
  236. cascade_delete(cursor, 'tool', tool_id)
  237. self.conn.commit()
  238. finally:
  239. cursor.close()
  240. def count(self, status: Optional[str] = None) -> int:
  241. """统计工具总数"""
  242. cursor = self._get_cursor()
  243. try:
  244. if status:
  245. cursor.execute("SELECT COUNT(*) as count FROM tool WHERE status = %s", (status,))
  246. else:
  247. cursor.execute("SELECT COUNT(*) as count FROM tool")
  248. return cursor.fetchone()['count']
  249. finally:
  250. cursor.close()
  251. def _format_result(self, row: Dict) -> Dict:
  252. """格式化查询结果,将 JSON 字符串解析为对象"""
  253. if not row:
  254. return None
  255. result = dict(row)
  256. for field in ('input', 'output'):
  257. if field in result and isinstance(result[field], str):
  258. try:
  259. result[field] = json.loads(result[field]) if result[field].strip() else None
  260. except json.JSONDecodeError:
  261. result[field] = None
  262. # 关联字段(来自 junction table 子查询)
  263. for field in ('capability_ids', 'knowledge_ids', 'provider_ids'):
  264. if field in result and isinstance(result[field], str):
  265. result[field] = json.loads(result[field])
  266. elif field in result and result[field] is None:
  267. result[field] = []
  268. return result
  269. def close(self):
  270. if self.conn:
  271. self.conn.close()