| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515 |
- """
- PostgreSQL 存储封装(替代 Milvus)
- 使用远程 PostgreSQL + pgvector/fastann 存储知识数据
- """
- import os
- import json
- import psycopg2
- from psycopg2.extras import RealDictCursor, execute_batch
- from typing import List, Dict, Optional
- from dotenv import load_dotenv
- from knowhub.knowhub_db.cascade import cascade_delete
- load_dotenv()
- # 关联字段的子查询(从 junction table 读取)
- # 对于带 relation_type 的 *_knowledge 边,同时暴露两种视图:
- # - *_ids : 扁平 ID 列表(向后兼容,不含 type)
- # - *_links : [{id, relation_type}](含 type)
- _REL_SUBQUERIES = """
- (SELECT COALESCE(json_agg(rk.requirement_id), '[]'::json)
- FROM requirement_knowledge rk WHERE rk.knowledge_id = knowledge.id) AS requirement_ids,
- (SELECT COALESCE(json_agg(json_build_object(
- 'id', rk2.requirement_id, 'relation_type', rk2.relation_type
- )), '[]'::json)
- FROM requirement_knowledge rk2 WHERE rk2.knowledge_id = knowledge.id) AS requirement_links,
- (SELECT COALESCE(json_agg(ck.capability_id), '[]'::json)
- FROM capability_knowledge ck WHERE ck.knowledge_id = knowledge.id) AS capability_ids,
- (SELECT COALESCE(json_agg(json_build_object(
- 'id', ck2.capability_id, 'relation_type', ck2.relation_type
- )), '[]'::json)
- FROM capability_knowledge ck2 WHERE ck2.knowledge_id = knowledge.id) AS capability_links,
- (SELECT COALESCE(json_agg(tk.tool_id), '[]'::json)
- FROM tool_knowledge tk WHERE tk.knowledge_id = knowledge.id) AS tool_ids,
- (SELECT COALESCE(json_agg(json_build_object(
- 'id', tk2.tool_id, 'relation_type', tk2.relation_type
- )), '[]'::json)
- FROM tool_knowledge tk2 WHERE tk2.knowledge_id = knowledge.id) AS tool_links,
- (SELECT COALESCE(json_agg(kr.resource_id), '[]'::json)
- FROM knowledge_resource kr WHERE kr.knowledge_id = knowledge.id) AS resource_ids,
- (SELECT COALESCE(json_agg(json_build_object(
- 'target_id', krel.target_id, 'relation_type', krel.relation_type
- )), '[]'::json)
- FROM knowledge_relation krel WHERE krel.source_id = knowledge.id) AS relations
- """
- # 基础字段(不含 embedding)
- _BASE_FIELDS = (
- "id, message_id, task, content, types, tags, tag_keys, "
- "scopes, owner, source, eval, "
- "created_at, updated_at, status"
- )
- # 完整 SELECT(含关联子查询)
- _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}"
- # 含 embedding 的 SELECT
- _SELECT_FIELDS_WITH_EMB = f"task_embedding, content_embedding, {_SELECT_FIELDS}"
- def _normalize_links(data: Dict, links_key: str, ids_key: str, default_type: str):
- """
- 统一两种输入格式:
- - {links_key: [{id, relation_type}, ...]} → 使用指定 type
- - {ids_key: [id1, id2, ...]} → 使用 default_type
- 两个 key 都没有返回 None(不更新)
- """
- if links_key in data and data[links_key] is not None:
- out = []
- for item in data[links_key]:
- if isinstance(item, dict):
- out.append((item['id'], item.get('relation_type', default_type)))
- else:
- out.append((item, default_type))
- return out
- if ids_key in data and data[ids_key] is not None:
- return [(i, default_type) for i in data[ids_key]]
- return None
- class PostgreSQLStore:
- def __init__(self):
- """初始化 PostgreSQL 连接"""
- self.conn = psycopg2.connect(
- host=os.getenv('KNOWHUB_DB'),
- port=int(os.getenv('KNOWHUB_PORT', 5432)),
- user=os.getenv('KNOWHUB_USER'),
- password=os.getenv('KNOWHUB_PASSWORD'),
- database=os.getenv('KNOWHUB_DB_NAME')
- )
- self.conn.autocommit = True
- print(f"[PostgreSQL] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
- def _reconnect(self):
- self.conn = psycopg2.connect(
- host=os.getenv('KNOWHUB_DB'),
- port=int(os.getenv('KNOWHUB_PORT', 5432)),
- user=os.getenv('KNOWHUB_USER'),
- password=os.getenv('KNOWHUB_PASSWORD'),
- database=os.getenv('KNOWHUB_DB_NAME')
- )
- self.conn.autocommit = True
- def _ensure_connection(self):
- if self.conn.closed != 0:
- self._reconnect()
- else:
- try:
- c = self.conn.cursor()
- c.execute("SELECT 1")
- c.close()
- except (psycopg2.OperationalError, psycopg2.InterfaceError):
- self._reconnect()
- def _get_cursor(self):
- """获取游标"""
- self._ensure_connection()
- return self.conn.cursor(cursor_factory=RealDictCursor)
- def insert(self, knowledge: Dict):
- """插入单条知识"""
- cursor = self._get_cursor()
- try:
- cursor.execute("""
- INSERT INTO knowledge (
- id, task_embedding, content_embedding, message_id, task, content, types, tags,
- tag_keys, scopes, owner, source, eval,
- created_at, updated_at, status
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
- """, (
- knowledge['id'],
- knowledge.get('task_embedding') or knowledge.get('embedding'),
- knowledge.get('content_embedding'),
- knowledge['message_id'],
- knowledge['task'],
- knowledge['content'],
- knowledge.get('types', []),
- json.dumps(knowledge.get('tags', {})),
- knowledge.get('tag_keys', []),
- knowledge.get('scopes', []),
- knowledge['owner'],
- json.dumps(knowledge.get('source', {})),
- json.dumps(knowledge.get('eval', {})),
- knowledge['created_at'],
- knowledge['updated_at'],
- knowledge.get('status', 'approved'),
- ))
- # 写入关联表
- kid = knowledge['id']
- req_links = _normalize_links(knowledge, 'requirement_links', 'requirement_ids', 'related') or []
- for req_id, rtype in req_links:
- cursor.execute(
- "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (req_id, kid, rtype))
- cap_links = _normalize_links(knowledge, 'capability_links', 'capability_ids', 'related') or []
- for cap_id, rtype in cap_links:
- cursor.execute(
- "INSERT INTO capability_knowledge (capability_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (cap_id, kid, rtype))
- tool_links = _normalize_links(knowledge, 'tool_links', 'tool_ids', 'related') or []
- for tool_id, rtype in tool_links:
- cursor.execute(
- "INSERT INTO tool_knowledge (tool_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (tool_id, kid, rtype))
- for res_id in knowledge.get('resource_ids', []):
- cursor.execute(
- "INSERT INTO knowledge_resource (knowledge_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (kid, res_id))
- self.conn.commit()
- finally:
- cursor.close()
- def _apply_relation_filters(self, where_clause: str, relation_filters: Optional[Dict[str, str]], params: list) -> str:
- if not relation_filters:
- return where_clause
-
- rel_clauses = []
- for k, v in relation_filters.items():
- if not v: continue
- if k == 'requirement_id':
- rel_clauses.append("EXISTS (SELECT 1 FROM requirement_knowledge rk WHERE rk.knowledge_id = knowledge.id AND rk.requirement_id = %s)")
- params.append(v)
- elif k == 'capability_id':
- rel_clauses.append("EXISTS (SELECT 1 FROM capability_knowledge ck WHERE ck.knowledge_id = knowledge.id AND ck.capability_id = %s)")
- params.append(v)
- elif k == 'tool_id':
- rel_clauses.append("EXISTS (SELECT 1 FROM tool_knowledge tk WHERE tk.knowledge_id = knowledge.id AND tk.tool_id = %s)")
- params.append(v)
-
- if not rel_clauses:
- return where_clause
-
- rel_where = " AND ".join(rel_clauses)
- if where_clause.strip():
- return f"{where_clause} AND {rel_where}"
- else:
- return f"WHERE {rel_where}"
- def search(self, query_embedding: List[float], filters: Optional[str] = None, limit: int = 10, relation_filters: Optional[Dict[str, str]] = None) -> List[Dict]:
- """向量检索(使用余弦相似度)"""
- cursor = self._get_cursor()
- try:
- where_clause = self._build_where_clause(filters) if filters else ""
- params = []
- where_clause = self._apply_relation_filters(where_clause, relation_filters, params)
- sql = f"""
- SELECT {_SELECT_FIELDS},
- 1 - (task_embedding <=> %s::real[]) as score
- FROM knowledge
- {where_clause}
- ORDER BY task_embedding <=> %s::real[]
- LIMIT %s
- """
- final_params = [query_embedding] + params + [query_embedding, limit]
- cursor.execute(sql, tuple(final_params))
- results = cursor.fetchall()
- return [self._format_result(r) for r in results]
- finally:
- cursor.close()
- def query(self, filters: str, limit: int = 100, relation_filters: Optional[Dict[str, str]] = None) -> List[Dict]:
- """纯标量查询"""
- cursor = self._get_cursor()
- try:
- where_clause = self._build_where_clause(filters) if filters else ""
- params = []
- where_clause = self._apply_relation_filters(where_clause, relation_filters, params)
- sql = f"""
- SELECT {_SELECT_FIELDS}
- FROM knowledge
- {where_clause}
- LIMIT %s
- """
- final_params = params + [limit]
- cursor.execute(sql, tuple(final_params))
- results = cursor.fetchall()
- return [self._format_result(r) for r in results]
- finally:
- cursor.close()
- def get_by_id(self, knowledge_id: str, include_embedding: bool = False) -> Optional[Dict]:
- """根据ID获取知识(默认不返回embedding以提升性能)"""
- cursor = self._get_cursor()
- try:
- fields = _SELECT_FIELDS_WITH_EMB if include_embedding else _SELECT_FIELDS
- cursor.execute(f"""
- SELECT {fields}
- FROM knowledge WHERE id = %s
- """, (knowledge_id,))
- result = cursor.fetchone()
- return self._format_result(result) if result else None
- finally:
- cursor.close()
- def update(self, knowledge_id: str, updates: Dict):
- """更新知识"""
- cursor = self._get_cursor()
- try:
- # 分离关联字段和实体字段
- rel_keys = ('requirement_ids', 'requirement_links',
- 'capability_ids', 'capability_links',
- 'tool_ids', 'tool_links', 'resource_ids')
- rel_data = {k: updates.pop(k) for k in rel_keys if k in updates}
- if updates:
- set_parts = []
- params = []
- for key, value in updates.items():
- if key in ('tags', 'source', 'eval'):
- set_parts.append(f"{key} = %s")
- params.append(json.dumps(value))
- else:
- set_parts.append(f"{key} = %s")
- params.append(value)
- params.append(knowledge_id)
- sql = f"UPDATE knowledge SET {', '.join(set_parts)} WHERE id = %s"
- cursor.execute(sql, params)
- # 更新关联表(全量替换)
- req_links = _normalize_links(rel_data, 'requirement_links', 'requirement_ids', 'related')
- if req_links is not None:
- cursor.execute("DELETE FROM requirement_knowledge WHERE knowledge_id = %s", (knowledge_id,))
- for req_id, rtype in req_links:
- cursor.execute(
- "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (req_id, knowledge_id, rtype))
- cap_links = _normalize_links(rel_data, 'capability_links', 'capability_ids', 'related')
- if cap_links is not None:
- cursor.execute("DELETE FROM capability_knowledge WHERE knowledge_id = %s", (knowledge_id,))
- for cap_id, rtype in cap_links:
- cursor.execute(
- "INSERT INTO capability_knowledge (capability_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (cap_id, knowledge_id, rtype))
- tool_links = _normalize_links(rel_data, 'tool_links', 'tool_ids', 'related')
- if tool_links is not None:
- cursor.execute("DELETE FROM tool_knowledge WHERE knowledge_id = %s", (knowledge_id,))
- for tool_id, rtype in tool_links:
- cursor.execute(
- "INSERT INTO tool_knowledge (tool_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (tool_id, knowledge_id, rtype))
- if 'resource_ids' in rel_data and rel_data['resource_ids'] is not None:
- cursor.execute("DELETE FROM knowledge_resource WHERE knowledge_id = %s", (knowledge_id,))
- for res_id in rel_data['resource_ids']:
- cursor.execute(
- "INSERT INTO knowledge_resource (knowledge_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (knowledge_id, res_id))
- self.conn.commit()
- finally:
- cursor.close()
- def delete(self, knowledge_id: str):
- """删除知识及其关联表记录"""
- cursor = self._get_cursor()
- try:
- cascade_delete(cursor, 'knowledge', knowledge_id)
- self.conn.commit()
- finally:
- cursor.close()
- def add_relation(self, source_id: str, target_id: str, relation_type: str):
- """添加一条知识间关系(不删除已有关系)"""
- cursor = self._get_cursor()
- try:
- cursor.execute(
- "INSERT INTO knowledge_relation (source_id, target_id, relation_type) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (source_id, target_id, relation_type))
- self.conn.commit()
- finally:
- cursor.close()
- def add_resource(self, knowledge_id: str, resource_id: str):
- """添加一条知识-资源关联(不删除已有关联)"""
- cursor = self._get_cursor()
- try:
- cursor.execute(
- "INSERT INTO knowledge_resource (knowledge_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (knowledge_id, resource_id))
- self.conn.commit()
- finally:
- cursor.close()
- def add_requirement(self, knowledge_id: str, requirement_id: str,
- relation_type: str = 'related'):
- """增量挂接 requirement-knowledge 边"""
- cursor = self._get_cursor()
- try:
- cursor.execute(
- "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (requirement_id, knowledge_id, relation_type))
- self.conn.commit()
- finally:
- cursor.close()
- def add_capability(self, knowledge_id: str, capability_id: str,
- relation_type: str = 'related'):
- """增量挂接 capability-knowledge 边"""
- cursor = self._get_cursor()
- try:
- cursor.execute(
- "INSERT INTO capability_knowledge (capability_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (capability_id, knowledge_id, relation_type))
- self.conn.commit()
- finally:
- cursor.close()
- def add_tool(self, knowledge_id: str, tool_id: str,
- relation_type: str = 'related'):
- """增量挂接 tool-knowledge 边"""
- cursor = self._get_cursor()
- try:
- cursor.execute(
- "INSERT INTO tool_knowledge (tool_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (tool_id, knowledge_id, relation_type))
- self.conn.commit()
- finally:
- cursor.close()
- def count(self) -> int:
- """返回知识总数"""
- cursor = self._get_cursor()
- try:
- cursor.execute("SELECT COUNT(*) as count FROM knowledge")
- return cursor.fetchone()['count']
- finally:
- cursor.close()
- def _build_where_clause(self, filters: str) -> str:
- """将Milvus风格的过滤表达式转换为PostgreSQL WHERE子句"""
- if not filters:
- return ""
- where = filters
- import re
- # 替换操作符
- where = where.replace(' == ', ' = ')
- where = where.replace(' or ', ' OR ')
- where = where.replace(' and ', ' AND ')
- # 处理数组包含操作
- where = re.sub(r'array_contains\((\w+),\s*"([^"]+)"\)', r"\1 @> ARRAY['\2']", where)
- # 处理 eval["score"] 语法
- where = where.replace('eval["score"]', "(eval->>'score')::int")
- # 把所有剩余的双引号字符串值替换为单引号(PostgreSQL标准)
- where = re.sub(r'"([^"]*)"', r"'\1'", where)
- return f"WHERE {where}"
- def _format_result(self, row: Dict) -> Dict:
- """格式化查询结果"""
- if not row:
- return None
- result = dict(row)
- if 'tags' in result and isinstance(result['tags'], str):
- result['tags'] = json.loads(result['tags'])
- if 'source' in result and isinstance(result['source'], str):
- result['source'] = json.loads(result['source'])
- if 'eval' in result and isinstance(result['eval'], str):
- result['eval'] = json.loads(result['eval'])
- # 关联字段(来自 junction table 子查询,可能是 JSON 字符串或已解析的列表)
- for field in ('requirement_ids', 'capability_ids', 'tool_ids', 'resource_ids',
- 'requirement_links', 'capability_links', 'tool_links'):
- if field in result and isinstance(result[field], str):
- result[field] = json.loads(result[field])
- elif field in result and result[field] is None:
- result[field] = []
- if 'relations' in result and isinstance(result['relations'], str):
- result['relations'] = json.loads(result['relations'])
- elif 'relations' in result and result['relations'] is None:
- result['relations'] = []
- if 'created_at' in result and result['created_at']:
- result['created_at'] = result['created_at'] * 1000
- if 'updated_at' in result and result['updated_at']:
- result['updated_at'] = result['updated_at'] * 1000
- return result
- def close(self):
- """关闭连接"""
- if self.conn:
- self.conn.close()
- def insert_batch(self, knowledge_list: List[Dict]):
- """批量插入知识"""
- if not knowledge_list:
- return
- cursor = self._get_cursor()
- try:
- data = []
- for k in knowledge_list:
- data.append((
- k['id'], k.get('task_embedding') or k.get('embedding'),
- k.get('content_embedding'),
- k['message_id'], k['task'],
- k['content'], k.get('types', []),
- json.dumps(k.get('tags', {})), k.get('tag_keys', []),
- k.get('scopes', []), k['owner'],
- json.dumps(k.get('source', {})), json.dumps(k.get('eval', {})),
- k['created_at'], k['updated_at'], k.get('status', 'approved'),
- ))
- execute_batch(cursor, """
- INSERT INTO knowledge (
- id, task_embedding, content_embedding, message_id, task, content, types, tags,
- tag_keys, scopes, owner, source, eval,
- created_at, updated_at, status
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
- """, data)
- # 批量写入关联表
- for k in knowledge_list:
- kid = k['id']
- req_links = _normalize_links(k, 'requirement_links', 'requirement_ids', 'related') or []
- for req_id, rtype in req_links:
- cursor.execute(
- "INSERT INTO requirement_knowledge (requirement_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (req_id, kid, rtype))
- cap_links = _normalize_links(k, 'capability_links', 'capability_ids', 'related') or []
- for cap_id, rtype in cap_links:
- cursor.execute(
- "INSERT INTO capability_knowledge (capability_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (cap_id, kid, rtype))
- tool_links = _normalize_links(k, 'tool_links', 'tool_ids', 'related') or []
- for tool_id, rtype in tool_links:
- cursor.execute(
- "INSERT INTO tool_knowledge (tool_id, knowledge_id, relation_type) "
- "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
- (tool_id, kid, rtype))
- for res_id in k.get('resource_ids', []):
- cursor.execute(
- "INSERT INTO knowledge_resource (knowledge_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (kid, res_id))
- self.conn.commit()
- finally:
- cursor.close()
|