pg_requirement_store.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. """
  2. PostgreSQL requirement_table 存储封装(v2 新 schema)
  3. 字段:id, description, atomics, source_nodes, status, match_result, embedding
  4. """
  5. import os
  6. import json
  7. import psycopg2
  8. from psycopg2.extras import RealDictCursor
  9. from typing import List, Dict, Optional
  10. from dotenv import load_dotenv
  11. load_dotenv()
  12. class PostgreSQLRequirementStore:
  13. def __init__(self):
  14. """初始化 PostgreSQL 连接"""
  15. self.conn = psycopg2.connect(
  16. host=os.getenv('KNOWHUB_DB'),
  17. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  18. user=os.getenv('KNOWHUB_USER'),
  19. password=os.getenv('KNOWHUB_PASSWORD'),
  20. database=os.getenv('KNOWHUB_DB_NAME')
  21. )
  22. self.conn.autocommit = False
  23. print(f"[PostgreSQL Requirement] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  24. def _get_cursor(self):
  25. return self.conn.cursor(cursor_factory=RealDictCursor)
  26. def insert_or_update(self, requirement: Dict):
  27. """插入或更新需求记录"""
  28. cursor = self._get_cursor()
  29. try:
  30. cursor.execute("""
  31. INSERT INTO requirement_table (
  32. id, description, atomics, source_nodes, status, match_result, embedding
  33. ) VALUES (%s, %s, %s, %s, %s, %s, %s)
  34. ON CONFLICT (id) DO UPDATE SET
  35. description = EXCLUDED.description,
  36. atomics = EXCLUDED.atomics,
  37. source_nodes = EXCLUDED.source_nodes,
  38. status = EXCLUDED.status,
  39. match_result = EXCLUDED.match_result,
  40. embedding = EXCLUDED.embedding
  41. """, (
  42. requirement['id'],
  43. requirement.get('description', ''),
  44. json.dumps(requirement.get('atomics', [])),
  45. json.dumps(requirement.get('source_nodes', [])),
  46. requirement.get('status', '未满足'),
  47. requirement.get('match_result', ''),
  48. requirement.get('embedding'),
  49. ))
  50. self.conn.commit()
  51. finally:
  52. cursor.close()
  53. def get_by_id(self, req_id: str) -> Optional[Dict]:
  54. """根据 ID 获取需求"""
  55. cursor = self._get_cursor()
  56. try:
  57. cursor.execute("""
  58. SELECT id, description, atomics, source_nodes, status, match_result
  59. FROM requirement_table WHERE id = %s
  60. """, (req_id,))
  61. result = cursor.fetchone()
  62. return self._format_result(result) if result else None
  63. finally:
  64. cursor.close()
  65. def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]:
  66. """向量检索需求"""
  67. cursor = self._get_cursor()
  68. try:
  69. cursor.execute("""
  70. SELECT id, description, atomics, source_nodes, status, match_result,
  71. 1 - (embedding <=> %s::real[]) as score
  72. FROM requirement_table
  73. WHERE embedding IS NOT NULL
  74. ORDER BY embedding <=> %s::real[]
  75. LIMIT %s
  76. """, (query_embedding, query_embedding, limit))
  77. results = cursor.fetchall()
  78. return [self._format_result(r) for r in results]
  79. finally:
  80. cursor.close()
  81. def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]:
  82. """列出需求"""
  83. cursor = self._get_cursor()
  84. try:
  85. if status:
  86. cursor.execute("""
  87. SELECT id, description, atomics, source_nodes, status, match_result
  88. FROM requirement_table
  89. WHERE status = %s
  90. ORDER BY id
  91. LIMIT %s OFFSET %s
  92. """, (status, limit, offset))
  93. else:
  94. cursor.execute("""
  95. SELECT id, description, atomics, source_nodes, status, match_result
  96. FROM requirement_table
  97. ORDER BY id
  98. LIMIT %s OFFSET %s
  99. """, (limit, offset))
  100. results = cursor.fetchall()
  101. return [self._format_result(r) for r in results]
  102. finally:
  103. cursor.close()
  104. def update(self, req_id: str, updates: Dict):
  105. """更新需求字段"""
  106. cursor = self._get_cursor()
  107. try:
  108. set_parts = []
  109. params = []
  110. json_fields = ('atomics', 'source_nodes')
  111. for key, value in updates.items():
  112. set_parts.append(f"{key} = %s")
  113. if key in json_fields:
  114. params.append(json.dumps(value))
  115. else:
  116. params.append(value)
  117. params.append(req_id)
  118. cursor.execute(
  119. f"UPDATE requirement_table SET {', '.join(set_parts)} WHERE id = %s",
  120. params
  121. )
  122. self.conn.commit()
  123. finally:
  124. cursor.close()
  125. def delete(self, req_id: str):
  126. """删除需求"""
  127. cursor = self._get_cursor()
  128. try:
  129. cursor.execute("DELETE FROM requirement_table WHERE id = %s", (req_id,))
  130. self.conn.commit()
  131. finally:
  132. cursor.close()
  133. def count(self, status: Optional[str] = None) -> int:
  134. """统计需求总数"""
  135. cursor = self._get_cursor()
  136. try:
  137. if status:
  138. cursor.execute("SELECT COUNT(*) as count FROM requirement_table WHERE status = %s", (status,))
  139. else:
  140. cursor.execute("SELECT COUNT(*) as count FROM requirement_table")
  141. return cursor.fetchone()['count']
  142. finally:
  143. cursor.close()
  144. def _format_result(self, row: Dict) -> Dict:
  145. """格式化查询结果"""
  146. if not row:
  147. return None
  148. result = dict(row)
  149. for field in ('atomics', 'source_nodes'):
  150. if field in result and isinstance(result[field], str):
  151. result[field] = json.loads(result[field])
  152. return result
  153. def close(self):
  154. if self.conn:
  155. self.conn.close()