pg_strategy_store.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. """
  2. PostgreSQL strategy 存储封装
  3. 用于存储和检索「制作策略」。strategy 是一组原子 capability 的组合,
  4. 附带自身的 body(可执行描述)与 source 知识。
  5. 关联:
  6. - strategy_capability(默认 relation_type='compose')
  7. - strategy_knowledge(默认 relation_type='source',也可为 'case' 等)
  8. - strategy_resource(直接素材,无 type)
  9. """
  10. import os
  11. import psycopg2
  12. from psycopg2.extras import RealDictCursor
  13. from typing import List, Dict, Optional
  14. from dotenv import load_dotenv
  15. from knowhub.knowhub_db.cascade import cascade_delete
  16. load_dotenv()
  17. # 读取路径:同时暴露扁平 ids 和带 type 的 links
  18. _REL_SUBQUERIES = """
  19. (SELECT COALESCE(json_agg(rs.requirement_id), '[]'::json)
  20. FROM requirement_strategy rs WHERE rs.strategy_id = strategy.id) AS requirement_ids,
  21. (SELECT COALESCE(json_agg(sc.capability_id), '[]'::json)
  22. FROM strategy_capability sc WHERE sc.strategy_id = strategy.id) AS capability_ids,
  23. (SELECT COALESCE(json_agg(json_build_object(
  24. 'id', sc2.capability_id, 'relation_type', sc2.relation_type
  25. )), '[]'::json)
  26. FROM strategy_capability sc2 WHERE sc2.strategy_id = strategy.id) AS capability_links,
  27. (SELECT COALESCE(json_agg(sk.knowledge_id), '[]'::json)
  28. FROM strategy_knowledge sk WHERE sk.strategy_id = strategy.id) AS knowledge_ids,
  29. (SELECT COALESCE(json_agg(json_build_object(
  30. 'id', sk2.knowledge_id, 'relation_type', sk2.relation_type
  31. )), '[]'::json)
  32. FROM strategy_knowledge sk2 WHERE sk2.strategy_id = strategy.id) AS knowledge_links,
  33. (SELECT COALESCE(json_agg(sr.resource_id), '[]'::json)
  34. FROM strategy_resource sr WHERE sr.strategy_id = strategy.id) AS resource_ids
  35. """
  36. _BASE_FIELDS = "id, name, description, body, status, created_at, updated_at"
  37. _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}"
  38. class PostgreSQLStrategyStore:
  39. def __init__(self):
  40. self.conn = psycopg2.connect(
  41. host=os.getenv('KNOWHUB_DB'),
  42. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  43. user=os.getenv('KNOWHUB_USER'),
  44. password=os.getenv('KNOWHUB_PASSWORD'),
  45. database=os.getenv('KNOWHUB_DB_NAME')
  46. )
  47. self.conn.autocommit = True
  48. print(f"[PostgreSQL Strategy] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  49. def _reconnect(self):
  50. self.conn = psycopg2.connect(
  51. host=os.getenv('KNOWHUB_DB'),
  52. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  53. user=os.getenv('KNOWHUB_USER'),
  54. password=os.getenv('KNOWHUB_PASSWORD'),
  55. database=os.getenv('KNOWHUB_DB_NAME')
  56. )
  57. self.conn.autocommit = True
  58. def _ensure_connection(self):
  59. if self.conn.closed != 0:
  60. self._reconnect()
  61. else:
  62. try:
  63. c = self.conn.cursor()
  64. c.execute("SELECT 1")
  65. c.close()
  66. except (psycopg2.OperationalError, psycopg2.InterfaceError):
  67. self._reconnect()
  68. def _get_cursor(self):
  69. self._ensure_connection()
  70. return self.conn.cursor(cursor_factory=RealDictCursor)
  71. # ─── 关联写入 ────────────────────────────────────────────────
  72. @staticmethod
  73. def _normalize_links(data: Dict, links_key: str, ids_key: str, default_type: str):
  74. """
  75. 统一两种输入:
  76. - {links_key: [{id, relation_type}, ...]} → 使用给定 type
  77. - {ids_key: [id1, id2, ...]} → 使用 default_type
  78. 返回 [(id, relation_type), ...];若两个 key 都不存在返回 None(表示不更新)
  79. """
  80. if links_key in data and data[links_key] is not None:
  81. out = []
  82. for item in data[links_key]:
  83. if isinstance(item, dict):
  84. out.append((item['id'], item.get('relation_type', default_type)))
  85. else: # 容错:允许混用
  86. out.append((item, default_type))
  87. return out
  88. if ids_key in data and data[ids_key] is not None:
  89. return [(i, default_type) for i in data[ids_key]]
  90. return None
  91. def _save_relations(self, cursor, strategy_id: str, data: Dict):
  92. """全量替换 strategy 的 junction"""
  93. cap_links = self._normalize_links(data, 'capability_links', 'capability_ids', 'compose')
  94. if cap_links is not None:
  95. cursor.execute("DELETE FROM strategy_capability WHERE strategy_id = %s", (strategy_id,))
  96. for cap_id, rtype in cap_links:
  97. cursor.execute(
  98. "INSERT INTO strategy_capability (strategy_id, capability_id, relation_type) "
  99. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  100. (strategy_id, cap_id, rtype))
  101. k_links = self._normalize_links(data, 'knowledge_links', 'knowledge_ids', 'source')
  102. if k_links is not None:
  103. cursor.execute("DELETE FROM strategy_knowledge WHERE strategy_id = %s", (strategy_id,))
  104. for kid, rtype in k_links:
  105. cursor.execute(
  106. "INSERT INTO strategy_knowledge (strategy_id, knowledge_id, relation_type) "
  107. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  108. (strategy_id, kid, rtype))
  109. if 'resource_ids' in data and data['resource_ids'] is not None:
  110. cursor.execute("DELETE FROM strategy_resource WHERE strategy_id = %s", (strategy_id,))
  111. for rid in data['resource_ids']:
  112. cursor.execute(
  113. "INSERT INTO strategy_resource (strategy_id, resource_id) "
  114. "VALUES (%s, %s) ON CONFLICT DO NOTHING",
  115. (strategy_id, rid))
  116. if 'requirement_ids' in data and data['requirement_ids'] is not None:
  117. cursor.execute("DELETE FROM requirement_strategy WHERE strategy_id = %s", (strategy_id,))
  118. for req_id in data['requirement_ids']:
  119. cursor.execute(
  120. "INSERT INTO requirement_strategy (requirement_id, strategy_id) "
  121. "VALUES (%s, %s) ON CONFLICT DO NOTHING",
  122. (req_id, strategy_id))
  123. # ─── 核心 CRUD ───────────────────────────────────────────────
  124. def insert_or_update(self, strategy: Dict):
  125. """插入或更新 strategy(含关联)"""
  126. cursor = self._get_cursor()
  127. try:
  128. cursor.execute("""
  129. INSERT INTO strategy (
  130. id, name, description, body, status, created_at, updated_at, embedding
  131. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
  132. ON CONFLICT (id) DO UPDATE SET
  133. name = EXCLUDED.name,
  134. description = EXCLUDED.description,
  135. body = EXCLUDED.body,
  136. status = EXCLUDED.status,
  137. updated_at = EXCLUDED.updated_at,
  138. embedding = EXCLUDED.embedding
  139. """, (
  140. strategy['id'],
  141. strategy.get('name', ''),
  142. strategy.get('description', ''),
  143. strategy.get('body', ''),
  144. strategy.get('status', 'draft'),
  145. strategy.get('created_at'),
  146. strategy.get('updated_at'),
  147. strategy.get('embedding'),
  148. ))
  149. self._save_relations(cursor, strategy['id'], strategy)
  150. self.conn.commit()
  151. finally:
  152. cursor.close()
  153. def get_by_id(self, strategy_id: str) -> Optional[Dict]:
  154. cursor = self._get_cursor()
  155. try:
  156. cursor.execute(f"SELECT {_SELECT_FIELDS} FROM strategy WHERE id = %s", (strategy_id,))
  157. result = cursor.fetchone()
  158. return self._format_result(result) if result else None
  159. finally:
  160. cursor.close()
  161. def search(self, query_embedding: List[float], limit: int = 10,
  162. status: Optional[str] = None) -> List[Dict]:
  163. """向量检索 strategy"""
  164. cursor = self._get_cursor()
  165. try:
  166. if status:
  167. sql = f"""
  168. SELECT {_SELECT_FIELDS},
  169. 1 - (embedding <=> %s::real[]) as score
  170. FROM strategy
  171. WHERE embedding IS NOT NULL AND status = %s
  172. ORDER BY embedding <=> %s::real[]
  173. LIMIT %s
  174. """
  175. params = (query_embedding, status, query_embedding, limit)
  176. else:
  177. sql = f"""
  178. SELECT {_SELECT_FIELDS},
  179. 1 - (embedding <=> %s::real[]) as score
  180. FROM strategy
  181. WHERE embedding IS NOT NULL
  182. ORDER BY embedding <=> %s::real[]
  183. LIMIT %s
  184. """
  185. params = (query_embedding, query_embedding, limit)
  186. cursor.execute(sql, params)
  187. results = cursor.fetchall()
  188. return [self._format_result(r) for r in results]
  189. finally:
  190. cursor.close()
  191. def list_all(self, limit: int = 100, offset: int = 0,
  192. status: Optional[str] = None) -> List[Dict]:
  193. cursor = self._get_cursor()
  194. try:
  195. if status:
  196. cursor.execute(f"""
  197. SELECT {_SELECT_FIELDS} FROM strategy
  198. WHERE status = %s
  199. ORDER BY id
  200. LIMIT %s OFFSET %s
  201. """, (status, limit, offset))
  202. else:
  203. cursor.execute(f"""
  204. SELECT {_SELECT_FIELDS} FROM strategy
  205. ORDER BY id
  206. LIMIT %s OFFSET %s
  207. """, (limit, offset))
  208. results = cursor.fetchall()
  209. return [self._format_result(r) for r in results]
  210. finally:
  211. cursor.close()
  212. def update(self, strategy_id: str, updates: Dict):
  213. """更新 strategy(关联字段可选)"""
  214. cursor = self._get_cursor()
  215. try:
  216. # 分离关联字段
  217. rel_keys = ('requirement_ids',
  218. 'capability_ids', 'capability_links',
  219. 'knowledge_ids', 'knowledge_links', 'resource_ids')
  220. rel_fields = {k: updates.pop(k) for k in rel_keys if k in updates}
  221. if updates:
  222. set_parts = []
  223. params = []
  224. for key, value in updates.items():
  225. set_parts.append(f"{key} = %s")
  226. params.append(value)
  227. params.append(strategy_id)
  228. cursor.execute(
  229. f"UPDATE strategy SET {', '.join(set_parts)} WHERE id = %s",
  230. params)
  231. if rel_fields:
  232. self._save_relations(cursor, strategy_id, rel_fields)
  233. self.conn.commit()
  234. finally:
  235. cursor.close()
  236. def delete(self, strategy_id: str):
  237. """删除 strategy 及其所有 junction 行"""
  238. cursor = self._get_cursor()
  239. try:
  240. cascade_delete(cursor, 'strategy', strategy_id)
  241. self.conn.commit()
  242. finally:
  243. cursor.close()
  244. def count(self, status: Optional[str] = None) -> int:
  245. cursor = self._get_cursor()
  246. try:
  247. if status:
  248. cursor.execute("SELECT COUNT(*) as count FROM strategy WHERE status = %s", (status,))
  249. else:
  250. cursor.execute("SELECT COUNT(*) as count FROM strategy")
  251. return cursor.fetchone()['count']
  252. finally:
  253. cursor.close()
  254. # ─── 增量关联 API(不删已有)─────────────────────────────────
  255. def add_capability(self, strategy_id: str, capability_id: str,
  256. relation_type: str = 'compose'):
  257. cursor = self._get_cursor()
  258. try:
  259. cursor.execute(
  260. "INSERT INTO strategy_capability (strategy_id, capability_id, relation_type) "
  261. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  262. (strategy_id, capability_id, relation_type))
  263. self.conn.commit()
  264. finally:
  265. cursor.close()
  266. def add_knowledge(self, strategy_id: str, knowledge_id: str,
  267. relation_type: str = 'source'):
  268. cursor = self._get_cursor()
  269. try:
  270. cursor.execute(
  271. "INSERT INTO strategy_knowledge (strategy_id, knowledge_id, relation_type) "
  272. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  273. (strategy_id, knowledge_id, relation_type))
  274. self.conn.commit()
  275. finally:
  276. cursor.close()
  277. def add_resource(self, strategy_id: str, resource_id: str):
  278. cursor = self._get_cursor()
  279. try:
  280. cursor.execute(
  281. "INSERT INTO strategy_resource (strategy_id, resource_id) "
  282. "VALUES (%s, %s) ON CONFLICT DO NOTHING",
  283. (strategy_id, resource_id))
  284. self.conn.commit()
  285. finally:
  286. cursor.close()
  287. def add_requirement(self, strategy_id: str, requirement_id: str):
  288. """增量挂接 requirement-strategy 边(这个 strategy 满足该 requirement)"""
  289. cursor = self._get_cursor()
  290. try:
  291. cursor.execute(
  292. "INSERT INTO requirement_strategy (requirement_id, strategy_id) "
  293. "VALUES (%s, %s) ON CONFLICT DO NOTHING",
  294. (requirement_id, strategy_id))
  295. self.conn.commit()
  296. finally:
  297. cursor.close()
  298. # ─── 辅助 ────────────────────────────────────────────────────
  299. def _format_result(self, row: Dict) -> Optional[Dict]:
  300. if not row:
  301. return None
  302. import json
  303. result = dict(row)
  304. for field in ('requirement_ids', 'capability_ids', 'knowledge_ids', 'resource_ids'):
  305. if field in result and isinstance(result[field], str):
  306. result[field] = json.loads(result[field])
  307. elif field in result and result[field] is None:
  308. result[field] = []
  309. for field in ('capability_links', 'knowledge_links'):
  310. if field in result and isinstance(result[field], str):
  311. result[field] = json.loads(result[field])
  312. elif field in result and result[field] is None:
  313. result[field] = []
  314. return result
  315. def close(self):
  316. if self.conn:
  317. self.conn.close()