""" PostgreSQL strategy 存储封装 用于存储和检索「制作策略」。strategy 是一组原子 capability 的组合, 附带自身的 body(可执行描述)与 source 知识。 关联: - strategy_capability(默认 relation_type='compose') - strategy_knowledge(默认 relation_type='source',也可为 'case' 等) - strategy_resource(直接素材,无 type) """ import os import psycopg2 from psycopg2.extras import RealDictCursor from typing import List, Dict, Optional from dotenv import load_dotenv from knowhub.knowhub_db.cascade import cascade_delete from knowhub.knowhub_db.version_context import version_where load_dotenv() # 读取路径:同时暴露扁平 ids 和带 type 的 links _REL_SUBQUERIES = """ (SELECT COALESCE(json_agg(rs.requirement_id), '[]'::json) FROM requirement_strategy rs WHERE rs.strategy_id = strategy.id) AS requirement_ids, (SELECT COALESCE(json_agg(sc.capability_id), '[]'::json) FROM strategy_capability sc WHERE sc.strategy_id = strategy.id) AS capability_ids, (SELECT COALESCE(json_agg(json_build_object( 'id', sc2.capability_id, 'relation_type', sc2.relation_type )), '[]'::json) FROM strategy_capability sc2 WHERE sc2.strategy_id = strategy.id) AS capability_links, (SELECT COALESCE(json_agg(sk.knowledge_id), '[]'::json) FROM strategy_knowledge sk WHERE sk.strategy_id = strategy.id) AS knowledge_ids, (SELECT COALESCE(json_agg(json_build_object( 'id', sk2.knowledge_id, 'relation_type', sk2.relation_type )), '[]'::json) FROM strategy_knowledge sk2 WHERE sk2.strategy_id = strategy.id) AS knowledge_links, (SELECT COALESCE(json_agg(sr.resource_id), '[]'::json) FROM strategy_resource sr WHERE sr.strategy_id = strategy.id) AS resource_ids """ _BASE_FIELDS = "id, name, description, body, status, created_at, updated_at, version" _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}" class PostgreSQLStrategyStore: def __init__(self): self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = True print(f"[PostgreSQL Strategy] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}") def _reconnect(self): self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = True def _ensure_connection(self): if self.conn.closed != 0: self._reconnect() else: try: c = self.conn.cursor() c.execute("SELECT 1") c.close() except (psycopg2.OperationalError, psycopg2.InterfaceError): self._reconnect() def _get_cursor(self): self._ensure_connection() return self.conn.cursor(cursor_factory=RealDictCursor) # ─── 关联写入 ──────────────────────────────────────────────── @staticmethod def _normalize_links(data: Dict, links_key: str, ids_key: str, default_type: str): """ 统一两种输入: - {links_key: [{id, relation_type}, ...]} → 使用给定 type - {ids_key: [id1, id2, ...]} → 使用 default_type 返回 [(id, relation_type), ...];若两个 key 都不存在返回 None(表示不更新) """ if links_key in data and data[links_key] is not None: out = [] for item in data[links_key]: if isinstance(item, dict): out.append((item['id'], item.get('relation_type', default_type))) else: # 容错:允许混用 out.append((item, default_type)) return out if ids_key in data and data[ids_key] is not None: return [(i, default_type) for i in data[ids_key]] return None def _save_relations(self, cursor, strategy_id: str, data: Dict): """全量替换 strategy 的 junction""" cap_links = self._normalize_links(data, 'capability_links', 'capability_ids', 'compose') if cap_links is not None: cursor.execute("DELETE FROM strategy_capability WHERE strategy_id = %s", (strategy_id,)) for cap_id, rtype in cap_links: cursor.execute( "INSERT INTO strategy_capability (strategy_id, capability_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (strategy_id, cap_id, rtype)) k_links = self._normalize_links(data, 'knowledge_links', 'knowledge_ids', 'source') if k_links is not None: cursor.execute("DELETE FROM strategy_knowledge WHERE strategy_id = %s", (strategy_id,)) for kid, rtype in k_links: cursor.execute( "INSERT INTO strategy_knowledge (strategy_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (strategy_id, kid, rtype)) if 'resource_ids' in data and data['resource_ids'] is not None: cursor.execute("DELETE FROM strategy_resource WHERE strategy_id = %s", (strategy_id,)) for rid in data['resource_ids']: cursor.execute( "INSERT INTO strategy_resource (strategy_id, resource_id) " "VALUES (%s, %s) ON CONFLICT DO NOTHING", (strategy_id, rid)) if 'requirement_ids' in data and data['requirement_ids'] is not None: cursor.execute("DELETE FROM requirement_strategy WHERE strategy_id = %s", (strategy_id,)) for req_id in data['requirement_ids']: cursor.execute( "INSERT INTO requirement_strategy (requirement_id, strategy_id) " "VALUES (%s, %s) ON CONFLICT DO NOTHING", (req_id, strategy_id)) # ─── 核心 CRUD ─────────────────────────────────────────────── def insert_or_update(self, strategy: Dict): """插入或更新 strategy(含关联)。AnalyticDB beam 表不支持 ON CONFLICT UPDATE,改用 DELETE+INSERT。""" cursor = self._get_cursor() try: cursor.execute("DELETE FROM strategy WHERE id = %s", (strategy['id'],)) cursor.execute(""" INSERT INTO strategy ( id, name, description, body, status, created_at, updated_at, embedding, version ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) """, ( strategy['id'], strategy.get('name', ''), strategy.get('description', ''), strategy.get('body', ''), strategy.get('status', 'draft'), strategy.get('created_at'), strategy.get('updated_at'), strategy.get('embedding'), strategy.get('version', 'v0'), )) self._save_relations(cursor, strategy['id'], strategy) self.conn.commit() finally: cursor.close() def get_by_id(self, strategy_id: str) -> Optional[Dict]: cursor = self._get_cursor() try: vf, vp = version_where() cursor.execute(f"SELECT {_SELECT_FIELDS} FROM strategy WHERE id = %s AND {vf}", (strategy_id, *vp)) result = cursor.fetchone() return self._format_result(result) if result else None finally: cursor.close() def search(self, query_embedding: List[float], limit: int = 10, status: Optional[str] = None) -> List[Dict]: """向量检索 strategy""" cursor = self._get_cursor() try: vf, vp = version_where() if status: sql = f""" SELECT {_SELECT_FIELDS}, 1 - (embedding <=> %s::real[]) as score FROM strategy WHERE embedding IS NOT NULL AND status = %s AND {vf} ORDER BY embedding <=> %s::real[] LIMIT %s """ params = (query_embedding, status, *vp, query_embedding, limit) else: sql = f""" SELECT {_SELECT_FIELDS}, 1 - (embedding <=> %s::real[]) as score FROM strategy WHERE embedding IS NOT NULL AND {vf} ORDER BY embedding <=> %s::real[] LIMIT %s """ params = (query_embedding, *vp, query_embedding, limit) cursor.execute(sql, params) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]: cursor = self._get_cursor() try: vf, vp = version_where() if status: cursor.execute(f""" SELECT {_SELECT_FIELDS} FROM strategy WHERE status = %s AND {vf} ORDER BY id LIMIT %s OFFSET %s """, (status, *vp, limit, offset)) else: cursor.execute(f""" SELECT {_SELECT_FIELDS} FROM strategy WHERE {vf} ORDER BY id LIMIT %s OFFSET %s """, (*vp, limit, offset)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def update(self, strategy_id: str, updates: Dict): """更新 strategy(关联字段可选)""" cursor = self._get_cursor() try: # 分离关联字段 rel_keys = ('requirement_ids', 'capability_ids', 'capability_links', 'knowledge_ids', 'knowledge_links', 'resource_ids') rel_fields = {k: updates.pop(k) for k in rel_keys if k in updates} if updates: set_parts = [] params = [] for key, value in updates.items(): set_parts.append(f"{key} = %s") params.append(value) params.append(strategy_id) cursor.execute( f"UPDATE strategy SET {', '.join(set_parts)} WHERE id = %s", params) if rel_fields: self._save_relations(cursor, strategy_id, rel_fields) self.conn.commit() finally: cursor.close() def delete(self, strategy_id: str): """删除 strategy 及其所有 junction 行""" cursor = self._get_cursor() try: cascade_delete(cursor, 'strategy', strategy_id) self.conn.commit() finally: cursor.close() def count(self, status: Optional[str] = None) -> int: cursor = self._get_cursor() try: vf, vp = version_where() if status: cursor.execute(f"SELECT COUNT(*) as count FROM strategy WHERE status = %s AND {vf}", (status, *vp)) else: cursor.execute(f"SELECT COUNT(*) as count FROM strategy WHERE {vf}", vp) return cursor.fetchone()['count'] finally: cursor.close() # ─── 增量关联 API(不删已有)───────────────────────────────── def add_capability(self, strategy_id: str, capability_id: str, relation_type: str = 'compose'): cursor = self._get_cursor() try: cursor.execute( "INSERT INTO strategy_capability (strategy_id, capability_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (strategy_id, capability_id, relation_type)) self.conn.commit() finally: cursor.close() def add_knowledge(self, strategy_id: str, knowledge_id: str, relation_type: str = 'source'): cursor = self._get_cursor() try: cursor.execute( "INSERT INTO strategy_knowledge (strategy_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (strategy_id, knowledge_id, relation_type)) self.conn.commit() finally: cursor.close() def add_resource(self, strategy_id: str, resource_id: str): cursor = self._get_cursor() try: cursor.execute( "INSERT INTO strategy_resource (strategy_id, resource_id) " "VALUES (%s, %s) ON CONFLICT DO NOTHING", (strategy_id, resource_id)) self.conn.commit() finally: cursor.close() def add_requirement(self, strategy_id: str, requirement_id: str): """增量挂接 requirement-strategy 边(这个 strategy 满足该 requirement)""" cursor = self._get_cursor() try: cursor.execute( "INSERT INTO requirement_strategy (requirement_id, strategy_id) " "VALUES (%s, %s) ON CONFLICT DO NOTHING", (requirement_id, strategy_id)) self.conn.commit() finally: cursor.close() # ─── 辅助 ──────────────────────────────────────────────────── def _format_result(self, row: Dict) -> Optional[Dict]: if not row: return None import json result = dict(row) for field in ('requirement_ids', 'capability_ids', 'knowledge_ids', 'resource_ids'): if field in result and isinstance(result[field], str): result[field] = json.loads(result[field]) elif field in result and result[field] is None: result[field] = [] for field in ('capability_links', 'knowledge_links'): if field in result and isinstance(result[field], str): result[field] = json.loads(result[field]) elif field in result and result[field] is None: result[field] = [] return result def close(self): if self.conn: self.conn.close()