pg_tool_store.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. """
  2. PostgreSQL tool_table 存储封装
  3. 用于存储和检索工具数据,支持向量检索
  4. """
  5. import os
  6. import json
  7. import psycopg2
  8. from psycopg2.extras import RealDictCursor
  9. from typing import List, Dict, Optional
  10. from dotenv import load_dotenv
  11. load_dotenv()
  12. class PostgreSQLToolStore:
  13. def __init__(self):
  14. """初始化 PostgreSQL 连接"""
  15. self.conn = psycopg2.connect(
  16. host=os.getenv('KNOWHUB_DB'),
  17. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  18. user=os.getenv('KNOWHUB_USER'),
  19. password=os.getenv('KNOWHUB_PASSWORD'),
  20. database=os.getenv('KNOWHUB_DB_NAME')
  21. )
  22. self.conn.autocommit = False
  23. print(f"[PostgreSQL Tool] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  24. def _get_cursor(self):
  25. return self.conn.cursor(cursor_factory=RealDictCursor)
  26. def insert_or_update(self, tool: Dict):
  27. """插入或更新工具"""
  28. cursor = self._get_cursor()
  29. try:
  30. cursor.execute("""
  31. INSERT INTO tool_table (
  32. id, name, version, introduction, tutorial, input, output,
  33. updated_time, status, capabilities, tool_knowledge,
  34. case_knowledge, process_knowledge, embedding
  35. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  36. ON CONFLICT (id) DO UPDATE SET
  37. name = EXCLUDED.name,
  38. version = EXCLUDED.version,
  39. introduction = EXCLUDED.introduction,
  40. tutorial = EXCLUDED.tutorial,
  41. input = EXCLUDED.input,
  42. output = EXCLUDED.output,
  43. updated_time = EXCLUDED.updated_time,
  44. status = EXCLUDED.status,
  45. capabilities = EXCLUDED.capabilities,
  46. tool_knowledge = EXCLUDED.tool_knowledge,
  47. case_knowledge = EXCLUDED.case_knowledge,
  48. process_knowledge = EXCLUDED.process_knowledge,
  49. embedding = EXCLUDED.embedding
  50. """, (
  51. tool['id'],
  52. tool.get('name', ''),
  53. tool.get('version'),
  54. tool.get('introduction', ''),
  55. tool.get('tutorial', ''),
  56. json.dumps(tool.get('input', '')),
  57. json.dumps(tool.get('output', '')),
  58. tool.get('updated_time', 0),
  59. tool.get('status', '未接入'),
  60. json.dumps(tool.get('capabilities', [])),
  61. json.dumps(tool.get('tool_knowledge', [])),
  62. json.dumps(tool.get('case_knowledge', [])),
  63. json.dumps(tool.get('process_knowledge', [])),
  64. tool.get('embedding'),
  65. ))
  66. self.conn.commit()
  67. finally:
  68. cursor.close()
  69. def get_by_id(self, tool_id: str) -> Optional[Dict]:
  70. """根据 ID 获取工具"""
  71. cursor = self._get_cursor()
  72. try:
  73. cursor.execute("""
  74. SELECT id, name, version, introduction, tutorial, input, output,
  75. updated_time, status, capabilities, tool_knowledge,
  76. case_knowledge, process_knowledge
  77. FROM tool_table WHERE id = %s
  78. """, (tool_id,))
  79. result = cursor.fetchone()
  80. return self._format_result(result) if result else None
  81. finally:
  82. cursor.close()
  83. def search(self, query_embedding: List[float], limit: int = 10, status: Optional[str] = None) -> List[Dict]:
  84. """向量检索工具"""
  85. cursor = self._get_cursor()
  86. try:
  87. where = "WHERE embedding IS NOT NULL"
  88. params = [query_embedding, query_embedding, limit]
  89. if status:
  90. where += " AND status = %s"
  91. params = [query_embedding, status, query_embedding, limit]
  92. sql = f"""
  93. SELECT id, name, version, introduction, tutorial, input, output,
  94. updated_time, status, capabilities, tool_knowledge,
  95. case_knowledge, process_knowledge,
  96. 1 - (embedding <=> %s::real[]) as score
  97. FROM tool_table
  98. WHERE embedding IS NOT NULL AND status = %s
  99. ORDER BY embedding <=> %s::real[]
  100. LIMIT %s
  101. """
  102. else:
  103. sql = f"""
  104. SELECT id, name, version, introduction, tutorial, input, output,
  105. updated_time, status, capabilities, tool_knowledge,
  106. case_knowledge, process_knowledge,
  107. 1 - (embedding <=> %s::real[]) as score
  108. FROM tool_table
  109. WHERE embedding IS NOT NULL
  110. ORDER BY embedding <=> %s::real[]
  111. LIMIT %s
  112. """
  113. cursor.execute(sql, params)
  114. results = cursor.fetchall()
  115. return [self._format_result(r) for r in results]
  116. finally:
  117. cursor.close()
  118. def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]:
  119. """列出工具"""
  120. cursor = self._get_cursor()
  121. try:
  122. if status:
  123. cursor.execute("""
  124. SELECT id, name, version, introduction, tutorial, input, output,
  125. updated_time, status, capabilities, tool_knowledge,
  126. case_knowledge, process_knowledge
  127. FROM tool_table
  128. WHERE status = %s
  129. ORDER BY updated_time DESC
  130. LIMIT %s OFFSET %s
  131. """, (status, limit, offset))
  132. else:
  133. cursor.execute("""
  134. SELECT id, name, version, introduction, tutorial, input, output,
  135. updated_time, status, capabilities, tool_knowledge,
  136. case_knowledge, process_knowledge
  137. FROM tool_table
  138. ORDER BY updated_time DESC
  139. LIMIT %s OFFSET %s
  140. """, (limit, offset))
  141. results = cursor.fetchall()
  142. return [self._format_result(r) for r in results]
  143. finally:
  144. cursor.close()
  145. def update(self, tool_id: str, updates: Dict):
  146. """更新工具字段"""
  147. cursor = self._get_cursor()
  148. try:
  149. set_parts = []
  150. params = []
  151. json_fields = ('input', 'output', 'capabilities', 'tool_knowledge',
  152. 'case_knowledge', 'process_knowledge')
  153. for key, value in updates.items():
  154. set_parts.append(f"{key} = %s")
  155. if key in json_fields:
  156. params.append(json.dumps(value))
  157. else:
  158. params.append(value)
  159. params.append(tool_id)
  160. cursor.execute(
  161. f"UPDATE tool_table SET {', '.join(set_parts)} WHERE id = %s",
  162. params
  163. )
  164. self.conn.commit()
  165. finally:
  166. cursor.close()
  167. def delete(self, tool_id: str):
  168. """删除工具"""
  169. cursor = self._get_cursor()
  170. try:
  171. cursor.execute("DELETE FROM tool_table WHERE id = %s", (tool_id,))
  172. self.conn.commit()
  173. finally:
  174. cursor.close()
  175. def count(self, status: Optional[str] = None) -> int:
  176. """统计工具总数"""
  177. cursor = self._get_cursor()
  178. try:
  179. if status:
  180. cursor.execute("SELECT COUNT(*) as count FROM tool_table WHERE status = %s", (status,))
  181. else:
  182. cursor.execute("SELECT COUNT(*) as count FROM tool_table")
  183. return cursor.fetchone()['count']
  184. finally:
  185. cursor.close()
  186. def _format_result(self, row: Dict) -> Dict:
  187. """格式化查询结果,将 JSON 字符串解析为对象"""
  188. if not row:
  189. return None
  190. result = dict(row)
  191. for field in ('input', 'output', 'capabilities', 'tool_knowledge',
  192. 'case_knowledge', 'process_knowledge'):
  193. if field in result and isinstance(result[field], str):
  194. try:
  195. result[field] = json.loads(result[field]) if result[field].strip() else None
  196. except json.JSONDecodeError:
  197. result[field] = None
  198. return result
  199. def close(self):
  200. if self.conn:
  201. self.conn.close()