pg_tool_store.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. """
  2. PostgreSQL tool_table 存储封装
  3. 用于存储和检索工具数据,支持向量检索
  4. """
  5. import os
  6. import json
  7. import psycopg2
  8. from psycopg2.extras import RealDictCursor
  9. from typing import List, Dict, Optional
  10. from dotenv import load_dotenv
  11. load_dotenv()
  12. class PostgreSQLToolStore:
  13. def __init__(self):
  14. """初始化 PostgreSQL 连接"""
  15. self.conn = psycopg2.connect(
  16. host=os.getenv('KNOWHUB_DB'),
  17. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  18. user=os.getenv('KNOWHUB_USER'),
  19. password=os.getenv('KNOWHUB_PASSWORD'),
  20. database=os.getenv('KNOWHUB_DB_NAME')
  21. )
  22. self.conn.autocommit = False
  23. print(f"[PostgreSQL Tool] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  24. def _reconnect(self):
  25. self.conn = psycopg2.connect(
  26. host=os.getenv('KNOWHUB_DB'),
  27. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  28. user=os.getenv('KNOWHUB_USER'),
  29. password=os.getenv('KNOWHUB_PASSWORD'),
  30. database=os.getenv('KNOWHUB_DB_NAME')
  31. )
  32. self.conn.autocommit = False
  33. def _ensure_connection(self):
  34. if self.conn.closed != 0:
  35. self._reconnect()
  36. else:
  37. try:
  38. c = self.conn.cursor()
  39. c.execute("SELECT 1")
  40. c.close()
  41. except (psycopg2.OperationalError, psycopg2.InterfaceError):
  42. self._reconnect()
  43. def _get_cursor(self):
  44. self._ensure_connection()
  45. return self.conn.cursor(cursor_factory=RealDictCursor)
  46. def insert_or_update(self, tool: Dict):
  47. """插入或更新工具"""
  48. cursor = self._get_cursor()
  49. try:
  50. cursor.execute("""
  51. INSERT INTO tool_table (
  52. id, name, version, introduction, tutorial, input, output,
  53. updated_time, status, capabilities, tool_knowledge,
  54. case_knowledge, process_knowledge, embedding, implemented_tool_ids
  55. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  56. ON CONFLICT (id) DO UPDATE SET
  57. name = EXCLUDED.name,
  58. version = EXCLUDED.version,
  59. introduction = EXCLUDED.introduction,
  60. tutorial = EXCLUDED.tutorial,
  61. input = EXCLUDED.input,
  62. output = EXCLUDED.output,
  63. updated_time = EXCLUDED.updated_time,
  64. status = EXCLUDED.status,
  65. capabilities = EXCLUDED.capabilities,
  66. tool_knowledge = EXCLUDED.tool_knowledge,
  67. case_knowledge = EXCLUDED.case_knowledge,
  68. process_knowledge = EXCLUDED.process_knowledge,
  69. embedding = EXCLUDED.embedding,
  70. implemented_tool_ids = EXCLUDED.implemented_tool_ids
  71. """, (
  72. tool['id'],
  73. tool.get('name', ''),
  74. tool.get('version'),
  75. tool.get('introduction', ''),
  76. tool.get('tutorial', ''),
  77. json.dumps(tool.get('input', '')),
  78. json.dumps(tool.get('output', '')),
  79. tool.get('updated_time', 0),
  80. tool.get('status', '未接入'),
  81. json.dumps(tool.get('capabilities', [])),
  82. json.dumps(tool.get('tool_knowledge', [])),
  83. json.dumps(tool.get('case_knowledge', [])),
  84. json.dumps(tool.get('process_knowledge', [])),
  85. tool.get('embedding'),
  86. json.dumps(tool.get('implemented_tool_ids', [])),
  87. ))
  88. self.conn.commit()
  89. finally:
  90. cursor.close()
  91. def get_by_id(self, tool_id: str) -> Optional[Dict]:
  92. """根据 ID 获取工具"""
  93. cursor = self._get_cursor()
  94. try:
  95. cursor.execute("""
  96. SELECT id, name, version, introduction, tutorial, input, output,
  97. updated_time, status, capabilities, tool_knowledge,
  98. case_knowledge, process_knowledge, implemented_tool_ids
  99. FROM tool_table WHERE id = %s
  100. """, (tool_id,))
  101. result = cursor.fetchone()
  102. return self._format_result(result) if result else None
  103. finally:
  104. cursor.close()
  105. def search(self, query_embedding: List[float], limit: int = 10, status: Optional[str] = None) -> List[Dict]:
  106. """向量检索工具"""
  107. cursor = self._get_cursor()
  108. try:
  109. where = "WHERE embedding IS NOT NULL"
  110. params = [query_embedding, query_embedding, limit]
  111. if status:
  112. where += " AND status = %s"
  113. params = [query_embedding, status, query_embedding, limit]
  114. sql = f"""
  115. SELECT id, name, version, introduction, tutorial, input, output,
  116. updated_time, status, capabilities, tool_knowledge,
  117. case_knowledge, process_knowledge, implemented_tool_ids,
  118. 1 - (embedding <=> %s::real[]) as score
  119. FROM tool_table
  120. WHERE embedding IS NOT NULL AND status = %s
  121. ORDER BY embedding <=> %s::real[]
  122. LIMIT %s
  123. """
  124. else:
  125. sql = f"""
  126. SELECT id, name, version, introduction, tutorial, input, output,
  127. updated_time, status, capabilities, tool_knowledge,
  128. case_knowledge, process_knowledge, implemented_tool_ids,
  129. 1 - (embedding <=> %s::real[]) as score
  130. FROM tool_table
  131. WHERE embedding IS NOT NULL
  132. ORDER BY embedding <=> %s::real[]
  133. LIMIT %s
  134. """
  135. cursor.execute(sql, params)
  136. results = cursor.fetchall()
  137. return [self._format_result(r) for r in results]
  138. finally:
  139. cursor.close()
  140. def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]:
  141. """列出工具"""
  142. cursor = self._get_cursor()
  143. try:
  144. if status:
  145. cursor.execute("""
  146. SELECT id, name, version, introduction, tutorial, input, output,
  147. updated_time, status, capabilities, tool_knowledge,
  148. case_knowledge, process_knowledge, implemented_tool_ids
  149. FROM tool_table
  150. WHERE status = %s
  151. ORDER BY updated_time DESC
  152. LIMIT %s OFFSET %s
  153. """, (status, limit, offset))
  154. else:
  155. cursor.execute("""
  156. SELECT id, name, version, introduction, tutorial, input, output,
  157. updated_time, status, capabilities, tool_knowledge,
  158. case_knowledge, process_knowledge, implemented_tool_ids
  159. FROM tool_table
  160. ORDER BY updated_time DESC
  161. LIMIT %s OFFSET %s
  162. """, (limit, offset))
  163. results = cursor.fetchall()
  164. return [self._format_result(r) for r in results]
  165. finally:
  166. cursor.close()
  167. def update(self, tool_id: str, updates: Dict):
  168. """更新工具字段"""
  169. cursor = self._get_cursor()
  170. try:
  171. set_parts = []
  172. params = []
  173. json_fields = ('input', 'output', 'capabilities', 'tool_knowledge',
  174. 'case_knowledge', 'process_knowledge')
  175. for key, value in updates.items():
  176. set_parts.append(f"{key} = %s")
  177. if key in json_fields:
  178. params.append(json.dumps(value))
  179. else:
  180. params.append(value)
  181. params.append(tool_id)
  182. cursor.execute(
  183. f"UPDATE tool_table SET {', '.join(set_parts)} WHERE id = %s",
  184. params
  185. )
  186. self.conn.commit()
  187. finally:
  188. cursor.close()
  189. def delete(self, tool_id: str):
  190. """删除工具"""
  191. cursor = self._get_cursor()
  192. try:
  193. cursor.execute("DELETE FROM tool_table WHERE id = %s", (tool_id,))
  194. self.conn.commit()
  195. finally:
  196. cursor.close()
  197. def count(self, status: Optional[str] = None) -> int:
  198. """统计工具总数"""
  199. cursor = self._get_cursor()
  200. try:
  201. if status:
  202. cursor.execute("SELECT COUNT(*) as count FROM tool_table WHERE status = %s", (status,))
  203. else:
  204. cursor.execute("SELECT COUNT(*) as count FROM tool_table")
  205. return cursor.fetchone()['count']
  206. finally:
  207. cursor.close()
  208. def _format_result(self, row: Dict) -> Dict:
  209. """格式化查询结果,将 JSON 字符串解析为对象"""
  210. if not row:
  211. return None
  212. result = dict(row)
  213. for field in ('input', 'output', 'capabilities', 'tool_knowledge',
  214. 'case_knowledge', 'process_knowledge', 'implemented_tool_ids'):
  215. if field in result and isinstance(result[field], str):
  216. try:
  217. result[field] = json.loads(result[field]) if result[field].strip() else None
  218. except json.JSONDecodeError:
  219. result[field] = None
  220. return result
  221. def close(self):
  222. if self.conn:
  223. self.conn.close()