pg_capability_store.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. """
  2. PostgreSQL atomic_capability 存储封装
  3. 用于存储和检索原子能力数据,支持向量检索
  4. """
  5. import os
  6. import json
  7. import psycopg2
  8. from psycopg2.extras import RealDictCursor
  9. from typing import List, Dict, Optional
  10. from dotenv import load_dotenv
  11. load_dotenv()
  12. class PostgreSQLCapabilityStore:
  13. def __init__(self):
  14. """初始化 PostgreSQL 连接"""
  15. self.conn = psycopg2.connect(
  16. host=os.getenv('KNOWHUB_DB'),
  17. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  18. user=os.getenv('KNOWHUB_USER'),
  19. password=os.getenv('KNOWHUB_PASSWORD'),
  20. database=os.getenv('KNOWHUB_DB_NAME')
  21. )
  22. self.conn.autocommit = False
  23. print(f"[PostgreSQL Capability] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  24. def _reconnect(self):
  25. self.conn = psycopg2.connect(
  26. host=os.getenv('KNOWHUB_DB'),
  27. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  28. user=os.getenv('KNOWHUB_USER'),
  29. password=os.getenv('KNOWHUB_PASSWORD'),
  30. database=os.getenv('KNOWHUB_DB_NAME')
  31. )
  32. self.conn.autocommit = False
  33. def _ensure_connection(self):
  34. if self.conn.closed != 0:
  35. self._reconnect()
  36. else:
  37. try:
  38. c = self.conn.cursor()
  39. c.execute("SELECT 1")
  40. c.close()
  41. except (psycopg2.OperationalError, psycopg2.InterfaceError):
  42. self._reconnect()
  43. def _get_cursor(self):
  44. self._ensure_connection()
  45. return self.conn.cursor(cursor_factory=RealDictCursor)
  46. def insert_or_update(self, cap: Dict):
  47. """插入或更新原子能力"""
  48. cursor = self._get_cursor()
  49. try:
  50. cursor.execute("""
  51. INSERT INTO atomic_capability (
  52. id, name, criterion, description, requirements,
  53. implements, tools, source_knowledge, embedding
  54. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
  55. ON CONFLICT (id) DO UPDATE SET
  56. name = EXCLUDED.name,
  57. criterion = EXCLUDED.criterion,
  58. description = EXCLUDED.description,
  59. requirements = EXCLUDED.requirements,
  60. implements = EXCLUDED.implements,
  61. tools = EXCLUDED.tools,
  62. source_knowledge = EXCLUDED.source_knowledge,
  63. embedding = EXCLUDED.embedding
  64. """, (
  65. cap['id'],
  66. cap.get('name', ''),
  67. cap.get('criterion', ''),
  68. cap.get('description', ''),
  69. json.dumps(cap.get('requirements', [])),
  70. json.dumps(cap.get('implements', {})),
  71. json.dumps(cap.get('tools', [])),
  72. json.dumps(cap.get('source_knowledge', [])),
  73. cap.get('embedding'),
  74. ))
  75. self.conn.commit()
  76. finally:
  77. cursor.close()
  78. def get_by_id(self, cap_id: str) -> Optional[Dict]:
  79. """根据 ID 获取原子能力"""
  80. cursor = self._get_cursor()
  81. try:
  82. cursor.execute("""
  83. SELECT id, name, criterion, description, requirements,
  84. implements, tools, source_knowledge
  85. FROM atomic_capability WHERE id = %s
  86. """, (cap_id,))
  87. result = cursor.fetchone()
  88. return self._format_result(result) if result else None
  89. finally:
  90. cursor.close()
  91. def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]:
  92. """向量检索原子能力"""
  93. cursor = self._get_cursor()
  94. try:
  95. cursor.execute("""
  96. SELECT id, name, criterion, description, requirements,
  97. implements, tools, source_knowledge,
  98. 1 - (embedding <=> %s::real[]) as score
  99. FROM atomic_capability
  100. WHERE embedding IS NOT NULL
  101. ORDER BY embedding <=> %s::real[]
  102. LIMIT %s
  103. """, (query_embedding, query_embedding, limit))
  104. results = cursor.fetchall()
  105. return [self._format_result(r) for r in results]
  106. finally:
  107. cursor.close()
  108. def list_all(self, limit: int = 100, offset: int = 0) -> List[Dict]:
  109. """列出原子能力"""
  110. cursor = self._get_cursor()
  111. try:
  112. cursor.execute("""
  113. SELECT id, name, criterion, description, requirements,
  114. implements, tools, source_knowledge
  115. FROM atomic_capability
  116. ORDER BY id
  117. LIMIT %s OFFSET %s
  118. """, (limit, offset))
  119. results = cursor.fetchall()
  120. return [self._format_result(r) for r in results]
  121. finally:
  122. cursor.close()
  123. def update(self, cap_id: str, updates: Dict):
  124. """更新原子能力字段"""
  125. cursor = self._get_cursor()
  126. try:
  127. set_parts = []
  128. params = []
  129. json_fields = ('requirements', 'implements', 'tools', 'source_knowledge')
  130. for key, value in updates.items():
  131. set_parts.append(f"{key} = %s")
  132. if key in json_fields:
  133. params.append(json.dumps(value))
  134. else:
  135. params.append(value)
  136. params.append(cap_id)
  137. cursor.execute(
  138. f"UPDATE atomic_capability SET {', '.join(set_parts)} WHERE id = %s",
  139. params
  140. )
  141. self.conn.commit()
  142. finally:
  143. cursor.close()
  144. def delete(self, cap_id: str):
  145. """删除原子能力"""
  146. cursor = self._get_cursor()
  147. try:
  148. cursor.execute("DELETE FROM atomic_capability WHERE id = %s", (cap_id,))
  149. self.conn.commit()
  150. finally:
  151. cursor.close()
  152. def count(self) -> int:
  153. """统计原子能力总数"""
  154. cursor = self._get_cursor()
  155. try:
  156. cursor.execute("SELECT COUNT(*) as count FROM atomic_capability")
  157. return cursor.fetchone()['count']
  158. finally:
  159. cursor.close()
  160. def _format_result(self, row: Dict) -> Dict:
  161. """格式化查询结果"""
  162. if not row:
  163. return None
  164. result = dict(row)
  165. for field in ('requirements', 'implements', 'tools', 'source_knowledge'):
  166. if field in result and isinstance(result[field], str):
  167. result[field] = json.loads(result[field])
  168. return result
  169. def close(self):
  170. if self.conn:
  171. self.conn.close()