pg_requirement_store.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. """
  2. PostgreSQL requirement 存储封装
  3. 用于存储和检索需求数据,支持向量检索。
  4. 表名:requirement(从 requirement_table 迁移)
  5. """
  6. import os
  7. import json
  8. import psycopg2
  9. from psycopg2.extras import RealDictCursor
  10. from typing import List, Dict, Optional
  11. from dotenv import load_dotenv
  12. from knowhub.knowhub_db.cascade import cascade_delete
  13. load_dotenv()
  14. # 关联字段子查询。knowledge 边暴露两种视图:knowledge_ids(扁平)+ knowledge_links(含 type)
  15. _REL_SUBQUERY = """
  16. (SELECT COALESCE(json_agg(rc.capability_id), '[]'::json)
  17. FROM requirement_capability rc WHERE rc.requirement_id = requirement.id) AS capability_ids,
  18. (SELECT COALESCE(json_agg(rk.knowledge_id), '[]'::json)
  19. FROM requirement_knowledge rk WHERE rk.requirement_id = requirement.id) AS knowledge_ids,
  20. (SELECT COALESCE(json_agg(json_build_object(
  21. 'id', rk2.knowledge_id, 'relation_type', rk2.relation_type
  22. )), '[]'::json)
  23. FROM requirement_knowledge rk2 WHERE rk2.requirement_id = requirement.id) AS knowledge_links,
  24. (SELECT COALESCE(json_agg(rr.resource_id), '[]'::json)
  25. FROM requirement_resource rr WHERE rr.requirement_id = requirement.id) AS resource_ids,
  26. (SELECT COALESCE(json_agg(rs.strategy_id), '[]'::json)
  27. FROM requirement_strategy rs WHERE rs.requirement_id = requirement.id) AS strategy_ids,
  28. (SELECT COALESCE(json_agg(rp.itemset_id), '[]'::json)
  29. FROM requirement_pattern rp WHERE rp.requirement_id = requirement.id) AS pattern_ids,
  30. (SELECT COALESCE(json_agg(rn.node_id), '[]'::json)
  31. FROM requirement_node rn WHERE rn.requirement_id = requirement.id) AS node_ids
  32. """
  33. _BASE_FIELDS = "id, description, source_nodes, status, match_result, version"
  34. _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERY}"
  35. def _normalize_links(data: Dict, links_key: str, ids_key: str, default_type: str):
  36. """两种输入格式统一:{links_key: [{id, relation_type}]} 或 {ids_key: [id]}"""
  37. if links_key in data and data[links_key] is not None:
  38. out = []
  39. for item in data[links_key]:
  40. if isinstance(item, dict):
  41. out.append((item['id'], item.get('relation_type', default_type)))
  42. else:
  43. out.append((item, default_type))
  44. return out
  45. if ids_key in data and data[ids_key] is not None:
  46. return [(i, default_type) for i in data[ids_key]]
  47. return None
  48. class PostgreSQLRequirementStore:
  49. def __init__(self):
  50. """初始化 PostgreSQL 连接"""
  51. self.conn = psycopg2.connect(
  52. host=os.getenv('KNOWHUB_DB'),
  53. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  54. user=os.getenv('KNOWHUB_USER'),
  55. password=os.getenv('KNOWHUB_PASSWORD'),
  56. database=os.getenv('KNOWHUB_DB_NAME')
  57. )
  58. self.conn.autocommit = True
  59. print(f"[PostgreSQL Requirement] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  60. def _reconnect(self):
  61. self.conn = psycopg2.connect(
  62. host=os.getenv('KNOWHUB_DB'),
  63. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  64. user=os.getenv('KNOWHUB_USER'),
  65. password=os.getenv('KNOWHUB_PASSWORD'),
  66. database=os.getenv('KNOWHUB_DB_NAME')
  67. )
  68. self.conn.autocommit = True
  69. def _ensure_connection(self):
  70. if self.conn.closed != 0:
  71. self._reconnect()
  72. else:
  73. try:
  74. c = self.conn.cursor()
  75. c.execute("SELECT 1")
  76. c.close()
  77. except (psycopg2.OperationalError, psycopg2.InterfaceError):
  78. self._reconnect()
  79. def _get_cursor(self):
  80. self._ensure_connection()
  81. return self.conn.cursor(cursor_factory=RealDictCursor)
  82. def insert_or_update(self, requirement: Dict):
  83. """插入或更新需求记录。AnalyticDB beam 表不支持 ON CONFLICT UPDATE 当含 ALTER 新增列,改用 DELETE+INSERT。"""
  84. cursor = self._get_cursor()
  85. try:
  86. cursor.execute("DELETE FROM requirement WHERE id = %s", (requirement['id'],))
  87. cursor.execute("""
  88. INSERT INTO requirement (
  89. id, description, source_nodes, status, match_result, embedding, version
  90. ) VALUES (%s, %s, %s, %s, %s, %s, %s)
  91. """, (
  92. requirement['id'],
  93. requirement.get('description', ''),
  94. json.dumps(requirement.get('source_nodes', [])),
  95. requirement.get('status', '未满足'),
  96. requirement.get('match_result', ''),
  97. requirement.get('embedding'),
  98. requirement.get('version', 'v0'),
  99. ))
  100. # 写入关联表
  101. req_id = requirement['id']
  102. if 'capability_ids' in requirement:
  103. cursor.execute("DELETE FROM requirement_capability WHERE requirement_id = %s", (req_id,))
  104. for cap_id in requirement['capability_ids']:
  105. cursor.execute(
  106. "INSERT INTO requirement_capability (requirement_id, capability_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  107. (req_id, cap_id))
  108. k_links = _normalize_links(requirement, 'knowledge_links', 'knowledge_ids', 'related')
  109. if k_links is not None:
  110. cursor.execute("DELETE FROM requirement_knowledge WHERE requirement_id = %s", (req_id,))
  111. for kid, rtype in k_links:
  112. cursor.execute(
  113. "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) "
  114. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  115. (req_id, kid, rtype))
  116. if 'resource_ids' in requirement and requirement['resource_ids'] is not None:
  117. cursor.execute("DELETE FROM requirement_resource WHERE requirement_id = %s", (req_id,))
  118. for rid in requirement['resource_ids']:
  119. cursor.execute(
  120. "INSERT INTO requirement_resource (requirement_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  121. (req_id, rid))
  122. if 'strategy_ids' in requirement and requirement['strategy_ids'] is not None:
  123. cursor.execute("DELETE FROM requirement_strategy WHERE requirement_id = %s", (req_id,))
  124. for sid in requirement['strategy_ids']:
  125. cursor.execute(
  126. "INSERT INTO requirement_strategy (requirement_id, strategy_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  127. (req_id, sid))
  128. self.conn.commit()
  129. finally:
  130. cursor.close()
  131. def get_by_id(self, req_id: str) -> Optional[Dict]:
  132. """根据 ID 获取需求"""
  133. from knowhub.knowhub_db.version_context import req_version_where
  134. cursor = self._get_cursor()
  135. try:
  136. vf, vp = req_version_where()
  137. cursor.execute(f"""
  138. SELECT {_SELECT_FIELDS}
  139. FROM requirement WHERE id = %s AND {vf}
  140. """, (req_id, *vp))
  141. result = cursor.fetchone()
  142. return self._format_result(result) if result else None
  143. finally:
  144. cursor.close()
  145. def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]:
  146. """向量检索需求"""
  147. from knowhub.knowhub_db.version_context import req_version_where
  148. cursor = self._get_cursor()
  149. try:
  150. vf, vp = req_version_where()
  151. cursor.execute(f"""
  152. SELECT {_SELECT_FIELDS},
  153. 1 - (embedding <=> %s::real[]) as score
  154. FROM requirement
  155. WHERE embedding IS NOT NULL AND {vf}
  156. ORDER BY embedding <=> %s::real[]
  157. LIMIT %s
  158. """, (query_embedding, *vp, query_embedding, limit))
  159. results = cursor.fetchall()
  160. return [self._format_result(r) for r in results]
  161. finally:
  162. cursor.close()
  163. def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]:
  164. """列出需求"""
  165. from knowhub.knowhub_db.version_context import req_version_where
  166. cursor = self._get_cursor()
  167. try:
  168. vf, vp = req_version_where()
  169. if status:
  170. cursor.execute(f"""
  171. SELECT {_SELECT_FIELDS}
  172. FROM requirement
  173. WHERE status = %s AND {vf}
  174. ORDER BY id
  175. LIMIT %s OFFSET %s
  176. """, (status, *vp, limit, offset))
  177. else:
  178. cursor.execute(f"""
  179. SELECT {_SELECT_FIELDS}
  180. FROM requirement
  181. WHERE {vf}
  182. ORDER BY id
  183. LIMIT %s OFFSET %s
  184. """, (*vp, limit, offset))
  185. results = cursor.fetchall()
  186. return [self._format_result(r) for r in results]
  187. finally:
  188. cursor.close()
  189. def update(self, req_id: str, updates: Dict):
  190. """更新需求字段"""
  191. cursor = self._get_cursor()
  192. try:
  193. # 分离关联字段
  194. cap_ids = updates.pop('capability_ids', None)
  195. strategy_ids = updates.pop('strategy_ids', None)
  196. rel_data = {}
  197. for k in ('knowledge_ids', 'knowledge_links', 'resource_ids'):
  198. if k in updates:
  199. rel_data[k] = updates.pop(k)
  200. if updates:
  201. set_parts = []
  202. params = []
  203. json_fields = ('source_nodes',)
  204. for key, value in updates.items():
  205. set_parts.append(f"{key} = %s")
  206. if key in json_fields:
  207. params.append(json.dumps(value))
  208. else:
  209. params.append(value)
  210. params.append(req_id)
  211. cursor.execute(
  212. f"UPDATE requirement SET {', '.join(set_parts)} WHERE id = %s",
  213. params
  214. )
  215. if cap_ids is not None:
  216. cursor.execute("DELETE FROM requirement_capability WHERE requirement_id = %s", (req_id,))
  217. for cap_id in cap_ids:
  218. cursor.execute(
  219. "INSERT INTO requirement_capability (requirement_id, capability_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  220. (req_id, cap_id))
  221. k_links = _normalize_links(rel_data, 'knowledge_links', 'knowledge_ids', 'related')
  222. if k_links is not None:
  223. cursor.execute("DELETE FROM requirement_knowledge WHERE requirement_id = %s", (req_id,))
  224. for kid, rtype in k_links:
  225. cursor.execute(
  226. "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) "
  227. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  228. (req_id, kid, rtype))
  229. if 'resource_ids' in rel_data and rel_data['resource_ids'] is not None:
  230. cursor.execute("DELETE FROM requirement_resource WHERE requirement_id = %s", (req_id,))
  231. for rid in rel_data['resource_ids']:
  232. cursor.execute(
  233. "INSERT INTO requirement_resource (requirement_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  234. (req_id, rid))
  235. if strategy_ids is not None:
  236. cursor.execute("DELETE FROM requirement_strategy WHERE requirement_id = %s", (req_id,))
  237. for sid in strategy_ids:
  238. cursor.execute(
  239. "INSERT INTO requirement_strategy (requirement_id, strategy_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  240. (req_id, sid))
  241. self.conn.commit()
  242. finally:
  243. cursor.close()
  244. def add_knowledge(self, req_id: str, knowledge_id: str, relation_type: str = 'related'):
  245. """增量挂接 requirement-knowledge 边"""
  246. cursor = self._get_cursor()
  247. try:
  248. cursor.execute(
  249. "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) "
  250. "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  251. (req_id, knowledge_id, relation_type))
  252. self.conn.commit()
  253. finally:
  254. cursor.close()
  255. def add_resource(self, req_id: str, resource_id: str):
  256. """增量挂接 requirement-resource 边"""
  257. cursor = self._get_cursor()
  258. try:
  259. cursor.execute(
  260. "INSERT INTO requirement_resource (requirement_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  261. (req_id, resource_id))
  262. self.conn.commit()
  263. finally:
  264. cursor.close()
  265. def add_strategy(self, req_id: str, strategy_id: str):
  266. """增量挂接 requirement-strategy 边(该 strategy 满足此 requirement)"""
  267. cursor = self._get_cursor()
  268. try:
  269. cursor.execute(
  270. "INSERT INTO requirement_strategy (requirement_id, strategy_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  271. (req_id, strategy_id))
  272. self.conn.commit()
  273. finally:
  274. cursor.close()
  275. def delete(self, req_id: str):
  276. """删除需求及其关联表记录"""
  277. cursor = self._get_cursor()
  278. try:
  279. cascade_delete(cursor, 'requirement', req_id)
  280. self.conn.commit()
  281. finally:
  282. cursor.close()
  283. def count(self, status: Optional[str] = None) -> int:
  284. """统计需求总数"""
  285. from knowhub.knowhub_db.version_context import req_version_where
  286. cursor = self._get_cursor()
  287. try:
  288. vf, vp = req_version_where()
  289. if status:
  290. cursor.execute(f"SELECT COUNT(*) as count FROM requirement WHERE status = %s AND {vf}",
  291. (status, *vp))
  292. else:
  293. cursor.execute(f"SELECT COUNT(*) as count FROM requirement WHERE {vf}", vp)
  294. return cursor.fetchone()['count']
  295. finally:
  296. cursor.close()
  297. def _format_result(self, row: Dict) -> Dict:
  298. """格式化查询结果"""
  299. if not row:
  300. return None
  301. result = dict(row)
  302. if 'source_nodes' in result and isinstance(result['source_nodes'], str):
  303. result['source_nodes'] = json.loads(result['source_nodes'])
  304. # 关联字段(来自 junction table 子查询)
  305. for field in ('capability_ids', 'knowledge_ids', 'resource_ids', 'strategy_ids', 'knowledge_links', 'pattern_ids', 'node_ids'):
  306. if field in result and isinstance(result[field], str):
  307. result[field] = json.loads(result[field])
  308. elif field in result and result[field] is None:
  309. result[field] = []
  310. return result
  311. def close(self):
  312. if self.conn:
  313. self.conn.close()