pg_requirement_store.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. """
  2. PostgreSQL requirement_table 存储封装(v2 新 schema)
  3. 字段:id, description, atomics, source_nodes, status, match_result, embedding
  4. """
  5. import os
  6. import json
  7. import psycopg2
  8. from psycopg2.extras import RealDictCursor
  9. from typing import List, Dict, Optional
  10. from dotenv import load_dotenv
  11. load_dotenv()
  12. class PostgreSQLRequirementStore:
  13. def __init__(self):
  14. """初始化 PostgreSQL 连接"""
  15. self.conn = psycopg2.connect(
  16. host=os.getenv('KNOWHUB_DB'),
  17. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  18. user=os.getenv('KNOWHUB_USER'),
  19. password=os.getenv('KNOWHUB_PASSWORD'),
  20. database=os.getenv('KNOWHUB_DB_NAME')
  21. )
  22. self.conn.autocommit = False
  23. print(f"[PostgreSQL Requirement] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  24. def _reconnect(self):
  25. self.conn = psycopg2.connect(
  26. host=os.getenv('KNOWHUB_DB'),
  27. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  28. user=os.getenv('KNOWHUB_USER'),
  29. password=os.getenv('KNOWHUB_PASSWORD'),
  30. database=os.getenv('KNOWHUB_DB_NAME')
  31. )
  32. self.conn.autocommit = False
  33. def _ensure_connection(self):
  34. if self.conn.closed != 0:
  35. self._reconnect()
  36. else:
  37. try:
  38. c = self.conn.cursor()
  39. c.execute("SELECT 1")
  40. c.close()
  41. except (psycopg2.OperationalError, psycopg2.InterfaceError):
  42. self._reconnect()
  43. def _get_cursor(self):
  44. self._ensure_connection()
  45. return self.conn.cursor(cursor_factory=RealDictCursor)
  46. def insert_or_update(self, requirement: Dict):
  47. """插入或更新需求记录"""
  48. cursor = self._get_cursor()
  49. try:
  50. cursor.execute("""
  51. INSERT INTO requirement_table (
  52. id, description, atomics, source_nodes, status, match_result, embedding
  53. ) VALUES (%s, %s, %s, %s, %s, %s, %s)
  54. ON CONFLICT (id) DO UPDATE SET
  55. description = EXCLUDED.description,
  56. atomics = EXCLUDED.atomics,
  57. source_nodes = EXCLUDED.source_nodes,
  58. status = EXCLUDED.status,
  59. match_result = EXCLUDED.match_result,
  60. embedding = EXCLUDED.embedding
  61. """, (
  62. requirement['id'],
  63. requirement.get('description', ''),
  64. json.dumps(requirement.get('atomics', [])),
  65. json.dumps(requirement.get('source_nodes', [])),
  66. requirement.get('status', '未满足'),
  67. requirement.get('match_result', ''),
  68. requirement.get('embedding'),
  69. ))
  70. self.conn.commit()
  71. finally:
  72. cursor.close()
  73. def get_by_id(self, req_id: str) -> Optional[Dict]:
  74. """根据 ID 获取需求"""
  75. cursor = self._get_cursor()
  76. try:
  77. cursor.execute("""
  78. SELECT id, description, atomics, source_nodes, status, match_result
  79. FROM requirement_table WHERE id = %s
  80. """, (req_id,))
  81. result = cursor.fetchone()
  82. return self._format_result(result) if result else None
  83. finally:
  84. cursor.close()
  85. def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]:
  86. """向量检索需求"""
  87. cursor = self._get_cursor()
  88. try:
  89. cursor.execute("""
  90. SELECT id, description, atomics, source_nodes, status, match_result,
  91. 1 - (embedding <=> %s::real[]) as score
  92. FROM requirement_table
  93. WHERE embedding IS NOT NULL
  94. ORDER BY embedding <=> %s::real[]
  95. LIMIT %s
  96. """, (query_embedding, query_embedding, limit))
  97. results = cursor.fetchall()
  98. return [self._format_result(r) for r in results]
  99. finally:
  100. cursor.close()
  101. def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]:
  102. """列出需求"""
  103. cursor = self._get_cursor()
  104. try:
  105. if status:
  106. cursor.execute("""
  107. SELECT id, description, atomics, source_nodes, status, match_result
  108. FROM requirement_table
  109. WHERE status = %s
  110. ORDER BY id
  111. LIMIT %s OFFSET %s
  112. """, (status, limit, offset))
  113. else:
  114. cursor.execute("""
  115. SELECT id, description, atomics, source_nodes, status, match_result
  116. FROM requirement_table
  117. ORDER BY id
  118. LIMIT %s OFFSET %s
  119. """, (limit, offset))
  120. results = cursor.fetchall()
  121. return [self._format_result(r) for r in results]
  122. finally:
  123. cursor.close()
  124. def update(self, req_id: str, updates: Dict):
  125. """更新需求字段"""
  126. cursor = self._get_cursor()
  127. try:
  128. set_parts = []
  129. params = []
  130. json_fields = ('atomics', 'source_nodes')
  131. for key, value in updates.items():
  132. set_parts.append(f"{key} = %s")
  133. if key in json_fields:
  134. params.append(json.dumps(value))
  135. else:
  136. params.append(value)
  137. params.append(req_id)
  138. cursor.execute(
  139. f"UPDATE requirement_table SET {', '.join(set_parts)} WHERE id = %s",
  140. params
  141. )
  142. self.conn.commit()
  143. finally:
  144. cursor.close()
  145. def delete(self, req_id: str):
  146. """删除需求"""
  147. cursor = self._get_cursor()
  148. try:
  149. cursor.execute("DELETE FROM requirement_table WHERE id = %s", (req_id,))
  150. self.conn.commit()
  151. finally:
  152. cursor.close()
  153. def count(self, status: Optional[str] = None) -> int:
  154. """统计需求总数"""
  155. cursor = self._get_cursor()
  156. try:
  157. if status:
  158. cursor.execute("SELECT COUNT(*) as count FROM requirement_table WHERE status = %s", (status,))
  159. else:
  160. cursor.execute("SELECT COUNT(*) as count FROM requirement_table")
  161. return cursor.fetchone()['count']
  162. finally:
  163. cursor.close()
  164. def _format_result(self, row: Dict) -> Dict:
  165. """格式化查询结果"""
  166. if not row:
  167. return None
  168. result = dict(row)
  169. for field in ('atomics', 'source_nodes'):
  170. if field in result and isinstance(result[field], str):
  171. result[field] = json.loads(result[field])
  172. return result
  173. def close(self):
  174. if self.conn:
  175. self.conn.close()