""" PostgreSQL atomic_capability 存储封装 用于存储和检索原子能力数据,支持向量检索 """ import os import json import psycopg2 from psycopg2.extras import RealDictCursor from typing import List, Dict, Optional from dotenv import load_dotenv load_dotenv() class PostgreSQLCapabilityStore: def __init__(self): """初始化 PostgreSQL 连接""" self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = False print(f"[PostgreSQL Capability] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}") def _reconnect(self): self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = False def _ensure_connection(self): if self.conn.closed != 0: self._reconnect() else: try: c = self.conn.cursor() c.execute("SELECT 1") c.close() except (psycopg2.OperationalError, psycopg2.InterfaceError): self._reconnect() def _get_cursor(self): self._ensure_connection() return self.conn.cursor(cursor_factory=RealDictCursor) def insert_or_update(self, cap: Dict): """插入或更新原子能力""" cursor = self._get_cursor() try: cursor.execute(""" INSERT INTO atomic_capability ( id, name, criterion, description, requirements, implements, tools, source_knowledge, embedding ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, criterion = EXCLUDED.criterion, description = EXCLUDED.description, requirements = EXCLUDED.requirements, implements = EXCLUDED.implements, tools = EXCLUDED.tools, source_knowledge = EXCLUDED.source_knowledge, embedding = EXCLUDED.embedding """, ( cap['id'], cap.get('name', ''), cap.get('criterion', ''), cap.get('description', ''), json.dumps(cap.get('requirements', [])), json.dumps(cap.get('implements', {})), json.dumps(cap.get('tools', [])), json.dumps(cap.get('source_knowledge', [])), cap.get('embedding'), )) self.conn.commit() finally: cursor.close() def get_by_id(self, cap_id: str) -> Optional[Dict]: """根据 ID 获取原子能力""" cursor = self._get_cursor() try: cursor.execute(""" SELECT id, name, criterion, description, requirements, implements, tools, source_knowledge FROM atomic_capability WHERE id = %s """, (cap_id,)) result = cursor.fetchone() return self._format_result(result) if result else None finally: cursor.close() def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]: """向量检索原子能力""" cursor = self._get_cursor() try: cursor.execute(""" SELECT id, name, criterion, description, requirements, implements, tools, source_knowledge, 1 - (embedding <=> %s::real[]) as score FROM atomic_capability WHERE embedding IS NOT NULL ORDER BY embedding <=> %s::real[] LIMIT %s """, (query_embedding, query_embedding, limit)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def list_all(self, limit: int = 100, offset: int = 0) -> List[Dict]: """列出原子能力""" cursor = self._get_cursor() try: cursor.execute(""" SELECT id, name, criterion, description, requirements, implements, tools, source_knowledge FROM atomic_capability ORDER BY id LIMIT %s OFFSET %s """, (limit, offset)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def update(self, cap_id: str, updates: Dict): """更新原子能力字段""" cursor = self._get_cursor() try: set_parts = [] params = [] json_fields = ('requirements', 'implements', 'tools', 'source_knowledge') for key, value in updates.items(): set_parts.append(f"{key} = %s") if key in json_fields: params.append(json.dumps(value)) else: params.append(value) params.append(cap_id) cursor.execute( f"UPDATE atomic_capability SET {', '.join(set_parts)} WHERE id = %s", params ) self.conn.commit() finally: cursor.close() def delete(self, cap_id: str): """删除原子能力""" cursor = self._get_cursor() try: cursor.execute("DELETE FROM atomic_capability WHERE id = %s", (cap_id,)) self.conn.commit() finally: cursor.close() def count(self) -> int: """统计原子能力总数""" cursor = self._get_cursor() try: cursor.execute("SELECT COUNT(*) as count FROM atomic_capability") return cursor.fetchone()['count'] finally: cursor.close() def _format_result(self, row: Dict) -> Dict: """格式化查询结果""" if not row: return None result = dict(row) for field in ('requirements', 'implements', 'tools', 'source_knowledge'): if field in result and isinstance(result[field], str): result[field] = json.loads(result[field]) return result def close(self): if self.conn: self.conn.close()