""" PostgreSQL tool 存储封装 用于存储和检索工具数据,支持向量检索。 表名:tool(从 tool_table 迁移) """ import os import json import psycopg2 from psycopg2.extras import RealDictCursor from typing import List, Dict, Optional from dotenv import load_dotenv from knowhub.knowhub_db.cascade import cascade_delete load_dotenv() # 关联字段子查询 _REL_SUBQUERIES = """ (SELECT COALESCE(json_agg(ct.capability_id), '[]'::json) FROM capability_tool ct WHERE ct.tool_id = tool.id) AS capability_ids, (SELECT COALESCE(json_agg(tk.knowledge_id), '[]'::json) FROM tool_knowledge tk WHERE tk.tool_id = tool.id) AS knowledge_ids, (SELECT COALESCE(json_agg(tp.provider_id), '[]'::json) FROM tool_provider tp WHERE tp.tool_id = tool.id) AS provider_ids """ _BASE_FIELDS = "id, name, version, introduction, tutorial, input, output, updated_time, status" _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}" class PostgreSQLToolStore: def __init__(self): """初始化 PostgreSQL 连接""" self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = False print(f"[PostgreSQL Tool] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}") def _reconnect(self): self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = False def _ensure_connection(self): if self.conn.closed != 0: self._reconnect() else: try: c = self.conn.cursor() c.execute("SELECT 1") c.close() except (psycopg2.OperationalError, psycopg2.InterfaceError): self._reconnect() def _get_cursor(self): self._ensure_connection() return self.conn.cursor(cursor_factory=RealDictCursor) def insert_or_update(self, tool: Dict): """插入或更新工具""" cursor = self._get_cursor() try: cursor.execute(""" INSERT INTO tool ( id, name, version, introduction, tutorial, input, output, updated_time, status, embedding ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, version = EXCLUDED.version, introduction = EXCLUDED.introduction, tutorial = EXCLUDED.tutorial, input = EXCLUDED.input, output = EXCLUDED.output, updated_time = EXCLUDED.updated_time, status = EXCLUDED.status, embedding = EXCLUDED.embedding """, ( tool['id'], tool.get('name', ''), tool.get('version'), tool.get('introduction', ''), tool.get('tutorial', ''), json.dumps(tool.get('input', '')), json.dumps(tool.get('output', '')), tool.get('updated_time', 0), tool.get('status', '未接入'), tool.get('embedding'), )) # 写入关联表 tool_id = tool['id'] if 'capability_ids' in tool: cursor.execute("DELETE FROM capability_tool WHERE tool_id = %s", (tool_id,)) for cap_id in tool['capability_ids']: cursor.execute( "INSERT INTO capability_tool (capability_id, tool_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (cap_id, tool_id)) if 'knowledge_ids' in tool: cursor.execute("DELETE FROM tool_knowledge WHERE tool_id = %s", (tool_id,)) for kid in tool['knowledge_ids']: cursor.execute( "INSERT INTO tool_knowledge (tool_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (tool_id, kid)) if 'provider_ids' in tool: cursor.execute("DELETE FROM tool_provider WHERE tool_id = %s", (tool_id,)) for pid in tool['provider_ids']: cursor.execute( "INSERT INTO tool_provider (tool_id, provider_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (tool_id, pid)) self.conn.commit() finally: cursor.close() def get_by_id(self, tool_id: str) -> Optional[Dict]: """根据 ID 获取工具""" cursor = self._get_cursor() try: cursor.execute(f""" SELECT {_SELECT_FIELDS} FROM tool WHERE id = %s """, (tool_id,)) result = cursor.fetchone() return self._format_result(result) if result else None finally: cursor.close() def search(self, query_embedding: List[float], limit: int = 10, status: Optional[str] = None) -> List[Dict]: """向量检索工具""" cursor = self._get_cursor() try: if status: sql = f""" SELECT {_SELECT_FIELDS}, 1 - (embedding <=> %s::real[]) as score FROM tool WHERE embedding IS NOT NULL AND status = %s ORDER BY embedding <=> %s::real[] LIMIT %s """ params = [query_embedding, status, query_embedding, limit] else: sql = f""" SELECT {_SELECT_FIELDS}, 1 - (embedding <=> %s::real[]) as score FROM tool WHERE embedding IS NOT NULL ORDER BY embedding <=> %s::real[] LIMIT %s """ params = [query_embedding, query_embedding, limit] cursor.execute(sql, params) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]: """列出工具""" cursor = self._get_cursor() try: if status: cursor.execute(f""" SELECT {_SELECT_FIELDS} FROM tool WHERE status = %s ORDER BY updated_time DESC LIMIT %s OFFSET %s """, (status, limit, offset)) else: cursor.execute(f""" SELECT {_SELECT_FIELDS} FROM tool ORDER BY updated_time DESC LIMIT %s OFFSET %s """, (limit, offset)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def update(self, tool_id: str, updates: Dict): """更新工具字段""" cursor = self._get_cursor() try: # 分离关联字段 cap_ids = updates.pop('capability_ids', None) knowledge_ids = updates.pop('knowledge_ids', None) provider_ids = updates.pop('provider_ids', None) if updates: set_parts = [] params = [] json_fields = ('input', 'output') for key, value in updates.items(): set_parts.append(f"{key} = %s") if key in json_fields: params.append(json.dumps(value)) else: params.append(value) params.append(tool_id) cursor.execute( f"UPDATE tool SET {', '.join(set_parts)} WHERE id = %s", params ) # 更新关联表 if cap_ids is not None: cursor.execute("DELETE FROM capability_tool WHERE tool_id = %s", (tool_id,)) for cap_id in cap_ids: cursor.execute( "INSERT INTO capability_tool (capability_id, tool_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (cap_id, tool_id)) if knowledge_ids is not None: cursor.execute("DELETE FROM tool_knowledge WHERE tool_id = %s", (tool_id,)) for kid in knowledge_ids: cursor.execute( "INSERT INTO tool_knowledge (tool_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (tool_id, kid)) if provider_ids is not None: cursor.execute("DELETE FROM tool_provider WHERE tool_id = %s", (tool_id,)) for pid in provider_ids: cursor.execute( "INSERT INTO tool_provider (tool_id, provider_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (tool_id, pid)) self.conn.commit() finally: cursor.close() def add_knowledge(self, tool_id: str, knowledge_id: str): """向工具添加一条知识关联(不删除已有关联)""" cursor = self._get_cursor() try: cursor.execute( "INSERT INTO tool_knowledge (tool_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (tool_id, knowledge_id)) self.conn.commit() finally: cursor.close() def delete(self, tool_id: str): """删除工具及其关联表记录""" cursor = self._get_cursor() try: cascade_delete(cursor, 'tool', tool_id) self.conn.commit() finally: cursor.close() def count(self, status: Optional[str] = None) -> int: """统计工具总数""" cursor = self._get_cursor() try: if status: cursor.execute("SELECT COUNT(*) as count FROM tool WHERE status = %s", (status,)) else: cursor.execute("SELECT COUNT(*) as count FROM tool") return cursor.fetchone()['count'] finally: cursor.close() def _format_result(self, row: Dict) -> Dict: """格式化查询结果,将 JSON 字符串解析为对象""" if not row: return None result = dict(row) for field in ('input', 'output'): if field in result and isinstance(result[field], str): try: result[field] = json.loads(result[field]) if result[field].strip() else None except json.JSONDecodeError: result[field] = None # 关联字段(来自 junction table 子查询) for field in ('capability_ids', 'knowledge_ids', 'provider_ids'): if field in result and isinstance(result[field], str): result[field] = json.loads(result[field]) elif field in result and result[field] is None: result[field] = [] return result def close(self): if self.conn: self.conn.close()