pg_store.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. """
  2. PostgreSQL 存储封装(替代 Milvus)
  3. 使用远程 PostgreSQL + pgvector/fastann 存储知识数据
  4. """
  5. import os
  6. import json
  7. import psycopg2
  8. from psycopg2.extras import RealDictCursor, execute_batch
  9. from typing import List, Dict, Optional
  10. from dotenv import load_dotenv
  11. from knowhub.knowhub_db.cascade import cascade_delete
  12. load_dotenv()
  13. # 关联字段的子查询(从 junction table 读取,返回 JSON 数组)
  14. _REL_SUBQUERIES = """
  15. (SELECT COALESCE(json_agg(rk.requirement_id), '[]'::json)
  16. FROM requirement_knowledge rk WHERE rk.knowledge_id = knowledge.id) AS requirement_ids,
  17. (SELECT COALESCE(json_agg(ck.capability_id), '[]'::json)
  18. FROM capability_knowledge ck WHERE ck.knowledge_id = knowledge.id) AS capability_ids,
  19. (SELECT COALESCE(json_agg(tk.tool_id), '[]'::json)
  20. FROM tool_knowledge tk WHERE tk.knowledge_id = knowledge.id) AS tool_ids,
  21. (SELECT COALESCE(json_agg(kr.resource_id), '[]'::json)
  22. FROM knowledge_resource kr WHERE kr.knowledge_id = knowledge.id) AS resource_ids,
  23. (SELECT COALESCE(json_agg(json_build_object(
  24. 'target_id', krel.target_id, 'relation_type', krel.relation_type
  25. )), '[]'::json)
  26. FROM knowledge_relation krel WHERE krel.source_id = knowledge.id) AS relations
  27. """
  28. # 基础字段(不含 embedding)
  29. _BASE_FIELDS = (
  30. "id, message_id, task, content, types, tags, tag_keys, "
  31. "scopes, owner, source, eval, "
  32. "created_at, updated_at, status"
  33. )
  34. # 完整 SELECT(含关联子查询)
  35. _SELECT_FIELDS = f"{_BASE_FIELDS}, {_REL_SUBQUERIES}"
  36. # 含 embedding 的 SELECT
  37. _SELECT_FIELDS_WITH_EMB = f"task_embedding, content_embedding, {_SELECT_FIELDS}"
  38. class PostgreSQLStore:
  39. def __init__(self):
  40. """初始化 PostgreSQL 连接"""
  41. self.conn = psycopg2.connect(
  42. host=os.getenv('KNOWHUB_DB'),
  43. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  44. user=os.getenv('KNOWHUB_USER'),
  45. password=os.getenv('KNOWHUB_PASSWORD'),
  46. database=os.getenv('KNOWHUB_DB_NAME')
  47. )
  48. self.conn.autocommit = False
  49. print(f"[PostgreSQL] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  50. def _reconnect(self):
  51. self.conn = psycopg2.connect(
  52. host=os.getenv('KNOWHUB_DB'),
  53. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  54. user=os.getenv('KNOWHUB_USER'),
  55. password=os.getenv('KNOWHUB_PASSWORD'),
  56. database=os.getenv('KNOWHUB_DB_NAME')
  57. )
  58. self.conn.autocommit = False
  59. def _ensure_connection(self):
  60. if self.conn.closed != 0:
  61. self._reconnect()
  62. else:
  63. try:
  64. c = self.conn.cursor()
  65. c.execute("SELECT 1")
  66. c.close()
  67. except (psycopg2.OperationalError, psycopg2.InterfaceError):
  68. self._reconnect()
  69. def _get_cursor(self):
  70. """获取游标"""
  71. self._ensure_connection()
  72. return self.conn.cursor(cursor_factory=RealDictCursor)
  73. def insert(self, knowledge: Dict):
  74. """插入单条知识"""
  75. cursor = self._get_cursor()
  76. try:
  77. cursor.execute("""
  78. INSERT INTO knowledge (
  79. id, task_embedding, content_embedding, message_id, task, content, types, tags,
  80. tag_keys, scopes, owner, source, eval,
  81. created_at, updated_at, status
  82. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  83. """, (
  84. knowledge['id'],
  85. knowledge.get('task_embedding') or knowledge.get('embedding'),
  86. knowledge.get('content_embedding'),
  87. knowledge['message_id'],
  88. knowledge['task'],
  89. knowledge['content'],
  90. knowledge.get('types', []),
  91. json.dumps(knowledge.get('tags', {})),
  92. knowledge.get('tag_keys', []),
  93. knowledge.get('scopes', []),
  94. knowledge['owner'],
  95. json.dumps(knowledge.get('source', {})),
  96. json.dumps(knowledge.get('eval', {})),
  97. knowledge['created_at'],
  98. knowledge['updated_at'],
  99. knowledge.get('status', 'approved'),
  100. ))
  101. # 写入关联表
  102. kid = knowledge['id']
  103. for req_id in knowledge.get('requirement_ids', []):
  104. cursor.execute(
  105. "INSERT INTO requirement_knowledge (requirement_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  106. (req_id, kid))
  107. for cap_id in knowledge.get('capability_ids', []):
  108. cursor.execute(
  109. "INSERT INTO capability_knowledge (capability_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  110. (cap_id, kid))
  111. for tool_id in knowledge.get('tool_ids', []):
  112. cursor.execute(
  113. "INSERT INTO tool_knowledge (tool_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  114. (tool_id, kid))
  115. for res_id in knowledge.get('resource_ids', []):
  116. cursor.execute(
  117. "INSERT INTO knowledge_resource (knowledge_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  118. (kid, res_id))
  119. self.conn.commit()
  120. finally:
  121. cursor.close()
  122. def _apply_relation_filters(self, where_clause: str, relation_filters: Optional[Dict[str, str]], params: list) -> str:
  123. if not relation_filters:
  124. return where_clause
  125. rel_clauses = []
  126. for k, v in relation_filters.items():
  127. if not v: continue
  128. if k == 'requirement_id':
  129. rel_clauses.append("EXISTS (SELECT 1 FROM requirement_knowledge rk WHERE rk.knowledge_id = knowledge.id AND rk.requirement_id = %s)")
  130. params.append(v)
  131. elif k == 'capability_id':
  132. rel_clauses.append("EXISTS (SELECT 1 FROM capability_knowledge ck WHERE ck.knowledge_id = knowledge.id AND ck.capability_id = %s)")
  133. params.append(v)
  134. elif k == 'tool_id':
  135. rel_clauses.append("EXISTS (SELECT 1 FROM tool_knowledge tk WHERE tk.knowledge_id = knowledge.id AND tk.tool_id = %s)")
  136. params.append(v)
  137. if not rel_clauses:
  138. return where_clause
  139. rel_where = " AND ".join(rel_clauses)
  140. if where_clause.strip():
  141. return f"{where_clause} AND {rel_where}"
  142. else:
  143. return f"WHERE {rel_where}"
  144. def search(self, query_embedding: List[float], filters: Optional[str] = None, limit: int = 10, relation_filters: Optional[Dict[str, str]] = None) -> List[Dict]:
  145. """向量检索(使用余弦相似度)"""
  146. cursor = self._get_cursor()
  147. try:
  148. where_clause = self._build_where_clause(filters) if filters else ""
  149. params = []
  150. where_clause = self._apply_relation_filters(where_clause, relation_filters, params)
  151. sql = f"""
  152. SELECT {_SELECT_FIELDS},
  153. 1 - (task_embedding <=> %s::real[]) as score
  154. FROM knowledge
  155. {where_clause}
  156. ORDER BY task_embedding <=> %s::real[]
  157. LIMIT %s
  158. """
  159. final_params = [query_embedding] + params + [query_embedding, limit]
  160. cursor.execute(sql, tuple(final_params))
  161. results = cursor.fetchall()
  162. return [self._format_result(r) for r in results]
  163. finally:
  164. cursor.close()
  165. def query(self, filters: str, limit: int = 100, relation_filters: Optional[Dict[str, str]] = None) -> List[Dict]:
  166. """纯标量查询"""
  167. cursor = self._get_cursor()
  168. try:
  169. where_clause = self._build_where_clause(filters) if filters else ""
  170. params = []
  171. where_clause = self._apply_relation_filters(where_clause, relation_filters, params)
  172. sql = f"""
  173. SELECT {_SELECT_FIELDS}
  174. FROM knowledge
  175. {where_clause}
  176. LIMIT %s
  177. """
  178. final_params = params + [limit]
  179. cursor.execute(sql, tuple(final_params))
  180. results = cursor.fetchall()
  181. return [self._format_result(r) for r in results]
  182. finally:
  183. cursor.close()
  184. def get_by_id(self, knowledge_id: str, include_embedding: bool = False) -> Optional[Dict]:
  185. """根据ID获取知识(默认不返回embedding以提升性能)"""
  186. cursor = self._get_cursor()
  187. try:
  188. fields = _SELECT_FIELDS_WITH_EMB if include_embedding else _SELECT_FIELDS
  189. cursor.execute(f"""
  190. SELECT {fields}
  191. FROM knowledge WHERE id = %s
  192. """, (knowledge_id,))
  193. result = cursor.fetchone()
  194. return self._format_result(result) if result else None
  195. finally:
  196. cursor.close()
  197. def update(self, knowledge_id: str, updates: Dict):
  198. """更新知识"""
  199. cursor = self._get_cursor()
  200. try:
  201. # 分离关联字段和实体字段
  202. req_ids = updates.pop('requirement_ids', None)
  203. cap_ids = updates.pop('capability_ids', None)
  204. tool_ids = updates.pop('tool_ids', None)
  205. resource_ids = updates.pop('resource_ids', None)
  206. if updates:
  207. set_parts = []
  208. params = []
  209. for key, value in updates.items():
  210. if key in ('tags', 'source', 'eval'):
  211. set_parts.append(f"{key} = %s")
  212. params.append(json.dumps(value))
  213. else:
  214. set_parts.append(f"{key} = %s")
  215. params.append(value)
  216. params.append(knowledge_id)
  217. sql = f"UPDATE knowledge SET {', '.join(set_parts)} WHERE id = %s"
  218. cursor.execute(sql, params)
  219. # 更新关联表(全量替换)
  220. if req_ids is not None:
  221. cursor.execute("DELETE FROM requirement_knowledge WHERE knowledge_id = %s", (knowledge_id,))
  222. for req_id in req_ids:
  223. cursor.execute(
  224. "INSERT INTO requirement_knowledge (requirement_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  225. (req_id, knowledge_id))
  226. if cap_ids is not None:
  227. cursor.execute("DELETE FROM capability_knowledge WHERE knowledge_id = %s", (knowledge_id,))
  228. for cap_id in cap_ids:
  229. cursor.execute(
  230. "INSERT INTO capability_knowledge (capability_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  231. (cap_id, knowledge_id))
  232. if tool_ids is not None:
  233. cursor.execute("DELETE FROM tool_knowledge WHERE knowledge_id = %s", (knowledge_id,))
  234. for tool_id in tool_ids:
  235. cursor.execute(
  236. "INSERT INTO tool_knowledge (tool_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  237. (tool_id, knowledge_id))
  238. if resource_ids is not None:
  239. cursor.execute("DELETE FROM knowledge_resource WHERE knowledge_id = %s", (knowledge_id,))
  240. for res_id in resource_ids:
  241. cursor.execute(
  242. "INSERT INTO knowledge_resource (knowledge_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  243. (knowledge_id, res_id))
  244. self.conn.commit()
  245. finally:
  246. cursor.close()
  247. def delete(self, knowledge_id: str):
  248. """删除知识及其关联表记录"""
  249. cursor = self._get_cursor()
  250. try:
  251. cascade_delete(cursor, 'knowledge', knowledge_id)
  252. self.conn.commit()
  253. finally:
  254. cursor.close()
  255. def add_relation(self, source_id: str, target_id: str, relation_type: str):
  256. """添加一条知识间关系(不删除已有关系)"""
  257. cursor = self._get_cursor()
  258. try:
  259. cursor.execute(
  260. "INSERT INTO knowledge_relation (source_id, target_id, relation_type) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING",
  261. (source_id, target_id, relation_type))
  262. self.conn.commit()
  263. finally:
  264. cursor.close()
  265. def add_resource(self, knowledge_id: str, resource_id: str):
  266. """添加一条知识-资源关联(不删除已有关联)"""
  267. cursor = self._get_cursor()
  268. try:
  269. cursor.execute(
  270. "INSERT INTO knowledge_resource (knowledge_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  271. (knowledge_id, resource_id))
  272. self.conn.commit()
  273. finally:
  274. cursor.close()
  275. def count(self) -> int:
  276. """返回知识总数"""
  277. cursor = self._get_cursor()
  278. try:
  279. cursor.execute("SELECT COUNT(*) as count FROM knowledge")
  280. return cursor.fetchone()['count']
  281. finally:
  282. cursor.close()
  283. def _build_where_clause(self, filters: str) -> str:
  284. """将Milvus风格的过滤表达式转换为PostgreSQL WHERE子句"""
  285. if not filters:
  286. return ""
  287. where = filters
  288. import re
  289. # 替换操作符
  290. where = where.replace(' == ', ' = ')
  291. where = where.replace(' or ', ' OR ')
  292. where = where.replace(' and ', ' AND ')
  293. # 处理数组包含操作
  294. where = re.sub(r'array_contains\((\w+),\s*"([^"]+)"\)', r"\1 @> ARRAY['\2']", where)
  295. # 处理 eval["score"] 语法
  296. where = where.replace('eval["score"]', "(eval->>'score')::int")
  297. # 把所有剩余的双引号字符串值替换为单引号(PostgreSQL标准)
  298. where = re.sub(r'"([^"]*)"', r"'\1'", where)
  299. return f"WHERE {where}"
  300. def _format_result(self, row: Dict) -> Dict:
  301. """格式化查询结果"""
  302. if not row:
  303. return None
  304. result = dict(row)
  305. if 'tags' in result and isinstance(result['tags'], str):
  306. result['tags'] = json.loads(result['tags'])
  307. if 'source' in result and isinstance(result['source'], str):
  308. result['source'] = json.loads(result['source'])
  309. if 'eval' in result and isinstance(result['eval'], str):
  310. result['eval'] = json.loads(result['eval'])
  311. # 关联字段(来自 junction table 子查询,可能是 JSON 字符串或已解析的列表)
  312. for field in ('requirement_ids', 'capability_ids', 'tool_ids', 'resource_ids'):
  313. if field in result and isinstance(result[field], str):
  314. result[field] = json.loads(result[field])
  315. elif field in result and result[field] is None:
  316. result[field] = []
  317. if 'relations' in result and isinstance(result['relations'], str):
  318. result['relations'] = json.loads(result['relations'])
  319. elif 'relations' in result and result['relations'] is None:
  320. result['relations'] = []
  321. if 'created_at' in result and result['created_at']:
  322. result['created_at'] = result['created_at'] * 1000
  323. if 'updated_at' in result and result['updated_at']:
  324. result['updated_at'] = result['updated_at'] * 1000
  325. return result
  326. def close(self):
  327. """关闭连接"""
  328. if self.conn:
  329. self.conn.close()
  330. def insert_batch(self, knowledge_list: List[Dict]):
  331. """批量插入知识"""
  332. if not knowledge_list:
  333. return
  334. cursor = self._get_cursor()
  335. try:
  336. data = []
  337. for k in knowledge_list:
  338. data.append((
  339. k['id'], k.get('task_embedding') or k.get('embedding'),
  340. k.get('content_embedding'),
  341. k['message_id'], k['task'],
  342. k['content'], k.get('types', []),
  343. json.dumps(k.get('tags', {})), k.get('tag_keys', []),
  344. k.get('scopes', []), k['owner'],
  345. json.dumps(k.get('source', {})), json.dumps(k.get('eval', {})),
  346. k['created_at'], k['updated_at'], k.get('status', 'approved'),
  347. ))
  348. execute_batch(cursor, """
  349. INSERT INTO knowledge (
  350. id, task_embedding, content_embedding, message_id, task, content, types, tags,
  351. tag_keys, scopes, owner, source, eval,
  352. created_at, updated_at, status
  353. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  354. """, data)
  355. # 批量写入关联表
  356. for k in knowledge_list:
  357. kid = k['id']
  358. for req_id in k.get('requirement_ids', []):
  359. cursor.execute(
  360. "INSERT INTO requirement_knowledge (requirement_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  361. (req_id, kid))
  362. for cap_id in k.get('capability_ids', []):
  363. cursor.execute(
  364. "INSERT INTO capability_knowledge (capability_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  365. (cap_id, kid))
  366. for tool_id in k.get('tool_ids', []):
  367. cursor.execute(
  368. "INSERT INTO tool_knowledge (tool_id, knowledge_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  369. (tool_id, kid))
  370. for res_id in k.get('resource_ids', []):
  371. cursor.execute(
  372. "INSERT INTO knowledge_resource (knowledge_id, resource_id) VALUES (%s, %s) ON CONFLICT DO NOTHING",
  373. (kid, res_id))
  374. self.conn.commit()
  375. finally:
  376. cursor.close()