""" PostgreSQL 存储封装(替代 Milvus) 使用远程 PostgreSQL + pgvector/fastann 存储知识数据 """ import os import json import psycopg2 from psycopg2.extras import RealDictCursor, execute_batch from typing import List, Dict, Optional from dotenv import load_dotenv from knowhub.knowhub_db.cascade import cascade_delete load_dotenv() # 关联字段的子查询(从 junction table 读取) # 对于带 relation_type 的 *_knowledge 边,同时暴露两种视图: # - *_ids : 扁平 ID 列表(向后兼容,不含 type) # - *_links : [{id, relation_type}](含 type) _REL_SUBQUERIES = """ (SELECT COALESCE(json_agg(rk.requirement_id), '[]'::json) FROM requirement_knowledge rk WHERE rk.knowledge_id = knowledge.id) AS requirement_ids, (SELECT COALESCE(json_agg(json_build_object( 'id', rk2.requirement_id, 'relation_type', rk2.relation_type )), '[]'::json) FROM requirement_knowledge rk2 WHERE rk2.knowledge_id = knowledge.id) AS requirement_links, (SELECT COALESCE(json_agg(ck.capability_id), '[]'::json) FROM capability_knowledge ck WHERE ck.knowledge_id = knowledge.id) AS capability_ids, (SELECT COALESCE(json_agg(json_build_object( 'id', ck2.capability_id, 'relation_type', ck2.relation_type )), '[]'::json) FROM capability_knowledge ck2 WHERE ck2.knowledge_id = knowledge.id) AS capability_links, (SELECT COALESCE(json_agg(tk.tool_id), '[]'::json) FROM tool_knowledge tk WHERE tk.knowledge_id = knowledge.id) AS tool_ids, (SELECT COALESCE(json_agg(json_build_object( 'id', tk2.tool_id, 'relation_type', tk2.relation_type )), '[]'::json) FROM tool_knowledge tk2 WHERE tk2.knowledge_id = knowledge.id) AS tool_links, (SELECT COALESCE(json_agg(kr.resource_id), '[]'::json) FROM knowledge_resource kr WHERE kr.knowledge_id = knowledge.id) AS resource_ids, (SELECT COALESCE(json_agg(json_build_object( 'target_id', krel.target_id, 'relation_type', krel.relation_type )), '[]'::json) FROM knowledge_relation krel WHERE krel.source_id = knowledge.id) AS relations """ # 基础字段(不含 embedding) _BASE_FIELDS = ( "id, message_id, task, content, types, tags, tag_keys, " "scopes, owner, source, eval, " "created_at, updated_at, status" ) # 完整 SELECT(含关联子查询) _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}" # 含 embedding 的 SELECT _SELECT_FIELDS_WITH_EMB = f"task_embedding, content_embedding, {_SELECT_FIELDS}" def _normalize_links(data: Dict, links_key: str, ids_key: str, default_type: str): """ 统一两种输入格式: - {links_key: [{id, relation_type}, ...]} → 使用指定 type - {ids_key: [id1, id2, ...]} → 使用 default_type 两个 key 都没有返回 None(不更新) """ if links_key in data and data[links_key] is not None: out = [] for item in data[links_key]: if isinstance(item, dict): out.append((item['id'], item.get('relation_type', default_type))) else: out.append((item, default_type)) return out if ids_key in data and data[ids_key] is not None: return [(i, default_type) for i in data[ids_key]] return None class PostgreSQLStore: def __init__(self): """初始化 PostgreSQL 连接""" self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = True print(f"[PostgreSQL] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}") def _reconnect(self): self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = True def _ensure_connection(self): if self.conn.closed != 0: self._reconnect() else: try: c = self.conn.cursor() c.execute("SELECT 1") c.close() except (psycopg2.OperationalError, psycopg2.InterfaceError): self._reconnect() def _get_cursor(self): """获取游标""" self._ensure_connection() return self.conn.cursor(cursor_factory=RealDictCursor) def insert(self, knowledge: Dict): """插入单条知识""" cursor = self._get_cursor() try: cursor.execute(""" INSERT INTO knowledge ( id, task_embedding, content_embedding, message_id, task, content, types, tags, tag_keys, scopes, owner, source, eval, created_at, updated_at, status ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, ( knowledge['id'], knowledge.get('task_embedding') or knowledge.get('embedding'), knowledge.get('content_embedding'), knowledge['message_id'], knowledge['task'], knowledge['content'], knowledge.get('types', []), json.dumps(knowledge.get('tags', {})), knowledge.get('tag_keys', []), knowledge.get('scopes', []), knowledge['owner'], json.dumps(knowledge.get('source', {})), json.dumps(knowledge.get('eval', {})), knowledge['created_at'], knowledge['updated_at'], knowledge.get('status', 'approved'), )) # 写入关联表 kid = knowledge['id'] req_links = _normalize_links(knowledge, 'requirement_links', 'requirement_ids', 'related') or [] for req_id, rtype in req_links: cursor.execute( "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (req_id, kid, rtype)) cap_links = _normalize_links(knowledge, 'capability_links', 'capability_ids', 'related') or [] for cap_id, rtype in cap_links: cursor.execute( "INSERT INTO capability_knowledge (capability_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (cap_id, kid, rtype)) tool_links = _normalize_links(knowledge, 'tool_links', 'tool_ids', 'related') or [] for tool_id, rtype in tool_links: cursor.execute( "INSERT INTO tool_knowledge (tool_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (tool_id, kid, rtype)) for res_id in knowledge.get('resource_ids', []): cursor.execute( "INSERT INTO knowledge_resource (knowledge_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (kid, res_id)) self.conn.commit() finally: cursor.close() def _apply_relation_filters(self, where_clause: str, relation_filters: Optional[Dict[str, str]], params: list) -> str: if not relation_filters: return where_clause rel_clauses = [] for k, v in relation_filters.items(): if not v: continue if k == 'requirement_id': rel_clauses.append("EXISTS (SELECT 1 FROM requirement_knowledge rk WHERE rk.knowledge_id = knowledge.id AND rk.requirement_id = %s)") params.append(v) elif k == 'capability_id': rel_clauses.append("EXISTS (SELECT 1 FROM capability_knowledge ck WHERE ck.knowledge_id = knowledge.id AND ck.capability_id = %s)") params.append(v) elif k == 'tool_id': rel_clauses.append("EXISTS (SELECT 1 FROM tool_knowledge tk WHERE tk.knowledge_id = knowledge.id AND tk.tool_id = %s)") params.append(v) if not rel_clauses: return where_clause rel_where = " AND ".join(rel_clauses) if where_clause.strip(): return f"{where_clause} AND {rel_where}" else: return f"WHERE {rel_where}" def search(self, query_embedding: List[float], filters: Optional[str] = None, limit: int = 10, relation_filters: Optional[Dict[str, str]] = None) -> List[Dict]: """向量检索(使用余弦相似度)""" cursor = self._get_cursor() try: where_clause = self._build_where_clause(filters) if filters else "" params = [] where_clause = self._apply_relation_filters(where_clause, relation_filters, params) sql = f""" SELECT {_SELECT_FIELDS}, 1 - (task_embedding <=> %s::real[]) as score FROM knowledge {where_clause} ORDER BY task_embedding <=> %s::real[] LIMIT %s """ final_params = [query_embedding] + params + [query_embedding, limit] cursor.execute(sql, tuple(final_params)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def query(self, filters: str, limit: int = 100, relation_filters: Optional[Dict[str, str]] = None) -> List[Dict]: """纯标量查询""" cursor = self._get_cursor() try: where_clause = self._build_where_clause(filters) if filters else "" params = [] where_clause = self._apply_relation_filters(where_clause, relation_filters, params) sql = f""" SELECT {_SELECT_FIELDS} FROM knowledge {where_clause} LIMIT %s """ final_params = params + [limit] cursor.execute(sql, tuple(final_params)) results = cursor.fetchall() return [self._format_result(r) for r in results] finally: cursor.close() def get_by_id(self, knowledge_id: str, include_embedding: bool = False) -> Optional[Dict]: """根据ID获取知识(默认不返回embedding以提升性能)""" cursor = self._get_cursor() try: fields = _SELECT_FIELDS_WITH_EMB if include_embedding else _SELECT_FIELDS cursor.execute(f""" SELECT {fields} FROM knowledge WHERE id = %s """, (knowledge_id,)) result = cursor.fetchone() return self._format_result(result) if result else None finally: cursor.close() def update(self, knowledge_id: str, updates: Dict): """更新知识""" cursor = self._get_cursor() try: # 分离关联字段和实体字段 rel_keys = ('requirement_ids', 'requirement_links', 'capability_ids', 'capability_links', 'tool_ids', 'tool_links', 'resource_ids') rel_data = {k: updates.pop(k) for k in rel_keys if k in updates} if updates: set_parts = [] params = [] for key, value in updates.items(): if key in ('tags', 'source', 'eval'): set_parts.append(f"{key} = %s") params.append(json.dumps(value)) else: set_parts.append(f"{key} = %s") params.append(value) params.append(knowledge_id) sql = f"UPDATE knowledge SET {', '.join(set_parts)} WHERE id = %s" cursor.execute(sql, params) # 更新关联表(全量替换) req_links = _normalize_links(rel_data, 'requirement_links', 'requirement_ids', 'related') if req_links is not None: cursor.execute("DELETE FROM requirement_knowledge WHERE knowledge_id = %s", (knowledge_id,)) for req_id, rtype in req_links: cursor.execute( "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (req_id, knowledge_id, rtype)) cap_links = _normalize_links(rel_data, 'capability_links', 'capability_ids', 'related') if cap_links is not None: cursor.execute("DELETE FROM capability_knowledge WHERE knowledge_id = %s", (knowledge_id,)) for cap_id, rtype in cap_links: cursor.execute( "INSERT INTO capability_knowledge (capability_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (cap_id, knowledge_id, rtype)) tool_links = _normalize_links(rel_data, 'tool_links', 'tool_ids', 'related') if tool_links is not None: cursor.execute("DELETE FROM tool_knowledge WHERE knowledge_id = %s", (knowledge_id,)) for tool_id, rtype in tool_links: cursor.execute( "INSERT INTO tool_knowledge (tool_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (tool_id, knowledge_id, rtype)) if 'resource_ids' in rel_data and rel_data['resource_ids'] is not None: cursor.execute("DELETE FROM knowledge_resource WHERE knowledge_id = %s", (knowledge_id,)) for res_id in rel_data['resource_ids']: cursor.execute( "INSERT INTO knowledge_resource (knowledge_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (knowledge_id, res_id)) self.conn.commit() finally: cursor.close() def delete(self, knowledge_id: str): """删除知识及其关联表记录""" cursor = self._get_cursor() try: cascade_delete(cursor, 'knowledge', knowledge_id) self.conn.commit() finally: cursor.close() def add_relation(self, source_id: str, target_id: str, relation_type: str): """添加一条知识间关系(不删除已有关系)""" cursor = self._get_cursor() try: cursor.execute( "INSERT INTO knowledge_relation (source_id, target_id, relation_type) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (source_id, target_id, relation_type)) self.conn.commit() finally: cursor.close() def add_resource(self, knowledge_id: str, resource_id: str): """添加一条知识-资源关联(不删除已有关联)""" cursor = self._get_cursor() try: cursor.execute( "INSERT INTO knowledge_resource (knowledge_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (knowledge_id, resource_id)) self.conn.commit() finally: cursor.close() def add_requirement(self, knowledge_id: str, requirement_id: str, relation_type: str = 'related'): """增量挂接 requirement-knowledge 边""" cursor = self._get_cursor() try: cursor.execute( "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (requirement_id, knowledge_id, relation_type)) self.conn.commit() finally: cursor.close() def add_capability(self, knowledge_id: str, capability_id: str, relation_type: str = 'related'): """增量挂接 capability-knowledge 边""" cursor = self._get_cursor() try: cursor.execute( "INSERT INTO capability_knowledge (capability_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (capability_id, knowledge_id, relation_type)) self.conn.commit() finally: cursor.close() def add_tool(self, knowledge_id: str, tool_id: str, relation_type: str = 'related'): """增量挂接 tool-knowledge 边""" cursor = self._get_cursor() try: cursor.execute( "INSERT INTO tool_knowledge (tool_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (tool_id, knowledge_id, relation_type)) self.conn.commit() finally: cursor.close() def count(self) -> int: """返回知识总数""" cursor = self._get_cursor() try: cursor.execute("SELECT COUNT(*) as count FROM knowledge") return cursor.fetchone()['count'] finally: cursor.close() def _build_where_clause(self, filters: str) -> str: """将Milvus风格的过滤表达式转换为PostgreSQL WHERE子句""" if not filters: return "" where = filters import re # 替换操作符 where = where.replace(' == ', ' = ') where = where.replace(' or ', ' OR ') where = where.replace(' and ', ' AND ') # 处理数组包含操作 where = re.sub(r'array_contains\((\w+),\s*"([^"]+)"\)', r"\1 @> ARRAY['\2']", where) # 处理 eval["score"] 语法 where = where.replace('eval["score"]', "(eval->>'score')::int") # 把所有剩余的双引号字符串值替换为单引号(PostgreSQL标准) where = re.sub(r'"([^"]*)"', r"'\1'", where) return f"WHERE {where}" def _format_result(self, row: Dict) -> Dict: """格式化查询结果""" if not row: return None result = dict(row) if 'tags' in result and isinstance(result['tags'], str): result['tags'] = json.loads(result['tags']) if 'source' in result and isinstance(result['source'], str): result['source'] = json.loads(result['source']) if 'eval' in result and isinstance(result['eval'], str): result['eval'] = json.loads(result['eval']) # 关联字段(来自 junction table 子查询,可能是 JSON 字符串或已解析的列表) for field in ('requirement_ids', 'capability_ids', 'tool_ids', 'resource_ids', 'requirement_links', 'capability_links', 'tool_links'): if field in result and isinstance(result[field], str): result[field] = json.loads(result[field]) elif field in result and result[field] is None: result[field] = [] if 'relations' in result and isinstance(result['relations'], str): result['relations'] = json.loads(result['relations']) elif 'relations' in result and result['relations'] is None: result['relations'] = [] if 'created_at' in result and result['created_at']: result['created_at'] = result['created_at'] * 1000 if 'updated_at' in result and result['updated_at']: result['updated_at'] = result['updated_at'] * 1000 return result def close(self): """关闭连接""" if self.conn: self.conn.close() def insert_batch(self, knowledge_list: List[Dict]): """批量插入知识""" if not knowledge_list: return cursor = self._get_cursor() try: data = [] for k in knowledge_list: data.append(( k['id'], k.get('task_embedding') or k.get('embedding'), k.get('content_embedding'), k['message_id'], k['task'], k['content'], k.get('types', []), json.dumps(k.get('tags', {})), k.get('tag_keys', []), k.get('scopes', []), k['owner'], json.dumps(k.get('source', {})), json.dumps(k.get('eval', {})), k['created_at'], k['updated_at'], k.get('status', 'approved'), )) execute_batch(cursor, """ INSERT INTO knowledge ( id, task_embedding, content_embedding, message_id, task, content, types, tags, tag_keys, scopes, owner, source, eval, created_at, updated_at, status ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, data) # 批量写入关联表 for k in knowledge_list: kid = k['id'] req_links = _normalize_links(k, 'requirement_links', 'requirement_ids', 'related') or [] for req_id, rtype in req_links: cursor.execute( "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (req_id, kid, rtype)) cap_links = _normalize_links(k, 'capability_links', 'capability_ids', 'related') or [] for cap_id, rtype in cap_links: cursor.execute( "INSERT INTO capability_knowledge (capability_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (cap_id, kid, rtype)) tool_links = _normalize_links(k, 'tool_links', 'tool_ids', 'related') or [] for tool_id, rtype in tool_links: cursor.execute( "INSERT INTO tool_knowledge (tool_id, knowledge_id, relation_type) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (tool_id, kid, rtype)) for res_id in k.get('resource_ids', []): cursor.execute( "INSERT INTO knowledge_resource (knowledge_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (kid, res_id)) self.conn.commit() finally: cursor.close()