pg_store.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. """
  2. PostgreSQL 存储封装(替代 Milvus)
  3. 使用远程 PostgreSQL + pgvector/fastann 存储知识数据
  4. """
  5. import os
  6. import json
  7. import psycopg2
  8. from psycopg2.extras import RealDictCursor, execute_batch
  9. from typing import List, Dict, Optional
  10. from dotenv import load_dotenv
  11. load_dotenv()
  12. class PostgreSQLStore:
  13. def __init__(self):
  14. """初始化 PostgreSQL 连接"""
  15. self.conn = psycopg2.connect(
  16. host=os.getenv('KNOWHUB_DB'),
  17. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  18. user=os.getenv('KNOWHUB_USER'),
  19. password=os.getenv('KNOWHUB_PASSWORD'),
  20. database=os.getenv('KNOWHUB_DB_NAME')
  21. )
  22. self.conn.autocommit = False
  23. print(f"[PostgreSQL] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  24. def _get_cursor(self):
  25. """获取游标"""
  26. return self.conn.cursor(cursor_factory=RealDictCursor)
  27. def insert(self, knowledge: Dict):
  28. """插入单条知识"""
  29. cursor = self._get_cursor()
  30. try:
  31. cursor.execute("""
  32. INSERT INTO knowledge (
  33. id, embedding, message_id, task, content, types, tags,
  34. tag_keys, scopes, owner, resource_ids, source, eval,
  35. created_at, updated_at, status, relationships
  36. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  37. """, (
  38. knowledge['id'],
  39. knowledge['embedding'],
  40. knowledge['message_id'],
  41. knowledge['task'],
  42. knowledge['content'],
  43. knowledge.get('types', []),
  44. json.dumps(knowledge.get('tags', {})),
  45. knowledge.get('tag_keys', []),
  46. knowledge.get('scopes', []),
  47. knowledge['owner'],
  48. knowledge.get('resource_ids', []),
  49. json.dumps(knowledge.get('source', {})),
  50. json.dumps(knowledge.get('eval', {})),
  51. knowledge['created_at'],
  52. knowledge['updated_at'],
  53. knowledge.get('status', 'approved'),
  54. json.dumps(knowledge.get('relationships', []))
  55. ))
  56. self.conn.commit()
  57. finally:
  58. cursor.close()
  59. def search(self, query_embedding: List[float], filters: Optional[str] = None, limit: int = 10) -> List[Dict]:
  60. """向量检索(使用余弦相似度)"""
  61. cursor = self._get_cursor()
  62. try:
  63. where_clause = self._build_where_clause(filters) if filters else ""
  64. sql = f"""
  65. SELECT id, message_id, task, content, types, tags, tag_keys,
  66. scopes, owner, resource_ids, source, eval, created_at,
  67. updated_at, status, relationships,
  68. 1 - (embedding <=> %s::real[]) as score
  69. FROM knowledge
  70. {where_clause}
  71. ORDER BY embedding <=> %s::real[]
  72. LIMIT %s
  73. """
  74. cursor.execute(sql, (query_embedding, query_embedding, limit))
  75. results = cursor.fetchall()
  76. return [self._format_result(r) for r in results]
  77. finally:
  78. cursor.close()
  79. def query(self, filters: str, limit: int = 100) -> List[Dict]:
  80. """纯标量查询"""
  81. cursor = self._get_cursor()
  82. try:
  83. where_clause = self._build_where_clause(filters)
  84. sql = f"""
  85. SELECT id, message_id, task, content, types, tags, tag_keys,
  86. scopes, owner, resource_ids, source, eval, created_at,
  87. updated_at, status, relationships
  88. FROM knowledge
  89. {where_clause}
  90. LIMIT %s
  91. """
  92. cursor.execute(sql, (limit,))
  93. results = cursor.fetchall()
  94. return [self._format_result(r) for r in results]
  95. finally:
  96. cursor.close()
  97. def get_by_id(self, knowledge_id: str, include_embedding: bool = False) -> Optional[Dict]:
  98. """根据ID获取知识(默认不返回embedding以提升性能)"""
  99. cursor = self._get_cursor()
  100. try:
  101. # 默认不返回embedding(1536维向量太大,详情页不需要)
  102. if include_embedding:
  103. fields = "id, embedding, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships"
  104. else:
  105. fields = "id, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships"
  106. cursor.execute(f"""
  107. SELECT {fields}
  108. FROM knowledge WHERE id = %s
  109. """, (knowledge_id,))
  110. result = cursor.fetchone()
  111. return self._format_result(result) if result else None
  112. finally:
  113. cursor.close()
  114. def update(self, knowledge_id: str, updates: Dict):
  115. """更新知识"""
  116. cursor = self._get_cursor()
  117. try:
  118. set_parts = []
  119. params = []
  120. for key, value in updates.items():
  121. if key in ('tags', 'source', 'eval'):
  122. set_parts.append(f"{key} = %s")
  123. params.append(json.dumps(value))
  124. elif key == 'relationships':
  125. set_parts.append(f"{key} = %s")
  126. params.append(json.dumps(value) if isinstance(value, list) else value)
  127. else:
  128. set_parts.append(f"{key} = %s")
  129. params.append(value)
  130. params.append(knowledge_id)
  131. sql = f"UPDATE knowledge SET {', '.join(set_parts)} WHERE id = %s"
  132. cursor.execute(sql, params)
  133. self.conn.commit()
  134. finally:
  135. cursor.close()
  136. def delete(self, knowledge_id: str):
  137. """删除知识"""
  138. cursor = self._get_cursor()
  139. try:
  140. cursor.execute("DELETE FROM knowledge WHERE id = %s", (knowledge_id,))
  141. self.conn.commit()
  142. finally:
  143. cursor.close()
  144. def count(self) -> int:
  145. """返回知识总数"""
  146. cursor = self._get_cursor()
  147. try:
  148. cursor.execute("SELECT COUNT(*) as count FROM knowledge")
  149. return cursor.fetchone()['count']
  150. finally:
  151. cursor.close()
  152. def _build_where_clause(self, filters: str) -> str:
  153. """将Milvus风格的过滤表达式转换为PostgreSQL WHERE子句"""
  154. if not filters:
  155. return ""
  156. where = filters
  157. import re
  158. # 替换操作符
  159. where = where.replace(' == ', ' = ')
  160. where = where.replace(' or ', ' OR ')
  161. where = where.replace(' and ', ' AND ')
  162. # 处理数组包含操作
  163. where = re.sub(r'array_contains\((\w+),\s*"([^"]+)"\)', r"\1 @> ARRAY['\2']", where)
  164. # 处理 eval["score"] 语法
  165. where = where.replace('eval["score"]', "(eval->>'score')::int")
  166. # 把所有剩余的双引号字符串值替换为单引号(PostgreSQL标准)
  167. where = re.sub(r'"([^"]*)"', r"'\1'", where)
  168. return f"WHERE {where}"
  169. def _format_result(self, row: Dict) -> Dict:
  170. """格式化查询结果"""
  171. if not row:
  172. return None
  173. result = dict(row)
  174. if 'tags' in result and isinstance(result['tags'], str):
  175. result['tags'] = json.loads(result['tags'])
  176. if 'source' in result and isinstance(result['source'], str):
  177. result['source'] = json.loads(result['source'])
  178. if 'eval' in result and isinstance(result['eval'], str):
  179. result['eval'] = json.loads(result['eval'])
  180. if 'relationships' in result and isinstance(result['relationships'], str):
  181. result['relationships'] = json.loads(result['relationships'])
  182. if 'created_at' in result and result['created_at']:
  183. result['created_at'] = result['created_at'] * 1000
  184. if 'updated_at' in result and result['updated_at']:
  185. result['updated_at'] = result['updated_at'] * 1000
  186. return result
  187. def close(self):
  188. """关闭连接"""
  189. if self.conn:
  190. self.conn.close()
  191. def insert_batch(self, knowledge_list: List[Dict]):
  192. """批量插入知识"""
  193. if not knowledge_list:
  194. return
  195. cursor = self._get_cursor()
  196. try:
  197. data = []
  198. for k in knowledge_list:
  199. data.append((
  200. k['id'], k['embedding'], k['message_id'], k['task'],
  201. k['content'], k.get('types', []),
  202. json.dumps(k.get('tags', {})), k.get('tag_keys', []),
  203. k.get('scopes', []), k['owner'], k.get('resource_ids', []),
  204. json.dumps(k.get('source', {})), json.dumps(k.get('eval', {})),
  205. k['created_at'], k['updated_at'], k.get('status', 'approved'),
  206. json.dumps(k.get('relationships', []))
  207. ))
  208. execute_batch(cursor, """
  209. INSERT INTO knowledge (
  210. id, embedding, message_id, task, content, types, tags,
  211. tag_keys, scopes, owner, resource_ids, source, eval,
  212. created_at, updated_at, status, relationships
  213. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  214. """, data)
  215. self.conn.commit()
  216. finally:
  217. cursor.close()