""" PostgreSQL capability 存储封装 用于存储和检索原子能力数据,支持向量检索。 表名:capability(从 atomic_capability 迁移) """ import os import json import psycopg2 from psycopg2.extras import RealDictCursor from typing import List, Dict, Optional from dotenv import load_dotenv from knowhub.knowhub_db.cascade import cascade_delete load_dotenv() # 关联字段子查询 _REL_SUBQUERIES = """ (SELECT COALESCE(json_agg(rc.requirement_id), '[]'::json) FROM requirement_capability rc WHERE rc.capability_id = capability.id) AS requirement_ids, (SELECT COALESCE(json_agg(ct.tool_id), '[]'::json) FROM capability_tool ct WHERE ct.capability_id = capability.id) AS tool_ids, (SELECT COALESCE( json_object_agg(ct2.tool_id, ct2.description), '{}'::json) FROM capability_tool ct2 WHERE ct2.capability_id = capability.id AND ct2.description != '') AS implements, (SELECT COALESCE(json_agg(ck.knowledge_id), '[]'::json) FROM capability_knowledge ck WHERE ck.capability_id = capability.id) AS knowledge_ids """ _BASE_FIELDS = "id, name, criterion, description" _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}" class PostgreSQLCapabilityStore: def __init__(self): """初始化 PostgreSQL 连接""" self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = False print(f"[PostgreSQL Capability] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}") def _reconnect(self): self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = False def _ensure_connection(self): if self.conn.closed != 0: self._reconnect() else: try: c = self.conn.cursor() c.execute("SELECT 1") c.close() except (psycopg2.OperationalError, psycopg2.InterfaceError): self._reconnect() def _get_cursor(self): self._ensure_connection() return self.conn.cursor(cursor_factory=RealDictCursor) def _save_relations(self, cursor, cap_id: str, data: Dict): """保存 capability 的关联表数据""" if 'requirement_ids' in data: cursor.execute("DELETE FROM requirement_capability WHERE capability_id = %s", (cap_id,)) for req_id in data['requirement_ids']: cursor.execute( "INSERT INTO requirement_capability (requirement_id, capability_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (req_id, cap_id)) # tool_ids + implements 合并写入 capability_tool if 'tool_ids' in data or 'implements' in data: cursor.execute("DELETE FROM capability_tool WHERE capability_id = %s", (cap_id,)) implements = data.get('implements', {}) tool_ids = set(data.get('tool_ids', [])) # 先写 tool_ids 列表中的(附带 implements 的 description) for tool_id in tool_ids: desc = implements.get(tool_id, '') cursor.execute( "INSERT INTO capability_tool (capability_id, tool_id, description) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (cap_id, tool_id, desc)) # 再写 implements 中有但 tool_ids 列表没有的 for tool_id, desc in implements.items(): if tool_id not in tool_ids: cursor.execute( "INSERT INTO capability_tool (capability_id, tool_id, description) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (cap_id, tool_id, desc)) if 'knowledge_ids' in data: cursor.execute("DELETE FROM capability_knowledge WHERE capability_id = %s", (cap_id,)) for kid in data['knowledge_ids']: cursor.execute( "INSERT INTO capability_knowledge (capability_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (cap_id, kid)) def insert_or_update(self, cap: Dict): """插入或更新原子能力""" cursor = self._get_cursor() try: cursor.execute(""" INSERT INTO capability ( id, name, criterion, description, embedding ) VALUES (%s, %s, %s, %s, %s) ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, criterion = EXCLUDED.criterion, description = EXCLUDED.description, embedding = EXCLUDED.embedding """, ( cap['id'], cap.get('name', ''), cap.get('criterion', ''), cap.get('description', ''), cap.get('embedding'), )) self._save_relations(cursor, cap['id'], cap) self.conn.commit() finally: cursor.close() def get_by_id(self, cap_id: str) -> Optional[Dict]: """根据 ID 获取原子能力""" cursor = self._get_cursor() try: cursor.execute(f""" SELECT {_SELECT_FIELDS} FROM capability WHERE id = %s """, (cap_id,)) result = cursor.fetchone() return self._format_result(result) if result else None finally: cursor.close() def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]: """向量检索原子能力""" cursor = self._get_cursor() try: cursor.execute(f""" SELECT {_SELECT_FIELDS}, 1 - (embedding <=> %s::real[]) as score FROM capability WHERE embedding IS NOT NULL ORDER BY embedding <=> %s::real[] LIMIT %s """, (query_embedding, query_embedding, limit)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def list_all(self, limit: int = 100, offset: int = 0) -> List[Dict]: """列出原子能力""" cursor = self._get_cursor() try: cursor.execute(f""" SELECT {_SELECT_FIELDS} FROM capability ORDER BY id LIMIT %s OFFSET %s """, (limit, offset)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def update(self, cap_id: str, updates: Dict): """更新原子能力字段""" cursor = self._get_cursor() try: # 分离关联字段 rel_fields = {} for key in ('requirement_ids', 'implements', 'tool_ids', 'knowledge_ids'): if key in updates: rel_fields[key] = updates.pop(key) if updates: set_parts = [] params = [] for key, value in updates.items(): set_parts.append(f"{key} = %s") params.append(value) params.append(cap_id) cursor.execute( f"UPDATE capability SET {', '.join(set_parts)} WHERE id = %s", params ) if rel_fields: self._save_relations(cursor, cap_id, rel_fields) self.conn.commit() finally: cursor.close() def delete(self, cap_id: str): """删除原子能力及其关联表记录""" cursor = self._get_cursor() try: cascade_delete(cursor, 'capability', cap_id) self.conn.commit() finally: cursor.close() def count(self) -> int: """统计原子能力总数""" cursor = self._get_cursor() try: cursor.execute("SELECT COUNT(*) as count FROM capability") return cursor.fetchone()['count'] finally: cursor.close() def _format_result(self, row: Dict) -> Dict: """格式化查询结果""" if not row: return None result = dict(row) for field in ('requirement_ids', 'tool_ids', 'knowledge_ids'): if field in result and isinstance(result[field], str): result[field] = json.loads(result[field]) elif field in result and result[field] is None: result[field] = [] if 'implements' in result: if isinstance(result['implements'], str): result['implements'] = json.loads(result['implements']) elif result['implements'] is None: result['implements'] = {} return result def close(self): if self.conn: self.conn.close()