| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297 |
- """
- PostgreSQL tool 存储封装
- 用于存储和检索工具数据,支持向量检索。
- 表名:tool(从 tool_table 迁移)
- """
- import os
- import json
- import psycopg2
- from psycopg2.extras import RealDictCursor
- from typing import List, Dict, Optional
- from dotenv import load_dotenv
- from knowhub.knowhub_db.cascade import cascade_delete
- load_dotenv()
- # 关联字段子查询
- _REL_SUBQUERIES = """
- (SELECT COALESCE(json_agg(ct.capability_id), '[]'::json)
- FROM capability_tool ct WHERE ct.tool_id = tool.id) AS capability_ids,
- (SELECT COALESCE(json_agg(tk.knowledge_id), '[]'::json)
- FROM tool_knowledge tk WHERE tk.tool_id = tool.id) AS knowledge_ids,
- (SELECT COALESCE(json_agg(tp.provider_id), '[]'::json)
- FROM tool_provider tp WHERE tp.tool_id = tool.id) AS provider_ids
- """
- _BASE_FIELDS = "id, name, version, introduction, tutorial, input, output, updated_time, status"
- _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}"
- class PostgreSQLToolStore:
- def __init__(self):
- """初始化 PostgreSQL 连接"""
- self.conn = psycopg2.connect(
- host=os.getenv('KNOWHUB_DB'),
- port=int(os.getenv('KNOWHUB_PORT', 5432)),
- user=os.getenv('KNOWHUB_USER'),
- password=os.getenv('KNOWHUB_PASSWORD'),
- database=os.getenv('KNOWHUB_DB_NAME')
- )
- self.conn.autocommit = False
- print(f"[PostgreSQL Tool] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
- def _reconnect(self):
- self.conn = psycopg2.connect(
- host=os.getenv('KNOWHUB_DB'),
- port=int(os.getenv('KNOWHUB_PORT', 5432)),
- user=os.getenv('KNOWHUB_USER'),
- password=os.getenv('KNOWHUB_PASSWORD'),
- database=os.getenv('KNOWHUB_DB_NAME')
- )
- self.conn.autocommit = False
- def _ensure_connection(self):
- if self.conn.closed != 0:
- self._reconnect()
- else:
- try:
- c = self.conn.cursor()
- c.execute("SELECT 1")
- c.close()
- except (psycopg2.OperationalError, psycopg2.InterfaceError):
- self._reconnect()
- def _get_cursor(self):
- self._ensure_connection()
- return self.conn.cursor(cursor_factory=RealDictCursor)
- def insert_or_update(self, tool: Dict):
- """插入或更新工具"""
- cursor = self._get_cursor()
- try:
- cursor.execute("""
- INSERT INTO tool (
- id, name, version, introduction, tutorial, input, output,
- updated_time, status, embedding
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
- ON CONFLICT (id) DO UPDATE SET
- name = EXCLUDED.name,
- version = EXCLUDED.version,
- introduction = EXCLUDED.introduction,
- tutorial = EXCLUDED.tutorial,
- input = EXCLUDED.input,
- output = EXCLUDED.output,
- updated_time = EXCLUDED.updated_time,
- status = EXCLUDED.status,
- embedding = EXCLUDED.embedding
- """, (
- tool['id'],
- tool.get('name', ''),
- tool.get('version'),
- tool.get('introduction', ''),
- tool.get('tutorial', ''),
- json.dumps(tool.get('input', '')),
- json.dumps(tool.get('output', '')),
- tool.get('updated_time', 0),
- tool.get('status', '未接入'),
- tool.get('embedding'),
- ))
- # 写入关联表
- tool_id = tool['id']
- if 'capability_ids' in tool:
- cursor.execute("DELETE FROM capability_tool WHERE tool_id = %s", (tool_id,))
- for cap_id in tool['capability_ids']:
- cursor.execute(
- "INSERT INTO capability_tool (capability_id, tool_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (cap_id, tool_id))
- if 'knowledge_ids' in tool:
- cursor.execute("DELETE FROM tool_knowledge WHERE tool_id = %s", (tool_id,))
- for kid in tool['knowledge_ids']:
- cursor.execute(
- "INSERT INTO tool_knowledge (tool_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (tool_id, kid))
- if 'provider_ids' in tool:
- cursor.execute("DELETE FROM tool_provider WHERE tool_id = %s", (tool_id,))
- for pid in tool['provider_ids']:
- cursor.execute(
- "INSERT INTO tool_provider (tool_id, provider_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (tool_id, pid))
- self.conn.commit()
- finally:
- cursor.close()
- def get_by_id(self, tool_id: str) -> Optional[Dict]:
- """根据 ID 获取工具"""
- cursor = self._get_cursor()
- try:
- cursor.execute(f"""
- SELECT {_SELECT_FIELDS}
- FROM tool WHERE id = %s
- """, (tool_id,))
- result = cursor.fetchone()
- return self._format_result(result) if result else None
- finally:
- cursor.close()
- def search(self, query_embedding: List[float], limit: int = 10, status: Optional[str] = None) -> List[Dict]:
- """向量检索工具"""
- cursor = self._get_cursor()
- try:
- if status:
- sql = f"""
- SELECT {_SELECT_FIELDS},
- 1 - (embedding <=> %s::real[]) as score
- FROM tool
- WHERE embedding IS NOT NULL AND status = %s
- ORDER BY embedding <=> %s::real[]
- LIMIT %s
- """
- params = [query_embedding, status, query_embedding, limit]
- else:
- sql = f"""
- SELECT {_SELECT_FIELDS},
- 1 - (embedding <=> %s::real[]) as score
- FROM tool
- WHERE embedding IS NOT NULL
- ORDER BY embedding <=> %s::real[]
- LIMIT %s
- """
- params = [query_embedding, query_embedding, limit]
- cursor.execute(sql, params)
- results = cursor.fetchall()
- return [self._format_result(r) for r in results]
- finally:
- cursor.close()
- def list_all(self, limit: int = 100, offset: int = 0, status: Optional[str] = None) -> List[Dict]:
- """列出工具"""
- cursor = self._get_cursor()
- try:
- if status:
- cursor.execute(f"""
- SELECT {_SELECT_FIELDS}
- FROM tool
- WHERE status = %s
- ORDER BY updated_time DESC
- LIMIT %s OFFSET %s
- """, (status, limit, offset))
- else:
- cursor.execute(f"""
- SELECT {_SELECT_FIELDS}
- FROM tool
- ORDER BY updated_time DESC
- LIMIT %s OFFSET %s
- """, (limit, offset))
- results = cursor.fetchall()
- return [self._format_result(r) for r in results]
- finally:
- cursor.close()
- def update(self, tool_id: str, updates: Dict):
- """更新工具字段"""
- cursor = self._get_cursor()
- try:
- # 分离关联字段
- cap_ids = updates.pop('capability_ids', None)
- knowledge_ids = updates.pop('knowledge_ids', None)
- provider_ids = updates.pop('provider_ids', None)
- if updates:
- set_parts = []
- params = []
- json_fields = ('input', 'output')
- for key, value in updates.items():
- set_parts.append(f"{key} = %s")
- if key in json_fields:
- params.append(json.dumps(value))
- else:
- params.append(value)
- params.append(tool_id)
- cursor.execute(
- f"UPDATE tool SET {', '.join(set_parts)} WHERE id = %s",
- params
- )
- # 更新关联表
- if cap_ids is not None:
- cursor.execute("DELETE FROM capability_tool WHERE tool_id = %s", (tool_id,))
- for cap_id in cap_ids:
- cursor.execute(
- "INSERT INTO capability_tool (capability_id, tool_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (cap_id, tool_id))
- if knowledge_ids is not None:
- cursor.execute("DELETE FROM tool_knowledge WHERE tool_id = %s", (tool_id,))
- for kid in knowledge_ids:
- cursor.execute(
- "INSERT INTO tool_knowledge (tool_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (tool_id, kid))
- if provider_ids is not None:
- cursor.execute("DELETE FROM tool_provider WHERE tool_id = %s", (tool_id,))
- for pid in provider_ids:
- cursor.execute(
- "INSERT INTO tool_provider (tool_id, provider_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (tool_id, pid))
- self.conn.commit()
- finally:
- cursor.close()
- def add_knowledge(self, tool_id: str, knowledge_id: str):
- """向工具添加一条知识关联(不删除已有关联)"""
- cursor = self._get_cursor()
- try:
- cursor.execute(
- "INSERT INTO tool_knowledge (tool_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (tool_id, knowledge_id))
- self.conn.commit()
- finally:
- cursor.close()
- def delete(self, tool_id: str):
- """删除工具及其关联表记录"""
- cursor = self._get_cursor()
- try:
- cascade_delete(cursor, 'tool', tool_id)
- self.conn.commit()
- finally:
- cursor.close()
- def count(self, status: Optional[str] = None) -> int:
- """统计工具总数"""
- cursor = self._get_cursor()
- try:
- if status:
- cursor.execute("SELECT COUNT(*) as count FROM tool WHERE status = %s", (status,))
- else:
- cursor.execute("SELECT COUNT(*) as count FROM tool")
- return cursor.fetchone()['count']
- finally:
- cursor.close()
- def _format_result(self, row: Dict) -> Dict:
- """格式化查询结果,将 JSON 字符串解析为对象"""
- if not row:
- return None
- result = dict(row)
- for field in ('input', 'output'):
- if field in result and isinstance(result[field], str):
- try:
- result[field] = json.loads(result[field]) if result[field].strip() else None
- except json.JSONDecodeError:
- result[field] = None
- # 关联字段(来自 junction table 子查询)
- for field in ('capability_ids', 'knowledge_ids', 'provider_ids'):
- if field in result and isinstance(result[field], str):
- result[field] = json.loads(result[field])
- elif field in result and result[field] is None:
- result[field] = []
- return result
- def close(self):
- if self.conn:
- self.conn.close()
|