| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- """
- PostgreSQL atomic_capability 存储封装
- 用于存储和检索原子能力数据,支持向量检索
- """
- import os
- import json
- import psycopg2
- from psycopg2.extras import RealDictCursor
- from typing import List, Dict, Optional
- from dotenv import load_dotenv
- load_dotenv()
- class PostgreSQLCapabilityStore:
- def __init__(self):
- """初始化 PostgreSQL 连接"""
- self.conn = psycopg2.connect(
- host=os.getenv('KNOWHUB_DB'),
- port=int(os.getenv('KNOWHUB_PORT', 5432)),
- user=os.getenv('KNOWHUB_USER'),
- password=os.getenv('KNOWHUB_PASSWORD'),
- database=os.getenv('KNOWHUB_DB_NAME')
- )
- self.conn.autocommit = False
- print(f"[PostgreSQL Capability] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
- def _get_cursor(self):
- return self.conn.cursor(cursor_factory=RealDictCursor)
- def insert_or_update(self, cap: Dict):
- """插入或更新原子能力"""
- cursor = self._get_cursor()
- try:
- cursor.execute("""
- INSERT INTO atomic_capability (
- id, name, criterion, description, requirements,
- implements, tools, source_knowledge, embedding
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
- ON CONFLICT (id) DO UPDATE SET
- name = EXCLUDED.name,
- criterion = EXCLUDED.criterion,
- description = EXCLUDED.description,
- requirements = EXCLUDED.requirements,
- implements = EXCLUDED.implements,
- tools = EXCLUDED.tools,
- source_knowledge = EXCLUDED.source_knowledge,
- embedding = EXCLUDED.embedding
- """, (
- cap['id'],
- cap.get('name', ''),
- cap.get('criterion', ''),
- cap.get('description', ''),
- json.dumps(cap.get('requirements', [])),
- json.dumps(cap.get('implements', {})),
- json.dumps(cap.get('tools', [])),
- json.dumps(cap.get('source_knowledge', [])),
- cap.get('embedding'),
- ))
- self.conn.commit()
- finally:
- cursor.close()
- def get_by_id(self, cap_id: str) -> Optional[Dict]:
- """根据 ID 获取原子能力"""
- cursor = self._get_cursor()
- try:
- cursor.execute("""
- SELECT id, name, criterion, description, requirements,
- implements, tools, source_knowledge
- FROM atomic_capability WHERE id = %s
- """, (cap_id,))
- result = cursor.fetchone()
- return self._format_result(result) if result else None
- finally:
- cursor.close()
- def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]:
- """向量检索原子能力"""
- cursor = self._get_cursor()
- try:
- cursor.execute("""
- SELECT id, name, criterion, description, requirements,
- implements, tools, source_knowledge,
- 1 - (embedding <=> %s::real[]) as score
- FROM atomic_capability
- WHERE embedding IS NOT NULL
- ORDER BY embedding <=> %s::real[]
- LIMIT %s
- """, (query_embedding, query_embedding, limit))
- results = cursor.fetchall()
- return [self._format_result(r) for r in results]
- finally:
- cursor.close()
- def list_all(self, limit: int = 100, offset: int = 0) -> List[Dict]:
- """列出原子能力"""
- cursor = self._get_cursor()
- try:
- cursor.execute("""
- SELECT id, name, criterion, description, requirements,
- implements, tools, source_knowledge
- FROM atomic_capability
- ORDER BY id
- LIMIT %s OFFSET %s
- """, (limit, offset))
- results = cursor.fetchall()
- return [self._format_result(r) for r in results]
- finally:
- cursor.close()
- def update(self, cap_id: str, updates: Dict):
- """更新原子能力字段"""
- cursor = self._get_cursor()
- try:
- set_parts = []
- params = []
- json_fields = ('requirements', 'implements', 'tools', 'source_knowledge')
- for key, value in updates.items():
- set_parts.append(f"{key} = %s")
- if key in json_fields:
- params.append(json.dumps(value))
- else:
- params.append(value)
- params.append(cap_id)
- cursor.execute(
- f"UPDATE atomic_capability SET {', '.join(set_parts)} WHERE id = %s",
- params
- )
- self.conn.commit()
- finally:
- cursor.close()
- def delete(self, cap_id: str):
- """删除原子能力"""
- cursor = self._get_cursor()
- try:
- cursor.execute("DELETE FROM atomic_capability WHERE id = %s", (cap_id,))
- self.conn.commit()
- finally:
- cursor.close()
- def count(self) -> int:
- """统计原子能力总数"""
- cursor = self._get_cursor()
- try:
- cursor.execute("SELECT COUNT(*) as count FROM atomic_capability")
- return cursor.fetchone()['count']
- finally:
- cursor.close()
- def _format_result(self, row: Dict) -> Dict:
- """格式化查询结果"""
- if not row:
- return None
- result = dict(row)
- for field in ('requirements', 'implements', 'tools', 'source_knowledge'):
- if field in result and isinstance(result[field], str):
- result[field] = json.loads(result[field])
- return result
- def close(self):
- if self.conn:
- self.conn.close()
|