pg_strategy_store.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. """
  2. PostgreSQL strategy 存储封装
  3. 用于存储和检索「制作策略」。strategy 是一组原子 capability 的组合,
  4. 附带自身的 body(可执行描述)与 source 知识。
  5. 关联:
  6. - strategy_capability(默认 relation_type='compose')
  7. - strategy_knowledge(默认 relation_type='source',也可为 'case' 等)
  8. - strategy_resource(直接素材,无 type)
  9. """
  10. import os
  11. import psycopg2
  12. from psycopg2.extras import RealDictCursor
  13. from typing import List, Dict, Optional
  14. from dotenv import load_dotenv
  15. from knowhub.knowhub_db.cascade import cascade_delete
  16. from knowhub.knowhub_db.version_context import version_where
  17. load_dotenv()
  18. # 读取路径:同时暴露扁平 ids 和带 type 的 links
  19. _REL_SUBQUERIES = """
  20. (SELECT COALESCE(json_agg(rs.requirement_id), '[]'::json)
  21. FROM requirement_strategy rs WHERE rs.strategy_id = strategy.id) AS requirement_ids,
  22. (SELECT COALESCE(json_agg(sc.capability_id), '[]'::json)
  23. FROM strategy_capability sc WHERE sc.strategy_id = strategy.id) AS capability_ids,
  24. (SELECT COALESCE(json_agg(json_build_object(
  25. 'id', sc2.capability_id, 'relation_type', sc2.relation_type
  26. )), '[]'::json)
  27. FROM strategy_capability sc2 WHERE sc2.strategy_id = strategy.id) AS capability_links,
  28. (SELECT COALESCE(json_agg(sk.knowledge_id), '[]'::json)
  29. FROM strategy_knowledge sk WHERE sk.strategy_id = strategy.id) AS knowledge_ids,
  30. (SELECT COALESCE(json_agg(json_build_object(
  31. 'id', sk2.knowledge_id, 'relation_type', sk2.relation_type
  32. )), '[]'::json)
  33. FROM strategy_knowledge sk2 WHERE sk2.strategy_id = strategy.id) AS knowledge_links,
  34. (SELECT COALESCE(json_agg(sr.resource_id), '[]'::json)
  35. FROM strategy_resource sr WHERE sr.strategy_id = strategy.id) AS resource_ids
  36. """
  37. _BASE_FIELDS = "id, name, description, body, status, created_at, updated_at, version"
  38. _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}"
  39. class PostgreSQLStrategyStore:
  40. def __init__(self):
  41. self.conn = psycopg2.connect(
  42. host=os.getenv('KNOWHUB_DB'),
  43. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  44. user=os.getenv('KNOWHUB_USER'),
  45. password=os.getenv('KNOWHUB_PASSWORD'),
  46. database=os.getenv('KNOWHUB_DB_NAME')
  47. )
  48. self.conn.autocommit = True
  49. print(f"[PostgreSQL Strategy] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  50. def _reconnect(self):
  51. self.conn = psycopg2.connect(
  52. host=os.getenv('KNOWHUB_DB'),
  53. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  54. user=os.getenv('KNOWHUB_USER'),
  55. password=os.getenv('KNOWHUB_PASSWORD'),
  56. database=os.getenv('KNOWHUB_DB_NAME')
  57. )
  58. self.conn.autocommit = True
  59. def _ensure_connection(self):
  60. if self.conn.closed != 0:
  61. self._reconnect()
  62. else:
  63. try:
  64. c = self.conn.cursor()
  65. c.execute("SELECT 1")
  66. c.close()
  67. except (psycopg2.OperationalError, psycopg2.InterfaceError):
  68. self._reconnect()
  69. def _get_cursor(self):
  70. self._ensure_connection()
  71. return self.conn.cursor(cursor_factory=RealDictCursor)
  72. # ─── 关联写入 ────────────────────────────────────────────────
  73. @staticmethod
  74. def _normalize_links(data: Dict, links_key: str, ids_key: str, default_type: str):
  75. """
  76. 统一两种输入:
  77. - {links_key: [{id, relation_type}, ...]} → 使用给定 type
  78. - {ids_key: [id1, id2, ...]} → 使用 default_type
  79. 返回 [(id, relation_type), ...];若两个 key 都不存在返回 None(表示不更新)
  80. """
  81. if links_key in data and data[links_key] is not None:
  82. out = []
  83. for item in data[links_key]:
  84. if isinstance(item, dict):
  85. out.append((item['id'], item.get('relation_type', default_type)))
  86. else: # 容错:允许混用
  87. out.append((item, default_type))
  88. return out
  89. if ids_key in data and data[ids_key] is not None:
  90. return [(i, default_type) for i in data[ids_key]]
  91. return None
  92. def _save_relations(self, cursor, strategy_id: str, data: Dict):
  93. """全量替换 strategy 的 junction"""
  94. cap_links = self._normalize_links(data, 'capability_links', 'capability_ids', 'compose')
  95. if cap_links is not None:
  96. cursor.execute("DELETE FROM strategy_capability WHERE strategy_id = %s", (strategy_id,))
  97. for cap_id, rtype in cap_links:
  98. cursor.execute(
  99. "INSERT INTO strategy_capability (strategy_id, capability_id, relation_type) "
  100. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  101. (strategy_id, cap_id, rtype))
  102. k_links = self._normalize_links(data, 'knowledge_links', 'knowledge_ids', 'source')
  103. if k_links is not None:
  104. cursor.execute("DELETE FROM strategy_knowledge WHERE strategy_id = %s", (strategy_id,))
  105. for kid, rtype in k_links:
  106. cursor.execute(
  107. "INSERT INTO strategy_knowledge (strategy_id, knowledge_id, relation_type) "
  108. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  109. (strategy_id, kid, rtype))
  110. if 'resource_ids' in data and data['resource_ids'] is not None:
  111. cursor.execute("DELETE FROM strategy_resource WHERE strategy_id = %s", (strategy_id,))
  112. for rid in data['resource_ids']:
  113. cursor.execute(
  114. "INSERT INTO strategy_resource (strategy_id, resource_id) "
  115. "VALUES (%s, %s) ON CONFLICT DO NOTHING",
  116. (strategy_id, rid))
  117. if 'requirement_ids' in data and data['requirement_ids'] is not None:
  118. cursor.execute("DELETE FROM requirement_strategy WHERE strategy_id = %s", (strategy_id,))
  119. for req_id in data['requirement_ids']:
  120. cursor.execute(
  121. "INSERT INTO requirement_strategy (requirement_id, strategy_id) "
  122. "VALUES (%s, %s) ON CONFLICT DO NOTHING",
  123. (req_id, strategy_id))
  124. # ─── 核心 CRUD ───────────────────────────────────────────────
  125. def insert_or_update(self, strategy: Dict):
  126. """插入或更新 strategy(含关联)。AnalyticDB beam 表不支持 ON CONFLICT UPDATE,改用 DELETE+INSERT。"""
  127. cursor = self._get_cursor()
  128. try:
  129. cursor.execute("DELETE FROM strategy WHERE id = %s", (strategy['id'],))
  130. cursor.execute("""
  131. INSERT INTO strategy (
  132. id, name, description, body, status, created_at, updated_at, embedding, version
  133. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
  134. """, (
  135. strategy['id'],
  136. strategy.get('name', ''),
  137. strategy.get('description', ''),
  138. strategy.get('body', ''),
  139. strategy.get('status', 'draft'),
  140. strategy.get('created_at'),
  141. strategy.get('updated_at'),
  142. strategy.get('embedding'),
  143. strategy.get('version', 'v0'),
  144. ))
  145. self._save_relations(cursor, strategy['id'], strategy)
  146. self.conn.commit()
  147. finally:
  148. cursor.close()
  149. def get_by_id(self, strategy_id: str) -> Optional[Dict]:
  150. cursor = self._get_cursor()
  151. try:
  152. vf, vp = version_where()
  153. cursor.execute(f"SELECT {_SELECT_FIELDS} FROM strategy WHERE id = %s AND {vf}",
  154. (strategy_id, *vp))
  155. result = cursor.fetchone()
  156. return self._format_result(result) if result else None
  157. finally:
  158. cursor.close()
  159. def search(self, query_embedding: List[float], limit: int = 10,
  160. status: Optional[str] = None) -> List[Dict]:
  161. """向量检索 strategy"""
  162. cursor = self._get_cursor()
  163. try:
  164. vf, vp = version_where()
  165. if status:
  166. sql = f"""
  167. SELECT {_SELECT_FIELDS},
  168. 1 - (embedding <=> %s::real[]) as score
  169. FROM strategy
  170. WHERE embedding IS NOT NULL AND status = %s AND {vf}
  171. ORDER BY embedding <=> %s::real[]
  172. LIMIT %s
  173. """
  174. params = (query_embedding, status, *vp, query_embedding, limit)
  175. else:
  176. sql = f"""
  177. SELECT {_SELECT_FIELDS},
  178. 1 - (embedding <=> %s::real[]) as score
  179. FROM strategy
  180. WHERE embedding IS NOT NULL AND {vf}
  181. ORDER BY embedding <=> %s::real[]
  182. LIMIT %s
  183. """
  184. params = (query_embedding, *vp, query_embedding, limit)
  185. cursor.execute(sql, params)
  186. results = cursor.fetchall()
  187. return [self._format_result(r) for r in results]
  188. finally:
  189. cursor.close()
  190. def list_all(self, limit: int = 100, offset: int = 0,
  191. status: Optional[str] = None) -> List[Dict]:
  192. cursor = self._get_cursor()
  193. try:
  194. vf, vp = version_where()
  195. if status:
  196. cursor.execute(f"""
  197. SELECT {_SELECT_FIELDS} FROM strategy
  198. WHERE status = %s AND {vf}
  199. ORDER BY id
  200. LIMIT %s OFFSET %s
  201. """, (status, *vp, limit, offset))
  202. else:
  203. cursor.execute(f"""
  204. SELECT {_SELECT_FIELDS} FROM strategy
  205. WHERE {vf}
  206. ORDER BY id
  207. LIMIT %s OFFSET %s
  208. """, (*vp, limit, offset))
  209. results = cursor.fetchall()
  210. return [self._format_result(r) for r in results]
  211. finally:
  212. cursor.close()
  213. def update(self, strategy_id: str, updates: Dict):
  214. """更新 strategy(关联字段可选)"""
  215. cursor = self._get_cursor()
  216. try:
  217. # 分离关联字段
  218. rel_keys = ('requirement_ids',
  219. 'capability_ids', 'capability_links',
  220. 'knowledge_ids', 'knowledge_links', 'resource_ids')
  221. rel_fields = {k: updates.pop(k) for k in rel_keys if k in updates}
  222. if updates:
  223. set_parts = []
  224. params = []
  225. for key, value in updates.items():
  226. set_parts.append(f"{key} = %s")
  227. params.append(value)
  228. params.append(strategy_id)
  229. cursor.execute(
  230. f"UPDATE strategy SET {', '.join(set_parts)} WHERE id = %s",
  231. params)
  232. if rel_fields:
  233. self._save_relations(cursor, strategy_id, rel_fields)
  234. self.conn.commit()
  235. finally:
  236. cursor.close()
  237. def delete(self, strategy_id: str):
  238. """删除 strategy 及其所有 junction 行"""
  239. cursor = self._get_cursor()
  240. try:
  241. cascade_delete(cursor, 'strategy', strategy_id)
  242. self.conn.commit()
  243. finally:
  244. cursor.close()
  245. def count(self, status: Optional[str] = None) -> int:
  246. cursor = self._get_cursor()
  247. try:
  248. vf, vp = version_where()
  249. if status:
  250. cursor.execute(f"SELECT COUNT(*) as count FROM strategy WHERE status = %s AND {vf}",
  251. (status, *vp))
  252. else:
  253. cursor.execute(f"SELECT COUNT(*) as count FROM strategy WHERE {vf}", vp)
  254. return cursor.fetchone()['count']
  255. finally:
  256. cursor.close()
  257. # ─── 增量关联 API(不删已有)─────────────────────────────────
  258. def add_capability(self, strategy_id: str, capability_id: str,
  259. relation_type: str = 'compose'):
  260. cursor = self._get_cursor()
  261. try:
  262. cursor.execute(
  263. "INSERT INTO strategy_capability (strategy_id, capability_id, relation_type) "
  264. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  265. (strategy_id, capability_id, relation_type))
  266. self.conn.commit()
  267. finally:
  268. cursor.close()
  269. def add_knowledge(self, strategy_id: str, knowledge_id: str,
  270. relation_type: str = 'source'):
  271. cursor = self._get_cursor()
  272. try:
  273. cursor.execute(
  274. "INSERT INTO strategy_knowledge (strategy_id, knowledge_id, relation_type) "
  275. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  276. (strategy_id, knowledge_id, relation_type))
  277. self.conn.commit()
  278. finally:
  279. cursor.close()
  280. def add_resource(self, strategy_id: str, resource_id: str):
  281. cursor = self._get_cursor()
  282. try:
  283. cursor.execute(
  284. "INSERT INTO strategy_resource (strategy_id, resource_id) "
  285. "VALUES (%s, %s) ON CONFLICT DO NOTHING",
  286. (strategy_id, resource_id))
  287. self.conn.commit()
  288. finally:
  289. cursor.close()
  290. def add_requirement(self, strategy_id: str, requirement_id: str):
  291. """增量挂接 requirement-strategy 边(这个 strategy 满足该 requirement)"""
  292. cursor = self._get_cursor()
  293. try:
  294. cursor.execute(
  295. "INSERT INTO requirement_strategy (requirement_id, strategy_id) "
  296. "VALUES (%s, %s) ON CONFLICT DO NOTHING",
  297. (requirement_id, strategy_id))
  298. self.conn.commit()
  299. finally:
  300. cursor.close()
  301. # ─── 辅助 ────────────────────────────────────────────────────
  302. def _format_result(self, row: Dict) -> Optional[Dict]:
  303. if not row:
  304. return None
  305. import json
  306. result = dict(row)
  307. for field in ('requirement_ids', 'capability_ids', 'knowledge_ids', 'resource_ids'):
  308. if field in result and isinstance(result[field], str):
  309. result[field] = json.loads(result[field])
  310. elif field in result and result[field] is None:
  311. result[field] = []
  312. for field in ('capability_links', 'knowledge_links'):
  313. if field in result and isinstance(result[field], str):
  314. result[field] = json.loads(result[field])
  315. elif field in result and result[field] is None:
  316. result[field] = []
  317. return result
  318. def close(self):
  319. if self.conn:
  320. self.conn.close()