""" PostgreSQL requirement_table 存储封装(v2 新 schema) 字段:id, description, atomics, source_nodes, status, match_result, embedding """ import os import json import psycopg2 from psycopg2.extras import RealDictCursor from typing import List, Dict, Optional from dotenv import load_dotenv load_dotenv() class PostgreSQLRequirementStore: def __init__(self): """初始化 PostgreSQL 连接""" self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = False print(f"[PostgreSQL Requirement] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}") def _reconnect(self): self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = False def _ensure_connection(self): if self.conn.closed != 0: self._reconnect() else: try: c = self.conn.cursor() c.execute("SELECT 1") c.close() except (psycopg2.OperationalError, psycopg2.InterfaceError): self._reconnect() def _get_cursor(self): self._ensure_connection() return self.conn.cursor(cursor_factory=RealDictCursor) def insert_or_update(self, requirement: Dict): """插入或更新需求记录""" cursor = self._get_cursor() try: cursor.execute(""" INSERT INTO requirement_table ( id, description, atomics, source_nodes, status, match_result, embedding ) VALUES (%s, %s, %s, %s, %s, %s, %s) ON CONFLICT (id) DO UPDATE SET description = EXCLUDED.description, atomics = EXCLUDED.atomics, source_nodes = EXCLUDED.source_nodes, status = EXCLUDED.status, match_result = EXCLUDED.match_result, embedding = EXCLUDED.embedding """, ( requirement['id'], requirement.get('description', ''), json.dumps(requirement.get('atomics', [])), json.dumps(requirement.get('source_nodes', [])), requirement.get('status', '未满足'), requirement.get('match_result', ''), requirement.get('embedding'), )) self.conn.commit() finally: cursor.close() def get_by_id(self, req_id: str) -> Optional[Dict]: """根据 ID 获取需求""" cursor = self._get_cursor() try: cursor.execute(""" SELECT id, description, atomics, source_nodes, status, match_result FROM requirement_table WHERE id = %s """, (req_id,)) result = cursor.fetchone() return self._format_result(result) if result else None finally: cursor.close() def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]: """向量检索需求""" cursor = self._get_cursor() try: cursor.execute(""" SELECT id, description, atomics, source_nodes, status, match_result, 1 - (embedding <=> %s::real[]) as score FROM requirement_table WHERE embedding IS NOT NULL ORDER BY embedding <=> %s::real[] LIMIT %s """, (query_embedding, query_embedding, limit)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]: """列出需求""" cursor = self._get_cursor() try: if status: cursor.execute(""" SELECT id, description, atomics, source_nodes, status, match_result FROM requirement_table WHERE status = %s ORDER BY id LIMIT %s OFFSET %s """, (status, limit, offset)) else: cursor.execute(""" SELECT id, description, atomics, source_nodes, status, match_result FROM requirement_table ORDER BY id LIMIT %s OFFSET %s """, (limit, offset)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def update(self, req_id: str, updates: Dict): """更新需求字段""" cursor = self._get_cursor() try: set_parts = [] params = [] json_fields = ('atomics', 'source_nodes') for key, value in updates.items(): set_parts.append(f"{key} = %s") if key in json_fields: params.append(json.dumps(value)) else: params.append(value) params.append(req_id) cursor.execute( f"UPDATE requirement_table SET {', '.join(set_parts)} WHERE id = %s", params ) self.conn.commit() finally: cursor.close() def delete(self, req_id: str): """删除需求""" cursor = self._get_cursor() try: cursor.execute("DELETE FROM requirement_table WHERE id = %s", (req_id,)) self.conn.commit() finally: cursor.close() def count(self, status: Optional[str] = None) -> int: """统计需求总数""" cursor = self._get_cursor() try: if status: cursor.execute("SELECT COUNT(*) as count FROM requirement_table WHERE status = %s", (status,)) else: cursor.execute("SELECT COUNT(*) as count FROM requirement_table") return cursor.fetchone()['count'] finally: cursor.close() def _format_result(self, row: Dict) -> Dict: """格式化查询结果""" if not row: return None result = dict(row) for field in ('atomics', 'source_nodes'): if field in result and isinstance(result[field], str): result[field] = json.loads(result[field]) return result def close(self): if self.conn: self.conn.close()