pg_capability_store.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. """
  2. PostgreSQL atomic_capability 存储封装
  3. 用于存储和检索原子能力数据,支持向量检索
  4. """
  5. import os
  6. import json
  7. import psycopg2
  8. from psycopg2.extras import RealDictCursor
  9. from typing import List, Dict, Optional
  10. from dotenv import load_dotenv
  11. load_dotenv()
  12. class PostgreSQLCapabilityStore:
  13. def __init__(self):
  14. """初始化 PostgreSQL 连接"""
  15. self.conn = psycopg2.connect(
  16. host=os.getenv('KNOWHUB_DB'),
  17. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  18. user=os.getenv('KNOWHUB_USER'),
  19. password=os.getenv('KNOWHUB_PASSWORD'),
  20. database=os.getenv('KNOWHUB_DB_NAME')
  21. )
  22. self.conn.autocommit = False
  23. print(f"[PostgreSQL Capability] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  24. def _get_cursor(self):
  25. return self.conn.cursor(cursor_factory=RealDictCursor)
  26. def insert_or_update(self, cap: Dict):
  27. """插入或更新原子能力"""
  28. cursor = self._get_cursor()
  29. try:
  30. cursor.execute("""
  31. INSERT INTO atomic_capability (
  32. id, name, criterion, description, requirements,
  33. implements, tools, source_knowledge, embedding
  34. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
  35. ON CONFLICT (id) DO UPDATE SET
  36. name = EXCLUDED.name,
  37. criterion = EXCLUDED.criterion,
  38. description = EXCLUDED.description,
  39. requirements = EXCLUDED.requirements,
  40. implements = EXCLUDED.implements,
  41. tools = EXCLUDED.tools,
  42. source_knowledge = EXCLUDED.source_knowledge,
  43. embedding = EXCLUDED.embedding
  44. """, (
  45. cap['id'],
  46. cap.get('name', ''),
  47. cap.get('criterion', ''),
  48. cap.get('description', ''),
  49. json.dumps(cap.get('requirements', [])),
  50. json.dumps(cap.get('implements', {})),
  51. json.dumps(cap.get('tools', [])),
  52. json.dumps(cap.get('source_knowledge', [])),
  53. cap.get('embedding'),
  54. ))
  55. self.conn.commit()
  56. finally:
  57. cursor.close()
  58. def get_by_id(self, cap_id: str) -> Optional[Dict]:
  59. """根据 ID 获取原子能力"""
  60. cursor = self._get_cursor()
  61. try:
  62. cursor.execute("""
  63. SELECT id, name, criterion, description, requirements,
  64. implements, tools, source_knowledge
  65. FROM atomic_capability WHERE id = %s
  66. """, (cap_id,))
  67. result = cursor.fetchone()
  68. return self._format_result(result) if result else None
  69. finally:
  70. cursor.close()
  71. def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]:
  72. """向量检索原子能力"""
  73. cursor = self._get_cursor()
  74. try:
  75. cursor.execute("""
  76. SELECT id, name, criterion, description, requirements,
  77. implements, tools, source_knowledge,
  78. 1 - (embedding <=> %s::real[]) as score
  79. FROM atomic_capability
  80. WHERE embedding IS NOT NULL
  81. ORDER BY embedding <=> %s::real[]
  82. LIMIT %s
  83. """, (query_embedding, query_embedding, limit))
  84. results = cursor.fetchall()
  85. return [self._format_result(r) for r in results]
  86. finally:
  87. cursor.close()
  88. def list_all(self, limit: int = 100, offset: int = 0) -> List[Dict]:
  89. """列出原子能力"""
  90. cursor = self._get_cursor()
  91. try:
  92. cursor.execute("""
  93. SELECT id, name, criterion, description, requirements,
  94. implements, tools, source_knowledge
  95. FROM atomic_capability
  96. ORDER BY id
  97. LIMIT %s OFFSET %s
  98. """, (limit, offset))
  99. results = cursor.fetchall()
  100. return [self._format_result(r) for r in results]
  101. finally:
  102. cursor.close()
  103. def update(self, cap_id: str, updates: Dict):
  104. """更新原子能力字段"""
  105. cursor = self._get_cursor()
  106. try:
  107. set_parts = []
  108. params = []
  109. json_fields = ('requirements', 'implements', 'tools', 'source_knowledge')
  110. for key, value in updates.items():
  111. set_parts.append(f"{key} = %s")
  112. if key in json_fields:
  113. params.append(json.dumps(value))
  114. else:
  115. params.append(value)
  116. params.append(cap_id)
  117. cursor.execute(
  118. f"UPDATE atomic_capability SET {', '.join(set_parts)} WHERE id = %s",
  119. params
  120. )
  121. self.conn.commit()
  122. finally:
  123. cursor.close()
  124. def delete(self, cap_id: str):
  125. """删除原子能力"""
  126. cursor = self._get_cursor()
  127. try:
  128. cursor.execute("DELETE FROM atomic_capability WHERE id = %s", (cap_id,))
  129. self.conn.commit()
  130. finally:
  131. cursor.close()
  132. def count(self) -> int:
  133. """统计原子能力总数"""
  134. cursor = self._get_cursor()
  135. try:
  136. cursor.execute("SELECT COUNT(*) as count FROM atomic_capability")
  137. return cursor.fetchone()['count']
  138. finally:
  139. cursor.close()
  140. def _format_result(self, row: Dict) -> Dict:
  141. """格式化查询结果"""
  142. if not row:
  143. return None
  144. result = dict(row)
  145. for field in ('requirements', 'implements', 'tools', 'source_knowledge'):
  146. if field in result and isinstance(result[field], str):
  147. result[field] = json.loads(result[field])
  148. return result
  149. def close(self):
  150. if self.conn:
  151. self.conn.close()