""" PostgreSQL requirement 存储封装 用于存储和检索需求数据,支持向量检索。 表名:requirement(从 requirement_table 迁移) """ import os import json import psycopg2 from psycopg2.extras import RealDictCursor from typing import List, Dict, Optional from dotenv import load_dotenv from knowhub.knowhub_db.cascade import cascade_delete load_dotenv() # 关联字段子查询。knowledge 边暴露两种视图:knowledge_ids(扁平)+ knowledge_links(含 type) _REL_SUBQUERY = """ (SELECT COALESCE(json_agg(rc.capability_id), '[]'::json) FROM requirement_capability rc WHERE rc.requirement_id = requirement.id) AS capability_ids, (SELECT COALESCE(json_agg(rk.knowledge_id), '[]'::json) FROM requirement_knowledge rk WHERE rk.requirement_id = requirement.id) AS knowledge_ids, (SELECT COALESCE(json_agg(json_build_object( 'id', rk2.knowledge_id, 'relation_type', rk2.relation_type )), '[]'::json) FROM requirement_knowledge rk2 WHERE rk2.requirement_id = requirement.id) AS knowledge_links, (SELECT COALESCE(json_agg(rr.resource_id), '[]'::json) FROM requirement_resource rr WHERE rr.requirement_id = requirement.id) AS resource_ids, (SELECT COALESCE(json_agg(rs.strategy_id), '[]'::json) FROM requirement_strategy rs WHERE rs.requirement_id = requirement.id) AS strategy_ids, (SELECT COALESCE(json_agg(rp.itemset_id), '[]'::json) FROM requirement_pattern rp WHERE rp.requirement_id = requirement.id) AS pattern_ids, (SELECT COALESCE(json_agg(rn.node_id), '[]'::json) FROM requirement_node rn WHERE rn.requirement_id = requirement.id) AS node_ids """ _BASE_FIELDS = "id, description, source_nodes, status, match_result, version" _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERY}" def _normalize_links(data: Dict, links_key: str, ids_key: str, default_type: str): """两种输入格式统一:{links_key: [{id, relation_type}]} 或 {ids_key: [id]}""" if links_key in data and data[links_key] is not None: out = [] for item in data[links_key]: if isinstance(item, dict): out.append((item['id'], item.get('relation_type', default_type))) else: out.append((item, default_type)) return out if ids_key in data and data[ids_key] is not None: return [(i, default_type) for i in data[ids_key]] return None class PostgreSQLRequirementStore: def __init__(self): """初始化 PostgreSQL 连接""" self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = True print(f"[PostgreSQL Requirement] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}") def _reconnect(self): self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = True def _ensure_connection(self): if self.conn.closed != 0: self._reconnect() else: try: c = self.conn.cursor() c.execute("SELECT 1") c.close() except (psycopg2.OperationalError, psycopg2.InterfaceError): self._reconnect() def _get_cursor(self): self._ensure_connection() return self.conn.cursor(cursor_factory=RealDictCursor) def insert_or_update(self, requirement: Dict): """插入或更新需求记录。AnalyticDB beam 表不支持 ON CONFLICT UPDATE 当含 ALTER 新增列,改用 DELETE+INSERT。""" cursor = self._get_cursor() try: cursor.execute("DELETE FROM requirement WHERE id = %s", (requirement['id'],)) cursor.execute(""" INSERT INTO requirement ( id, description, source_nodes, status, match_result, embedding, version ) VALUES (%s, %s, %s, %s, %s, %s, %s) """, ( requirement['id'], requirement.get('description', ''), json.dumps(requirement.get('source_nodes', [])), requirement.get('status', '未满足'), requirement.get('match_result', ''), requirement.get('embedding'), requirement.get('version', 'v0'), )) # 写入关联表 req_id = requirement['id'] if 'capability_ids' in requirement: cursor.execute("DELETE FROM requirement_capability WHERE requirement_id = %s", (req_id,)) for cap_id in requirement['capability_ids']: cursor.execute( "INSERT INTO requirement_capability (requirement_id, capability_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (req_id, cap_id)) k_links = _normalize_links(requirement, 'knowledge_links', 'knowledge_ids', 'related') if k_links is not None: cursor.execute("DELETE FROM requirement_knowledge WHERE requirement_id = %s", (req_id,)) for kid, rtype in k_links: cursor.execute( "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (req_id, kid, rtype)) if 'resource_ids' in requirement and requirement['resource_ids'] is not None: cursor.execute("DELETE FROM requirement_resource WHERE requirement_id = %s", (req_id,)) for rid in requirement['resource_ids']: cursor.execute( "INSERT INTO requirement_resource (requirement_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (req_id, rid)) if 'strategy_ids' in requirement and requirement['strategy_ids'] is not None: cursor.execute("DELETE FROM requirement_strategy WHERE requirement_id = %s", (req_id,)) for sid in requirement['strategy_ids']: cursor.execute( "INSERT INTO requirement_strategy (requirement_id, strategy_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (req_id, sid)) self.conn.commit() finally: cursor.close() def get_by_id(self, req_id: str) -> Optional[Dict]: """根据 ID 获取需求""" from knowhub.knowhub_db.version_context import req_version_where cursor = self._get_cursor() try: vf, vp = req_version_where() cursor.execute(f""" SELECT {_SELECT_FIELDS} FROM requirement WHERE id = %s AND {vf} """, (req_id, *vp)) result = cursor.fetchone() return self._format_result(result) if result else None finally: cursor.close() def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]: """向量检索需求""" from knowhub.knowhub_db.version_context import req_version_where cursor = self._get_cursor() try: vf, vp = req_version_where() cursor.execute(f""" SELECT {_SELECT_FIELDS}, 1 - (embedding <=> %s::real[]) as score FROM requirement WHERE embedding IS NOT NULL AND {vf} ORDER BY embedding <=> %s::real[] LIMIT %s """, (query_embedding, *vp, query_embedding, limit)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]: """列出需求""" from knowhub.knowhub_db.version_context import req_version_where cursor = self._get_cursor() try: vf, vp = req_version_where() if status: cursor.execute(f""" SELECT {_SELECT_FIELDS} FROM requirement WHERE status = %s AND {vf} ORDER BY id LIMIT %s OFFSET %s """, (status, *vp, limit, offset)) else: cursor.execute(f""" SELECT {_SELECT_FIELDS} FROM requirement WHERE {vf} ORDER BY id LIMIT %s OFFSET %s """, (*vp, limit, offset)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def update(self, req_id: str, updates: Dict): """更新需求字段""" cursor = self._get_cursor() try: # 分离关联字段 cap_ids = updates.pop('capability_ids', None) strategy_ids = updates.pop('strategy_ids', None) rel_data = {} for k in ('knowledge_ids', 'knowledge_links', 'resource_ids'): if k in updates: rel_data[k] = updates.pop(k) if updates: set_parts = [] params = [] json_fields = ('source_nodes',) for key, value in updates.items(): set_parts.append(f"{key} = %s") if key in json_fields: params.append(json.dumps(value)) else: params.append(value) params.append(req_id) cursor.execute( f"UPDATE requirement SET {', '.join(set_parts)} WHERE id = %s", params ) if cap_ids is not None: cursor.execute("DELETE FROM requirement_capability WHERE requirement_id = %s", (req_id,)) for cap_id in cap_ids: cursor.execute( "INSERT INTO requirement_capability (requirement_id, capability_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (req_id, cap_id)) k_links = _normalize_links(rel_data, 'knowledge_links', 'knowledge_ids', 'related') if k_links is not None: cursor.execute("DELETE FROM requirement_knowledge WHERE requirement_id = %s", (req_id,)) for kid, rtype in k_links: cursor.execute( "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (req_id, kid, rtype)) if 'resource_ids' in rel_data and rel_data['resource_ids'] is not None: cursor.execute("DELETE FROM requirement_resource WHERE requirement_id = %s", (req_id,)) for rid in rel_data['resource_ids']: cursor.execute( "INSERT INTO requirement_resource (requirement_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (req_id, rid)) if strategy_ids is not None: cursor.execute("DELETE FROM requirement_strategy WHERE requirement_id = %s", (req_id,)) for sid in strategy_ids: cursor.execute( "INSERT INTO requirement_strategy (requirement_id, strategy_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (req_id, sid)) self.conn.commit() finally: cursor.close() def add_knowledge(self, req_id: str, knowledge_id: str, relation_type: str = 'related'): """增量挂接 requirement-knowledge 边""" cursor = self._get_cursor() try: cursor.execute( "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (req_id, knowledge_id, relation_type)) self.conn.commit() finally: cursor.close() def add_resource(self, req_id: str, resource_id: str): """增量挂接 requirement-resource 边""" cursor = self._get_cursor() try: cursor.execute( "INSERT INTO requirement_resource (requirement_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (req_id, resource_id)) self.conn.commit() finally: cursor.close() def add_strategy(self, req_id: str, strategy_id: str): """增量挂接 requirement-strategy 边(该 strategy 满足此 requirement)""" cursor = self._get_cursor() try: cursor.execute( "INSERT INTO requirement_strategy (requirement_id, strategy_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (req_id, strategy_id)) self.conn.commit() finally: cursor.close() def delete(self, req_id: str): """删除需求及其关联表记录""" cursor = self._get_cursor() try: cascade_delete(cursor, 'requirement', req_id) self.conn.commit() finally: cursor.close() def count(self, status: Optional[str] = None) -> int: """统计需求总数""" from knowhub.knowhub_db.version_context import req_version_where cursor = self._get_cursor() try: vf, vp = req_version_where() if status: cursor.execute(f"SELECT COUNT(*) as count FROM requirement WHERE status = %s AND {vf}", (status, *vp)) else: cursor.execute(f"SELECT COUNT(*) as count FROM requirement WHERE {vf}", vp) return cursor.fetchone()['count'] finally: cursor.close() def _format_result(self, row: Dict) -> Dict: """格式化查询结果""" if not row: return None result = dict(row) if 'source_nodes' in result and isinstance(result['source_nodes'], str): result['source_nodes'] = json.loads(result['source_nodes']) # 关联字段(来自 junction table 子查询) for field in ('capability_ids', 'knowledge_ids', 'resource_ids', 'strategy_ids', 'knowledge_links', 'pattern_ids', 'node_ids'): if field in result and isinstance(result[field], str): result[field] = json.loads(result[field]) elif field in result and result[field] is None: result[field] = [] return result def close(self): if self.conn: self.conn.close()