| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353 |
- """
- PostgreSQL strategy 存储封装
- 用于存储和检索「制作策略」。strategy 是一组原子 capability 的组合,
- 附带自身的 body(可执行描述)与 source 知识。
- 关联:
- - strategy_capability(默认 relation_type='compose')
- - strategy_knowledge(默认 relation_type='source',也可为 'case' 等)
- - strategy_resource(直接素材,无 type)
- """
- import os
- import psycopg2
- from psycopg2.extras import RealDictCursor
- from typing import List, Dict, Optional
- from dotenv import load_dotenv
- from knowhub.knowhub_db.cascade import cascade_delete
- load_dotenv()
- # 读取路径:同时暴露扁平 ids 和带 type 的 links
- _REL_SUBQUERIES = """
- (SELECT COALESCE(json_agg(rs.requirement_id), '[]'::json)
- FROM requirement_strategy rs WHERE rs.strategy_id = strategy.id) AS requirement_ids,
- (SELECT COALESCE(json_agg(sc.capability_id), '[]'::json)
- FROM strategy_capability sc WHERE sc.strategy_id = strategy.id) AS capability_ids,
- (SELECT COALESCE(json_agg(json_build_object(
- 'id', sc2.capability_id, 'relation_type', sc2.relation_type
- )), '[]'::json)
- FROM strategy_capability sc2 WHERE sc2.strategy_id = strategy.id) AS capability_links,
- (SELECT COALESCE(json_agg(sk.knowledge_id), '[]'::json)
- FROM strategy_knowledge sk WHERE sk.strategy_id = strategy.id) AS knowledge_ids,
- (SELECT COALESCE(json_agg(json_build_object(
- 'id', sk2.knowledge_id, 'relation_type', sk2.relation_type
- )), '[]'::json)
- FROM strategy_knowledge sk2 WHERE sk2.strategy_id = strategy.id) AS knowledge_links,
- (SELECT COALESCE(json_agg(sr.resource_id), '[]'::json)
- FROM strategy_resource sr WHERE sr.strategy_id = strategy.id) AS resource_ids
- """
- _BASE_FIELDS = "id, name, description, body, status, created_at, updated_at"
- _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}"
- class PostgreSQLStrategyStore:
- def __init__(self):
- self.conn = psycopg2.connect(
- host=os.getenv('KNOWHUB_DB'),
- port=int(os.getenv('KNOWHUB_PORT', 5432)),
- user=os.getenv('KNOWHUB_USER'),
- password=os.getenv('KNOWHUB_PASSWORD'),
- database=os.getenv('KNOWHUB_DB_NAME')
- )
- self.conn.autocommit = True
- print(f"[PostgreSQL Strategy] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
- def _reconnect(self):
- self.conn = psycopg2.connect(
- host=os.getenv('KNOWHUB_DB'),
- port=int(os.getenv('KNOWHUB_PORT', 5432)),
- user=os.getenv('KNOWHUB_USER'),
- password=os.getenv('KNOWHUB_PASSWORD'),
- database=os.getenv('KNOWHUB_DB_NAME')
- )
- self.conn.autocommit = True
- def _ensure_connection(self):
- if self.conn.closed != 0:
- self._reconnect()
- else:
- try:
- c = self.conn.cursor()
- c.execute("SELECT 1")
- c.close()
- except (psycopg2.OperationalError, psycopg2.InterfaceError):
- self._reconnect()
- def _get_cursor(self):
- self._ensure_connection()
- return self.conn.cursor(cursor_factory=RealDictCursor)
- # ─── 关联写入 ────────────────────────────────────────────────
- @staticmethod
- def _normalize_links(data: Dict, links_key: str, ids_key: str, default_type: str):
- """
- 统一两种输入:
- - {links_key: [{id, relation_type}, ...]} → 使用给定 type
- - {ids_key: [id1, id2, ...]} → 使用 default_type
- 返回 [(id, relation_type), ...];若两个 key 都不存在返回 None(表示不更新)
- """
- if links_key in data and data[links_key] is not None:
- out = []
- for item in data[links_key]:
- if isinstance(item, dict):
- out.append((item['id'], item.get('relation_type', default_type)))
- else: # 容错:允许混用
- out.append((item, default_type))
- return out
- if ids_key in data and data[ids_key] is not None:
- return [(i, default_type) for i in data[ids_key]]
- return None
- def _save_relations(self, cursor, strategy_id: str, data: Dict):
- """全量替换 strategy 的 junction"""
- cap_links = self._normalize_links(data, 'capability_links', 'capability_ids', 'compose')
- if cap_links is not None:
- cursor.execute("DELETE FROM strategy_capability WHERE strategy_id = %s", (strategy_id,))
- for cap_id, rtype in cap_links:
- cursor.execute(
- "INSERT INTO strategy_capability (strategy_id, capability_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (strategy_id, cap_id, rtype))
- k_links = self._normalize_links(data, 'knowledge_links', 'knowledge_ids', 'source')
- if k_links is not None:
- cursor.execute("DELETE FROM strategy_knowledge WHERE strategy_id = %s", (strategy_id,))
- for kid, rtype in k_links:
- cursor.execute(
- "INSERT INTO strategy_knowledge (strategy_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (strategy_id, kid, rtype))
- if 'resource_ids' in data and data['resource_ids'] is not None:
- cursor.execute("DELETE FROM strategy_resource WHERE strategy_id = %s", (strategy_id,))
- for rid in data['resource_ids']:
- cursor.execute(
- "INSERT INTO strategy_resource (strategy_id, resource_id) "
- "VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (strategy_id, rid))
- if 'requirement_ids' in data and data['requirement_ids'] is not None:
- cursor.execute("DELETE FROM requirement_strategy WHERE strategy_id = %s", (strategy_id,))
- for req_id in data['requirement_ids']:
- cursor.execute(
- "INSERT INTO requirement_strategy (requirement_id, strategy_id) "
- "VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (req_id, strategy_id))
- # ─── 核心 CRUD ───────────────────────────────────────────────
- def insert_or_update(self, strategy: Dict):
- """插入或更新 strategy(含关联)"""
- cursor = self._get_cursor()
- try:
- cursor.execute("""
- INSERT INTO strategy (
- id, name, description, body, status, created_at, updated_at, embedding
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
- ON CONFLICT (id) DO UPDATE SET
- name = EXCLUDED.name,
- description = EXCLUDED.description,
- body = EXCLUDED.body,
- status = EXCLUDED.status,
- updated_at = EXCLUDED.updated_at,
- embedding = EXCLUDED.embedding
- """, (
- strategy['id'],
- strategy.get('name', ''),
- strategy.get('description', ''),
- strategy.get('body', ''),
- strategy.get('status', 'draft'),
- strategy.get('created_at'),
- strategy.get('updated_at'),
- strategy.get('embedding'),
- ))
- self._save_relations(cursor, strategy['id'], strategy)
- self.conn.commit()
- finally:
- cursor.close()
- def get_by_id(self, strategy_id: str) -> Optional[Dict]:
- cursor = self._get_cursor()
- try:
- cursor.execute(f"SELECT {_SELECT_FIELDS} FROM strategy WHERE id = %s", (strategy_id,))
- result = cursor.fetchone()
- return self._format_result(result) if result else None
- finally:
- cursor.close()
- def search(self, query_embedding: List[float], limit: int = 10,
- status: Optional[str] = None) -> List[Dict]:
- """向量检索 strategy"""
- cursor = self._get_cursor()
- try:
- if status:
- sql = f"""
- SELECT {_SELECT_FIELDS},
- 1 - (embedding <=> %s::real[]) as score
- FROM strategy
- WHERE embedding IS NOT NULL AND status = %s
- ORDER BY embedding <=> %s::real[]
- LIMIT %s
- """
- params = (query_embedding, status, query_embedding, limit)
- else:
- sql = f"""
- SELECT {_SELECT_FIELDS},
- 1 - (embedding <=> %s::real[]) as score
- FROM strategy
- WHERE embedding IS NOT NULL
- ORDER BY embedding <=> %s::real[]
- LIMIT %s
- """
- params = (query_embedding, query_embedding, limit)
- cursor.execute(sql, params)
- results = cursor.fetchall()
- return [self._format_result(r) for r in results]
- finally:
- cursor.close()
- def list_all(self, limit: int = 100, offset: int = 0,
- status: Optional[str] = None) -> List[Dict]:
- cursor = self._get_cursor()
- try:
- if status:
- cursor.execute(f"""
- SELECT {_SELECT_FIELDS} FROM strategy
- WHERE status = %s
- ORDER BY id
- LIMIT %s OFFSET %s
- """, (status, limit, offset))
- else:
- cursor.execute(f"""
- SELECT {_SELECT_FIELDS} FROM strategy
- ORDER BY id
- LIMIT %s OFFSET %s
- """, (limit, offset))
- results = cursor.fetchall()
- return [self._format_result(r) for r in results]
- finally:
- cursor.close()
- def update(self, strategy_id: str, updates: Dict):
- """更新 strategy(关联字段可选)"""
- cursor = self._get_cursor()
- try:
- # 分离关联字段
- rel_keys = ('requirement_ids',
- 'capability_ids', 'capability_links',
- 'knowledge_ids', 'knowledge_links', 'resource_ids')
- rel_fields = {k: updates.pop(k) for k in rel_keys if k in updates}
- if updates:
- set_parts = []
- params = []
- for key, value in updates.items():
- set_parts.append(f"{key} = %s")
- params.append(value)
- params.append(strategy_id)
- cursor.execute(
- f"UPDATE strategy SET {', '.join(set_parts)} WHERE id = %s",
- params)
- if rel_fields:
- self._save_relations(cursor, strategy_id, rel_fields)
- self.conn.commit()
- finally:
- cursor.close()
- def delete(self, strategy_id: str):
- """删除 strategy 及其所有 junction 行"""
- cursor = self._get_cursor()
- try:
- cascade_delete(cursor, 'strategy', strategy_id)
- self.conn.commit()
- finally:
- cursor.close()
- def count(self, status: Optional[str] = None) -> int:
- cursor = self._get_cursor()
- try:
- if status:
- cursor.execute("SELECT COUNT(*) as count FROM strategy WHERE status = %s", (status,))
- else:
- cursor.execute("SELECT COUNT(*) as count FROM strategy")
- return cursor.fetchone()['count']
- finally:
- cursor.close()
- # ─── 增量关联 API(不删已有)─────────────────────────────────
- def add_capability(self, strategy_id: str, capability_id: str,
- relation_type: str = 'compose'):
- cursor = self._get_cursor()
- try:
- cursor.execute(
- "INSERT INTO strategy_capability (strategy_id, capability_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (strategy_id, capability_id, relation_type))
- self.conn.commit()
- finally:
- cursor.close()
- def add_knowledge(self, strategy_id: str, knowledge_id: str,
- relation_type: str = 'source'):
- cursor = self._get_cursor()
- try:
- cursor.execute(
- "INSERT INTO strategy_knowledge (strategy_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (strategy_id, knowledge_id, relation_type))
- self.conn.commit()
- finally:
- cursor.close()
- def add_resource(self, strategy_id: str, resource_id: str):
- cursor = self._get_cursor()
- try:
- cursor.execute(
- "INSERT INTO strategy_resource (strategy_id, resource_id) "
- "VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (strategy_id, resource_id))
- self.conn.commit()
- finally:
- cursor.close()
- def add_requirement(self, strategy_id: str, requirement_id: str):
- """增量挂接 requirement-strategy 边(这个 strategy 满足该 requirement)"""
- cursor = self._get_cursor()
- try:
- cursor.execute(
- "INSERT INTO requirement_strategy (requirement_id, strategy_id) "
- "VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (requirement_id, strategy_id))
- self.conn.commit()
- finally:
- cursor.close()
- # ─── 辅助 ────────────────────────────────────────────────────
- def _format_result(self, row: Dict) -> Optional[Dict]:
- if not row:
- return None
- import json
- result = dict(row)
- for field in ('requirement_ids', 'capability_ids', 'knowledge_ids', 'resource_ids'):
- if field in result and isinstance(result[field], str):
- result[field] = json.loads(result[field])
- elif field in result and result[field] is None:
- result[field] = []
- for field in ('capability_links', 'knowledge_links'):
- if field in result and isinstance(result[field], str):
- result[field] = json.loads(result[field])
- elif field in result and result[field] is None:
- result[field] = []
- return result
- def close(self):
- if self.conn:
- self.conn.close()
|