pg_requirement_store.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. """
  2. PostgreSQL requirement 存储封装
  3. 用于存储和检索需求数据,支持向量检索。
  4. 表名:requirement(从 requirement_table 迁移)
  5. """
  6. import os
  7. import json
  8. import psycopg2
  9. from psycopg2.extras import RealDictCursor
  10. from typing import List, Dict, Optional
  11. from dotenv import load_dotenv
  12. from knowhub.knowhub_db.cascade import cascade_delete
  13. load_dotenv()
  14. # 关联字段子查询
  15. _REL_SUBQUERY = """
  16. (SELECT COALESCE(json_agg(rc.capability_id), '[]'::json)
  17. FROM requirement_capability rc WHERE rc.requirement_id = requirement.id) AS capability_ids,
  18. (SELECT COALESCE(json_agg(rk.knowledge_id), '[]'::json)
  19. FROM requirement_knowledge rk WHERE rk.requirement_id = requirement.id) AS knowledge_ids
  20. """
  21. _BASE_FIELDS = "id, description, source_nodes, status, match_result"
  22. _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERY}"
  23. class PostgreSQLRequirementStore:
  24. def __init__(self):
  25. """初始化 PostgreSQL 连接"""
  26. self.conn = psycopg2.connect(
  27. host=os.getenv('KNOWHUB_DB'),
  28. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  29. user=os.getenv('KNOWHUB_USER'),
  30. password=os.getenv('KNOWHUB_PASSWORD'),
  31. database=os.getenv('KNOWHUB_DB_NAME')
  32. )
  33. self.conn.autocommit = False
  34. print(f"[PostgreSQL Requirement] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  35. def _reconnect(self):
  36. self.conn = psycopg2.connect(
  37. host=os.getenv('KNOWHUB_DB'),
  38. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  39. user=os.getenv('KNOWHUB_USER'),
  40. password=os.getenv('KNOWHUB_PASSWORD'),
  41. database=os.getenv('KNOWHUB_DB_NAME')
  42. )
  43. self.conn.autocommit = False
  44. def _ensure_connection(self):
  45. if self.conn.closed != 0:
  46. self._reconnect()
  47. else:
  48. try:
  49. c = self.conn.cursor()
  50. c.execute("SELECT 1")
  51. c.close()
  52. except (psycopg2.OperationalError, psycopg2.InterfaceError):
  53. self._reconnect()
  54. def _get_cursor(self):
  55. self._ensure_connection()
  56. return self.conn.cursor(cursor_factory=RealDictCursor)
  57. def insert_or_update(self, requirement: Dict):
  58. """插入或更新需求记录"""
  59. cursor = self._get_cursor()
  60. try:
  61. cursor.execute("""
  62. INSERT INTO requirement (
  63. id, description, source_nodes, status, match_result, embedding
  64. ) VALUES (%s, %s, %s, %s, %s, %s)
  65. ON CONFLICT (id) DO UPDATE SET
  66. description = EXCLUDED.description,
  67. source_nodes = EXCLUDED.source_nodes,
  68. status = EXCLUDED.status,
  69. match_result = EXCLUDED.match_result,
  70. embedding = EXCLUDED.embedding
  71. """, (
  72. requirement['id'],
  73. requirement.get('description', ''),
  74. json.dumps(requirement.get('source_nodes', [])),
  75. requirement.get('status', '未满足'),
  76. requirement.get('match_result', ''),
  77. requirement.get('embedding'),
  78. ))
  79. # 写入关联表
  80. req_id = requirement['id']
  81. if 'capability_ids' in requirement:
  82. cursor.execute("DELETE FROM requirement_capability WHERE requirement_id = %s", (req_id,))
  83. for cap_id in requirement['capability_ids']:
  84. cursor.execute(
  85. "INSERT INTO requirement_capability (requirement_id, capability_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  86. (req_id, cap_id))
  87. if 'knowledge_ids' in requirement:
  88. cursor.execute("DELETE FROM requirement_knowledge WHERE requirement_id = %s", (req_id,))
  89. for kid in requirement['knowledge_ids']:
  90. cursor.execute(
  91. "INSERT INTO requirement_knowledge (requirement_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  92. (req_id, kid))
  93. self.conn.commit()
  94. finally:
  95. cursor.close()
  96. def get_by_id(self, req_id: str) -> Optional[Dict]:
  97. """根据 ID 获取需求"""
  98. cursor = self._get_cursor()
  99. try:
  100. cursor.execute(f"""
  101. SELECT {_SELECT_FIELDS}
  102. FROM requirement WHERE id = %s
  103. """, (req_id,))
  104. result = cursor.fetchone()
  105. return self._format_result(result) if result else None
  106. finally:
  107. cursor.close()
  108. def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]:
  109. """向量检索需求"""
  110. cursor = self._get_cursor()
  111. try:
  112. cursor.execute(f"""
  113. SELECT {_SELECT_FIELDS},
  114. 1 - (embedding <=> %s::real[]) as score
  115. FROM requirement
  116. WHERE embedding IS NOT NULL
  117. ORDER BY embedding <=> %s::real[]
  118. LIMIT %s
  119. """, (query_embedding, query_embedding, limit))
  120. results = cursor.fetchall()
  121. return [self._format_result(r) for r in results]
  122. finally:
  123. cursor.close()
  124. def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]:
  125. """列出需求"""
  126. cursor = self._get_cursor()
  127. try:
  128. if status:
  129. cursor.execute(f"""
  130. SELECT {_SELECT_FIELDS}
  131. FROM requirement
  132. WHERE status = %s
  133. ORDER BY id
  134. LIMIT %s OFFSET %s
  135. """, (status, limit, offset))
  136. else:
  137. cursor.execute(f"""
  138. SELECT {_SELECT_FIELDS}
  139. FROM requirement
  140. ORDER BY id
  141. LIMIT %s OFFSET %s
  142. """, (limit, offset))
  143. results = cursor.fetchall()
  144. return [self._format_result(r) for r in results]
  145. finally:
  146. cursor.close()
  147. def update(self, req_id: str, updates: Dict):
  148. """更新需求字段"""
  149. cursor = self._get_cursor()
  150. try:
  151. # 分离关联字段
  152. cap_ids = updates.pop('capability_ids', None)
  153. knowledge_ids = updates.pop('knowledge_ids', None)
  154. if updates:
  155. set_parts = []
  156. params = []
  157. json_fields = ('source_nodes',)
  158. for key, value in updates.items():
  159. set_parts.append(f"{key} = %s")
  160. if key in json_fields:
  161. params.append(json.dumps(value))
  162. else:
  163. params.append(value)
  164. params.append(req_id)
  165. cursor.execute(
  166. f"UPDATE requirement SET {', '.join(set_parts)} WHERE id = %s",
  167. params
  168. )
  169. if cap_ids is not None:
  170. cursor.execute("DELETE FROM requirement_capability WHERE requirement_id = %s", (req_id,))
  171. for cap_id in cap_ids:
  172. cursor.execute(
  173. "INSERT INTO requirement_capability (requirement_id, capability_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  174. (req_id, cap_id))
  175. if knowledge_ids is not None:
  176. cursor.execute("DELETE FROM requirement_knowledge WHERE requirement_id = %s", (req_id,))
  177. for kid in knowledge_ids:
  178. cursor.execute(
  179. "INSERT INTO requirement_knowledge (requirement_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  180. (req_id, kid))
  181. self.conn.commit()
  182. finally:
  183. cursor.close()
  184. def delete(self, req_id: str):
  185. """删除需求及其关联表记录"""
  186. cursor = self._get_cursor()
  187. try:
  188. cascade_delete(cursor, 'requirement', req_id)
  189. self.conn.commit()
  190. finally:
  191. cursor.close()
  192. def count(self, status: Optional[str] = None) -> int:
  193. """统计需求总数"""
  194. cursor = self._get_cursor()
  195. try:
  196. if status:
  197. cursor.execute("SELECT COUNT(*) as count FROM requirement WHERE status = %s", (status,))
  198. else:
  199. cursor.execute("SELECT COUNT(*) as count FROM requirement")
  200. return cursor.fetchone()['count']
  201. finally:
  202. cursor.close()
  203. def _format_result(self, row: Dict) -> Dict:
  204. """格式化查询结果"""
  205. if not row:
  206. return None
  207. result = dict(row)
  208. if 'source_nodes' in result and isinstance(result['source_nodes'], str):
  209. result['source_nodes'] = json.loads(result['source_nodes'])
  210. # 关联字段(来自 junction table 子查询)
  211. for field in ('capability_ids', 'knowledge_ids'):
  212. if field in result and isinstance(result[field], str):
  213. result[field] = json.loads(result[field])
  214. elif field in result and result[field] is None:
  215. result[field] = []
  216. return result
  217. def close(self):
  218. if self.conn:
  219. self.conn.close()