| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- """
- PostgreSQL 存储封装(替代 Milvus)
- 使用远程 PostgreSQL + pgvector/fastann 存储知识数据
- """
- import os
- import json
- import psycopg2
- from psycopg2.extras import RealDictCursor, execute_batch
- from typing import List, Dict, Optional
- from dotenv import load_dotenv
- load_dotenv()
- class PostgreSQLStore:
- def __init__(self):
- """初始化 PostgreSQL 连接"""
- self.conn = psycopg2.connect(
- host=os.getenv('KNOWHUB_DB'),
- port=int(os.getenv('KNOWHUB_PORT', 5432)),
- user=os.getenv('KNOWHUB_USER'),
- password=os.getenv('KNOWHUB_PASSWORD'),
- database=os.getenv('KNOWHUB_DB_NAME')
- )
- self.conn.autocommit = False
- print(f"[PostgreSQL] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
- def _get_cursor(self):
- """获取游标"""
- return self.conn.cursor(cursor_factory=RealDictCursor)
- def insert(self, knowledge: Dict):
- """插入单条知识"""
- cursor = self._get_cursor()
- try:
- cursor.execute("""
- INSERT INTO knowledge (
- id, embedding, message_id, task, content, types, tags,
- tag_keys, scopes, owner, resource_ids, source, eval,
- created_at, updated_at, status, relationships
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
- """, (
- knowledge['id'],
- knowledge['embedding'],
- knowledge['message_id'],
- knowledge['task'],
- knowledge['content'],
- knowledge.get('types', []),
- json.dumps(knowledge.get('tags', {})),
- knowledge.get('tag_keys', []),
- knowledge.get('scopes', []),
- knowledge['owner'],
- knowledge.get('resource_ids', []),
- json.dumps(knowledge.get('source', {})),
- json.dumps(knowledge.get('eval', {})),
- knowledge['created_at'],
- knowledge['updated_at'],
- knowledge.get('status', 'approved'),
- json.dumps(knowledge.get('relationships', []))
- ))
- self.conn.commit()
- finally:
- cursor.close()
- def search(self, query_embedding: List[float], filters: Optional[str] = None, limit: int = 10) -> List[Dict]:
- """向量检索(使用余弦相似度)"""
- cursor = self._get_cursor()
- try:
- where_clause = self._build_where_clause(filters) if filters else ""
- sql = f"""
- SELECT id, message_id, task, content, types, tags, tag_keys,
- scopes, owner, resource_ids, source, eval, created_at,
- updated_at, status, relationships,
- 1 - (embedding <=> %s::real[]) as score
- FROM knowledge
- {where_clause}
- ORDER BY embedding <=> %s::real[]
- LIMIT %s
- """
- cursor.execute(sql, (query_embedding, query_embedding, limit))
- results = cursor.fetchall()
- return [self._format_result(r) for r in results]
- finally:
- cursor.close()
- def query(self, filters: str, limit: int = 100) -> List[Dict]:
- """纯标量查询"""
- cursor = self._get_cursor()
- try:
- where_clause = self._build_where_clause(filters)
- sql = f"""
- SELECT id, message_id, task, content, types, tags, tag_keys,
- scopes, owner, resource_ids, source, eval, created_at,
- updated_at, status, relationships
- FROM knowledge
- {where_clause}
- LIMIT %s
- """
- cursor.execute(sql, (limit,))
- results = cursor.fetchall()
- return [self._format_result(r) for r in results]
- finally:
- cursor.close()
- def get_by_id(self, knowledge_id: str, include_embedding: bool = False) -> Optional[Dict]:
- """根据ID获取知识(默认不返回embedding以提升性能)"""
- cursor = self._get_cursor()
- try:
- # 默认不返回embedding(1536维向量太大,详情页不需要)
- if include_embedding:
- fields = "id, embedding, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships"
- else:
- fields = "id, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships"
- cursor.execute(f"""
- SELECT {fields}
- FROM knowledge WHERE id = %s
- """, (knowledge_id,))
- result = cursor.fetchone()
- return self._format_result(result) if result else None
- finally:
- cursor.close()
- def update(self, knowledge_id: str, updates: Dict):
- """更新知识"""
- cursor = self._get_cursor()
- try:
- set_parts = []
- params = []
- for key, value in updates.items():
- if key in ('tags', 'source', 'eval'):
- set_parts.append(f"{key} = %s")
- params.append(json.dumps(value))
- elif key == 'relationships':
- set_parts.append(f"{key} = %s")
- params.append(json.dumps(value) if isinstance(value, list) else value)
- else:
- set_parts.append(f"{key} = %s")
- params.append(value)
- params.append(knowledge_id)
- sql = f"UPDATE knowledge SET {', '.join(set_parts)} WHERE id = %s"
- cursor.execute(sql, params)
- self.conn.commit()
- finally:
- cursor.close()
- def delete(self, knowledge_id: str):
- """删除知识"""
- cursor = self._get_cursor()
- try:
- cursor.execute("DELETE FROM knowledge WHERE id = %s", (knowledge_id,))
- self.conn.commit()
- finally:
- cursor.close()
- def count(self) -> int:
- """返回知识总数"""
- cursor = self._get_cursor()
- try:
- cursor.execute("SELECT COUNT(*) as count FROM knowledge")
- return cursor.fetchone()['count']
- finally:
- cursor.close()
- def _build_where_clause(self, filters: str) -> str:
- """将Milvus风格的过滤表达式转换为PostgreSQL WHERE子句"""
- if not filters:
- return ""
- where = filters
- import re
- # 替换操作符
- where = where.replace(' == ', ' = ')
- where = where.replace(' or ', ' OR ')
- where = where.replace(' and ', ' AND ')
- # 处理数组包含操作
- where = re.sub(r'array_contains\((\w+),\s*"([^"]+)"\)', r"\1 @> ARRAY['\2']", where)
- # 处理 eval["score"] 语法
- where = where.replace('eval["score"]', "(eval->>'score')::int")
- # 把所有剩余的双引号字符串值替换为单引号(PostgreSQL标准)
- where = re.sub(r'"([^"]*)"', r"'\1'", where)
- return f"WHERE {where}"
- def _format_result(self, row: Dict) -> Dict:
- """格式化查询结果"""
- if not row:
- return None
- result = dict(row)
- if 'tags' in result and isinstance(result['tags'], str):
- result['tags'] = json.loads(result['tags'])
- if 'source' in result and isinstance(result['source'], str):
- result['source'] = json.loads(result['source'])
- if 'eval' in result and isinstance(result['eval'], str):
- result['eval'] = json.loads(result['eval'])
- if 'relationships' in result and isinstance(result['relationships'], str):
- result['relationships'] = json.loads(result['relationships'])
- if 'created_at' in result and result['created_at']:
- result['created_at'] = result['created_at'] * 1000
- if 'updated_at' in result and result['updated_at']:
- result['updated_at'] = result['updated_at'] * 1000
- return result
- def close(self):
- """关闭连接"""
- if self.conn:
- self.conn.close()
- def insert_batch(self, knowledge_list: List[Dict]):
- """批量插入知识"""
- if not knowledge_list:
- return
- cursor = self._get_cursor()
- try:
- data = []
- for k in knowledge_list:
- data.append((
- k['id'], k['embedding'], k['message_id'], k['task'],
- k['content'], k.get('types', []),
- json.dumps(k.get('tags', {})), k.get('tag_keys', []),
- k.get('scopes', []), k['owner'], k.get('resource_ids', []),
- json.dumps(k.get('source', {})), json.dumps(k.get('eval', {})),
- k['created_at'], k['updated_at'], k.get('status', 'approved'),
- json.dumps(k.get('relationships', []))
- ))
- execute_batch(cursor, """
- INSERT INTO knowledge (
- id, embedding, message_id, task, content, types, tags,
- tag_keys, scopes, owner, resource_ids, source, eval,
- created_at, updated_at, status, relationships
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
- """, data)
- self.conn.commit()
- finally:
- cursor.close()
|