""" PostgreSQL 存储封装(替代 Milvus) 使用远程 PostgreSQL + pgvector/fastann 存储知识数据 """ import os import json import psycopg2 from psycopg2.extras import RealDictCursor, execute_batch from typing import List, Dict, Optional from dotenv import load_dotenv load_dotenv() class PostgreSQLStore: def __init__(self): """初始化 PostgreSQL 连接""" self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = False print(f"[PostgreSQL] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}") def _get_cursor(self): """获取游标""" return self.conn.cursor(cursor_factory=RealDictCursor) def insert(self, knowledge: Dict): """插入单条知识""" cursor = self._get_cursor() try: cursor.execute(""" INSERT INTO knowledge ( id, task_embedding, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships, support_capability, tools ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, ( knowledge['id'], knowledge.get('task_embedding') or knowledge.get('embedding'), knowledge['message_id'], knowledge['task'], knowledge['content'], knowledge.get('types', []), json.dumps(knowledge.get('tags', {})), knowledge.get('tag_keys', []), knowledge.get('scopes', []), knowledge['owner'], knowledge.get('resource_ids', []), json.dumps(knowledge.get('source', {})), json.dumps(knowledge.get('eval', {})), knowledge['created_at'], knowledge['updated_at'], knowledge.get('status', 'approved'), json.dumps(knowledge.get('relationships', [])), json.dumps(knowledge.get('support_capability', [])), json.dumps(knowledge.get('tools', [])), )) self.conn.commit() finally: cursor.close() def search(self, query_embedding: List[float], filters: Optional[str] = None, limit: int = 10) -> List[Dict]: """向量检索(使用余弦相似度)""" cursor = self._get_cursor() try: where_clause = self._build_where_clause(filters) if filters else "" sql = f""" SELECT id, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships, support_capability, tools, 1 - (task_embedding <=> %s::real[]) as score FROM knowledge {where_clause} ORDER BY task_embedding <=> %s::real[] LIMIT %s """ cursor.execute(sql, (query_embedding, query_embedding, limit)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def query(self, filters: str, limit: int = 100) -> List[Dict]: """纯标量查询""" cursor = self._get_cursor() try: where_clause = self._build_where_clause(filters) sql = f""" SELECT id, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships, support_capability, tools FROM knowledge {where_clause} LIMIT %s """ cursor.execute(sql, (limit,)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def get_by_id(self, knowledge_id: str, include_embedding: bool = False) -> Optional[Dict]: """根据ID获取知识(默认不返回embedding以提升性能)""" cursor = self._get_cursor() try: # 默认不返回embedding(1536维向量太大,详情页不需要) if include_embedding: fields = "id, task_embedding, content_embedding, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships, support_capability, tools" else: fields = "id, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships, support_capability, tools" cursor.execute(f""" SELECT {fields} FROM knowledge WHERE id = %s """, (knowledge_id,)) result = cursor.fetchone() return self._format_result(result) if result else None finally: cursor.close() def update(self, knowledge_id: str, updates: Dict): """更新知识""" cursor = self._get_cursor() try: set_parts = [] params = [] for key, value in updates.items(): if key in ('tags', 'source', 'eval', 'support_capability', 'tools'): set_parts.append(f"{key} = %s") params.append(json.dumps(value)) elif key == 'relationships': set_parts.append(f"{key} = %s") params.append(json.dumps(value) if isinstance(value, list) else value) else: set_parts.append(f"{key} = %s") params.append(value) params.append(knowledge_id) sql = f"UPDATE knowledge SET {', '.join(set_parts)} WHERE id = %s" cursor.execute(sql, params) self.conn.commit() finally: cursor.close() def delete(self, knowledge_id: str): """删除知识""" cursor = self._get_cursor() try: cursor.execute("DELETE FROM knowledge WHERE id = %s", (knowledge_id,)) self.conn.commit() finally: cursor.close() def count(self) -> int: """返回知识总数""" cursor = self._get_cursor() try: cursor.execute("SELECT COUNT(*) as count FROM knowledge") return cursor.fetchone()['count'] finally: cursor.close() def _build_where_clause(self, filters: str) -> str: """将Milvus风格的过滤表达式转换为PostgreSQL WHERE子句""" if not filters: return "" where = filters import re # 替换操作符 where = where.replace(' == ', ' = ') where = where.replace(' or ', ' OR ') where = where.replace(' and ', ' AND ') # 处理数组包含操作 where = re.sub(r'array_contains\((\w+),\s*"([^"]+)"\)', r"\1 @> ARRAY['\2']", where) # 处理 eval["score"] 语法 where = where.replace('eval["score"]', "(eval->>'score')::int") # 把所有剩余的双引号字符串值替换为单引号(PostgreSQL标准) where = re.sub(r'"([^"]*)"', r"'\1'", where) return f"WHERE {where}" def _format_result(self, row: Dict) -> Dict: """格式化查询结果""" if not row: return None result = dict(row) if 'tags' in result and isinstance(result['tags'], str): result['tags'] = json.loads(result['tags']) if 'source' in result and isinstance(result['source'], str): result['source'] = json.loads(result['source']) if 'eval' in result and isinstance(result['eval'], str): result['eval'] = json.loads(result['eval']) if 'relationships' in result and isinstance(result['relationships'], str): result['relationships'] = json.loads(result['relationships']) if 'support_capability' in result and isinstance(result['support_capability'], str): result['support_capability'] = json.loads(result['support_capability']) if 'tools' in result and isinstance(result['tools'], str): result['tools'] = json.loads(result['tools']) if 'created_at' in result and result['created_at']: result['created_at'] = result['created_at'] * 1000 if 'updated_at' in result and result['updated_at']: result['updated_at'] = result['updated_at'] * 1000 return result def close(self): """关闭连接""" if self.conn: self.conn.close() def insert_batch(self, knowledge_list: List[Dict]): """批量插入知识""" if not knowledge_list: return cursor = self._get_cursor() try: data = [] for k in knowledge_list: data.append(( k['id'], k.get('task_embedding') or k.get('embedding'), k['message_id'], k['task'], k['content'], k.get('types', []), json.dumps(k.get('tags', {})), k.get('tag_keys', []), k.get('scopes', []), k['owner'], k.get('resource_ids', []), json.dumps(k.get('source', {})), json.dumps(k.get('eval', {})), k['created_at'], k['updated_at'], k.get('status', 'approved'), json.dumps(k.get('relationships', [])), json.dumps(k.get('support_capability', [])), json.dumps(k.get('tools', [])), )) execute_batch(cursor, """ INSERT INTO knowledge ( id, task_embedding, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships, support_capability, tools ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, data) self.conn.commit() finally: cursor.close()