""" PostgreSQL tool_table 存储封装 用于存储和检索工具数据,支持向量检索 """ import os import json import psycopg2 from psycopg2.extras import RealDictCursor from typing import List, Dict, Optional from dotenv import load_dotenv load_dotenv() class PostgreSQLToolStore: def __init__(self): """初始化 PostgreSQL 连接""" self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = False print(f"[PostgreSQL Tool] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}") def _reconnect(self): self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = False def _ensure_connection(self): if self.conn.closed != 0: self._reconnect() else: try: c = self.conn.cursor() c.execute("SELECT 1") c.close() except (psycopg2.OperationalError, psycopg2.InterfaceError): self._reconnect() def _get_cursor(self): self._ensure_connection() return self.conn.cursor(cursor_factory=RealDictCursor) def insert_or_update(self, tool: Dict): """插入或更新工具""" cursor = self._get_cursor() try: cursor.execute(""" INSERT INTO tool_table ( id, name, version, introduction, tutorial, input, output, updated_time, status, capabilities, tool_knowledge, case_knowledge, process_knowledge, embedding, implemented_tool_ids ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, version = EXCLUDED.version, introduction = EXCLUDED.introduction, tutorial = EXCLUDED.tutorial, input = EXCLUDED.input, output = EXCLUDED.output, updated_time = EXCLUDED.updated_time, status = EXCLUDED.status, capabilities = EXCLUDED.capabilities, tool_knowledge = EXCLUDED.tool_knowledge, case_knowledge = EXCLUDED.case_knowledge, process_knowledge = EXCLUDED.process_knowledge, embedding = EXCLUDED.embedding, implemented_tool_ids = EXCLUDED.implemented_tool_ids """, ( tool['id'], tool.get('name', ''), tool.get('version'), tool.get('introduction', ''), tool.get('tutorial', ''), json.dumps(tool.get('input', '')), json.dumps(tool.get('output', '')), tool.get('updated_time', 0), tool.get('status', '未接入'), json.dumps(tool.get('capabilities', [])), json.dumps(tool.get('tool_knowledge', [])), json.dumps(tool.get('case_knowledge', [])), json.dumps(tool.get('process_knowledge', [])), tool.get('embedding'), json.dumps(tool.get('implemented_tool_ids', [])), )) self.conn.commit() finally: cursor.close() def get_by_id(self, tool_id: str) -> Optional[Dict]: """根据 ID 获取工具""" cursor = self._get_cursor() try: cursor.execute(""" SELECT id, name, version, introduction, tutorial, input, output, updated_time, status, capabilities, tool_knowledge, case_knowledge, process_knowledge, implemented_tool_ids FROM tool_table WHERE id = %s """, (tool_id,)) result = cursor.fetchone() return self._format_result(result) if result else None finally: cursor.close() def search(self, query_embedding: List[float], limit: int = 10, status: Optional[str] = None) -> List[Dict]: """向量检索工具""" cursor = self._get_cursor() try: where = "WHERE embedding IS NOT NULL" params = [query_embedding, query_embedding, limit] if status: where += " AND status = %s" params = [query_embedding, status, query_embedding, limit] sql = f""" SELECT id, name, version, introduction, tutorial, input, output, updated_time, status, capabilities, tool_knowledge, case_knowledge, process_knowledge, implemented_tool_ids, 1 - (embedding <=> %s::real[]) as score FROM tool_table WHERE embedding IS NOT NULL AND status = %s ORDER BY embedding <=> %s::real[] LIMIT %s """ else: sql = f""" SELECT id, name, version, introduction, tutorial, input, output, updated_time, status, capabilities, tool_knowledge, case_knowledge, process_knowledge, implemented_tool_ids, 1 - (embedding <=> %s::real[]) as score FROM tool_table WHERE embedding IS NOT NULL ORDER BY embedding <=> %s::real[] LIMIT %s """ cursor.execute(sql, params) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]: """列出工具""" cursor = self._get_cursor() try: if status: cursor.execute(""" SELECT id, name, version, introduction, tutorial, input, output, updated_time, status, capabilities, tool_knowledge, case_knowledge, process_knowledge, implemented_tool_ids FROM tool_table WHERE status = %s ORDER BY updated_time DESC LIMIT %s OFFSET %s """, (status, limit, offset)) else: cursor.execute(""" SELECT id, name, version, introduction, tutorial, input, output, updated_time, status, capabilities, tool_knowledge, case_knowledge, process_knowledge, implemented_tool_ids FROM tool_table ORDER BY updated_time DESC LIMIT %s OFFSET %s """, (limit, offset)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def update(self, tool_id: str, updates: Dict): """更新工具字段""" cursor = self._get_cursor() try: set_parts = [] params = [] json_fields = ('input', 'output', 'capabilities', 'tool_knowledge', 'case_knowledge', 'process_knowledge') for key, value in updates.items(): set_parts.append(f"{key} = %s") if key in json_fields: params.append(json.dumps(value)) else: params.append(value) params.append(tool_id) cursor.execute( f"UPDATE tool_table SET {', '.join(set_parts)} WHERE id = %s", params ) self.conn.commit() finally: cursor.close() def delete(self, tool_id: str): """删除工具""" cursor = self._get_cursor() try: cursor.execute("DELETE FROM tool_table WHERE id = %s", (tool_id,)) self.conn.commit() finally: cursor.close() def count(self, status: Optional[str] = None) -> int: """统计工具总数""" cursor = self._get_cursor() try: if status: cursor.execute("SELECT COUNT(*) as count FROM tool_table WHERE status = %s", (status,)) else: cursor.execute("SELECT COUNT(*) as count FROM tool_table") return cursor.fetchone()['count'] finally: cursor.close() def _format_result(self, row: Dict) -> Dict: """格式化查询结果,将 JSON 字符串解析为对象""" if not row: return None result = dict(row) for field in ('input', 'output', 'capabilities', 'tool_knowledge', 'case_knowledge', 'process_knowledge', 'implemented_tool_ids'): if field in result and isinstance(result[field], str): try: result[field] = json.loads(result[field]) if result[field].strip() else None except json.JSONDecodeError: result[field] = None return result def close(self): if self.conn: self.conn.close()