pg_store.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. """
  2. PostgreSQL 存储封装(替代 Milvus)
  3. 使用远程 PostgreSQL + pgvector/fastann 存储知识数据
  4. """
  5. import os
  6. import json
  7. import psycopg2
  8. from psycopg2.extras import RealDictCursor, execute_batch
  9. from typing import List, Dict, Optional
  10. from dotenv import load_dotenv
  11. load_dotenv()
  12. class PostgreSQLStore:
  13. def __init__(self):
  14. """初始化 PostgreSQL 连接"""
  15. self.conn = psycopg2.connect(
  16. host=os.getenv('KNOWHUB_DB'),
  17. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  18. user=os.getenv('KNOWHUB_USER'),
  19. password=os.getenv('KNOWHUB_PASSWORD'),
  20. database=os.getenv('KNOWHUB_DB_NAME')
  21. )
  22. self.conn.autocommit = False
  23. print(f"[PostgreSQL] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  24. def _reconnect(self):
  25. self.conn = psycopg2.connect(
  26. host=os.getenv('KNOWHUB_DB'),
  27. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  28. user=os.getenv('KNOWHUB_USER'),
  29. password=os.getenv('KNOWHUB_PASSWORD'),
  30. database=os.getenv('KNOWHUB_DB_NAME')
  31. )
  32. self.conn.autocommit = False
  33. def _ensure_connection(self):
  34. if self.conn.closed != 0:
  35. self._reconnect()
  36. else:
  37. try:
  38. c = self.conn.cursor()
  39. c.execute("SELECT 1")
  40. c.close()
  41. except (psycopg2.OperationalError, psycopg2.InterfaceError):
  42. self._reconnect()
  43. def _get_cursor(self):
  44. """获取游标"""
  45. self._ensure_connection()
  46. return self.conn.cursor(cursor_factory=RealDictCursor)
  47. def insert(self, knowledge: Dict):
  48. """插入单条知识"""
  49. cursor = self._get_cursor()
  50. try:
  51. cursor.execute("""
  52. INSERT INTO knowledge (
  53. id, task_embedding, content_embedding, message_id, task, content, types, tags,
  54. tag_keys, scopes, owner, resource_ids, source, eval,
  55. created_at, updated_at, status, relationships,
  56. support_capability, tools
  57. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  58. """, (
  59. knowledge['id'],
  60. knowledge.get('task_embedding') or knowledge.get('embedding'),
  61. knowledge.get('content_embedding'),
  62. knowledge['message_id'],
  63. knowledge['task'],
  64. knowledge['content'],
  65. knowledge.get('types', []),
  66. json.dumps(knowledge.get('tags', {})),
  67. knowledge.get('tag_keys', []),
  68. knowledge.get('scopes', []),
  69. knowledge['owner'],
  70. knowledge.get('resource_ids', []),
  71. json.dumps(knowledge.get('source', {})),
  72. json.dumps(knowledge.get('eval', {})),
  73. knowledge['created_at'],
  74. knowledge['updated_at'],
  75. knowledge.get('status', 'approved'),
  76. json.dumps(knowledge.get('relationships', [])),
  77. json.dumps(knowledge.get('support_capability', [])),
  78. json.dumps(knowledge.get('tools', [])),
  79. ))
  80. self.conn.commit()
  81. finally:
  82. cursor.close()
  83. def search(self, query_embedding: List[float], filters: Optional[str] = None, limit: int = 10) -> List[Dict]:
  84. """向量检索(使用余弦相似度)"""
  85. cursor = self._get_cursor()
  86. try:
  87. where_clause = self._build_where_clause(filters) if filters else ""
  88. sql = f"""
  89. SELECT id, message_id, task, content, types, tags, tag_keys,
  90. scopes, owner, resource_ids, source, eval, created_at,
  91. updated_at, status, relationships, support_capability, tools,
  92. 1 - (task_embedding <=> %s::real[]) as score
  93. FROM knowledge
  94. {where_clause}
  95. ORDER BY task_embedding <=> %s::real[]
  96. LIMIT %s
  97. """
  98. cursor.execute(sql, (query_embedding, query_embedding, limit))
  99. results = cursor.fetchall()
  100. return [self._format_result(r) for r in results]
  101. finally:
  102. cursor.close()
  103. def query(self, filters: str, limit: int = 100) -> List[Dict]:
  104. """纯标量查询"""
  105. cursor = self._get_cursor()
  106. try:
  107. where_clause = self._build_where_clause(filters)
  108. sql = f"""
  109. SELECT id, message_id, task, content, types, tags, tag_keys,
  110. scopes, owner, resource_ids, source, eval, created_at,
  111. updated_at, status, relationships, support_capability, tools
  112. FROM knowledge
  113. {where_clause}
  114. LIMIT %s
  115. """
  116. cursor.execute(sql, (limit,))
  117. results = cursor.fetchall()
  118. return [self._format_result(r) for r in results]
  119. finally:
  120. cursor.close()
  121. def get_by_id(self, knowledge_id: str, include_embedding: bool = False) -> Optional[Dict]:
  122. """根据ID获取知识(默认不返回embedding以提升性能)"""
  123. cursor = self._get_cursor()
  124. try:
  125. # 默认不返回embedding(1536维向量太大,详情页不需要)
  126. if include_embedding:
  127. fields = "id, task_embedding, content_embedding, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships, support_capability, tools"
  128. else:
  129. fields = "id, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships, support_capability, tools"
  130. cursor.execute(f"""
  131. SELECT {fields}
  132. FROM knowledge WHERE id = %s
  133. """, (knowledge_id,))
  134. result = cursor.fetchone()
  135. return self._format_result(result) if result else None
  136. finally:
  137. cursor.close()
  138. def update(self, knowledge_id: str, updates: Dict):
  139. """更新知识"""
  140. cursor = self._get_cursor()
  141. try:
  142. set_parts = []
  143. params = []
  144. for key, value in updates.items():
  145. if key in ('tags', 'source', 'eval', 'support_capability', 'tools'):
  146. set_parts.append(f"{key} = %s")
  147. params.append(json.dumps(value))
  148. elif key == 'relationships':
  149. set_parts.append(f"{key} = %s")
  150. params.append(json.dumps(value) if isinstance(value, list) else value)
  151. else:
  152. set_parts.append(f"{key} = %s")
  153. params.append(value)
  154. params.append(knowledge_id)
  155. sql = f"UPDATE knowledge SET {', '.join(set_parts)} WHERE id = %s"
  156. cursor.execute(sql, params)
  157. self.conn.commit()
  158. finally:
  159. cursor.close()
  160. def delete(self, knowledge_id: str):
  161. """删除知识"""
  162. cursor = self._get_cursor()
  163. try:
  164. cursor.execute("DELETE FROM knowledge WHERE id = %s", (knowledge_id,))
  165. self.conn.commit()
  166. finally:
  167. cursor.close()
  168. def count(self) -> int:
  169. """返回知识总数"""
  170. cursor = self._get_cursor()
  171. try:
  172. cursor.execute("SELECT COUNT(*) as count FROM knowledge")
  173. return cursor.fetchone()['count']
  174. finally:
  175. cursor.close()
  176. def _build_where_clause(self, filters: str) -> str:
  177. """将Milvus风格的过滤表达式转换为PostgreSQL WHERE子句"""
  178. if not filters:
  179. return ""
  180. where = filters
  181. import re
  182. # 替换操作符
  183. where = where.replace(' == ', ' = ')
  184. where = where.replace(' or ', ' OR ')
  185. where = where.replace(' and ', ' AND ')
  186. # 处理数组包含操作
  187. where = re.sub(r'array_contains\((\w+),\s*"([^"]+)"\)', r"\1 @> ARRAY['\2']", where)
  188. # 处理 eval["score"] 语法
  189. where = where.replace('eval["score"]', "(eval->>'score')::int")
  190. # 把所有剩余的双引号字符串值替换为单引号(PostgreSQL标准)
  191. where = re.sub(r'"([^"]*)"', r"'\1'", where)
  192. return f"WHERE {where}"
  193. def _format_result(self, row: Dict) -> Dict:
  194. """格式化查询结果"""
  195. if not row:
  196. return None
  197. result = dict(row)
  198. if 'tags' in result and isinstance(result['tags'], str):
  199. result['tags'] = json.loads(result['tags'])
  200. if 'source' in result and isinstance(result['source'], str):
  201. result['source'] = json.loads(result['source'])
  202. if 'eval' in result and isinstance(result['eval'], str):
  203. result['eval'] = json.loads(result['eval'])
  204. if 'relationships' in result and isinstance(result['relationships'], str):
  205. result['relationships'] = json.loads(result['relationships'])
  206. if 'support_capability' in result and isinstance(result['support_capability'], str):
  207. result['support_capability'] = json.loads(result['support_capability'])
  208. if 'tools' in result and isinstance(result['tools'], str):
  209. result['tools'] = json.loads(result['tools'])
  210. if 'created_at' in result and result['created_at']:
  211. result['created_at'] = result['created_at'] * 1000
  212. if 'updated_at' in result and result['updated_at']:
  213. result['updated_at'] = result['updated_at'] * 1000
  214. return result
  215. def close(self):
  216. """关闭连接"""
  217. if self.conn:
  218. self.conn.close()
  219. def insert_batch(self, knowledge_list: List[Dict]):
  220. """批量插入知识"""
  221. if not knowledge_list:
  222. return
  223. cursor = self._get_cursor()
  224. try:
  225. data = []
  226. for k in knowledge_list:
  227. data.append((
  228. k['id'], k.get('task_embedding') or k.get('embedding'),
  229. k.get('content_embedding'),
  230. k['message_id'], k['task'],
  231. k['content'], k.get('types', []),
  232. json.dumps(k.get('tags', {})), k.get('tag_keys', []),
  233. k.get('scopes', []), k['owner'], k.get('resource_ids', []),
  234. json.dumps(k.get('source', {})), json.dumps(k.get('eval', {})),
  235. k['created_at'], k['updated_at'], k.get('status', 'approved'),
  236. json.dumps(k.get('relationships', [])),
  237. json.dumps(k.get('support_capability', [])),
  238. json.dumps(k.get('tools', [])),
  239. ))
  240. execute_batch(cursor, """
  241. INSERT INTO knowledge (
  242. id, task_embedding, content_embedding, message_id, task, content, types, tags,
  243. tag_keys, scopes, owner, resource_ids, source, eval,
  244. created_at, updated_at, status, relationships,
  245. support_capability, tools
  246. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  247. """, data)
  248. self.conn.commit()
  249. finally:
  250. cursor.close()