pg_store.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. """
  2. PostgreSQL 存储封装(替代 Milvus)
  3. 使用远程 PostgreSQL + pgvector/fastann 存储知识数据
  4. """
  5. import os
  6. import json
  7. import psycopg2
  8. from psycopg2.extras import RealDictCursor, execute_batch
  9. from typing import List, Dict, Optional
  10. from dotenv import load_dotenv
  11. load_dotenv()
  12. class PostgreSQLStore:
  13. def __init__(self):
  14. """初始化 PostgreSQL 连接"""
  15. self.conn = psycopg2.connect(
  16. host=os.getenv('KNOWHUB_DB'),
  17. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  18. user=os.getenv('KNOWHUB_USER'),
  19. password=os.getenv('KNOWHUB_PASSWORD'),
  20. database=os.getenv('KNOWHUB_DB_NAME')
  21. )
  22. self.conn.autocommit = False
  23. print(f"[PostgreSQL] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  24. def _get_cursor(self):
  25. """获取游标"""
  26. return self.conn.cursor(cursor_factory=RealDictCursor)
  27. def insert(self, knowledge: Dict):
  28. """插入单条知识"""
  29. cursor = self._get_cursor()
  30. try:
  31. cursor.execute("""
  32. INSERT INTO knowledge (
  33. id, task_embedding, message_id, task, content, types, tags,
  34. tag_keys, scopes, owner, resource_ids, source, eval,
  35. created_at, updated_at, status, relationships,
  36. support_capability, tools
  37. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  38. """, (
  39. knowledge['id'],
  40. knowledge.get('task_embedding') or knowledge.get('embedding'),
  41. knowledge['message_id'],
  42. knowledge['task'],
  43. knowledge['content'],
  44. knowledge.get('types', []),
  45. json.dumps(knowledge.get('tags', {})),
  46. knowledge.get('tag_keys', []),
  47. knowledge.get('scopes', []),
  48. knowledge['owner'],
  49. knowledge.get('resource_ids', []),
  50. json.dumps(knowledge.get('source', {})),
  51. json.dumps(knowledge.get('eval', {})),
  52. knowledge['created_at'],
  53. knowledge['updated_at'],
  54. knowledge.get('status', 'approved'),
  55. json.dumps(knowledge.get('relationships', [])),
  56. json.dumps(knowledge.get('support_capability', [])),
  57. json.dumps(knowledge.get('tools', [])),
  58. ))
  59. self.conn.commit()
  60. finally:
  61. cursor.close()
  62. def search(self, query_embedding: List[float], filters: Optional[str] = None, limit: int = 10) -> List[Dict]:
  63. """向量检索(使用余弦相似度)"""
  64. cursor = self._get_cursor()
  65. try:
  66. where_clause = self._build_where_clause(filters) if filters else ""
  67. sql = f"""
  68. SELECT id, message_id, task, content, types, tags, tag_keys,
  69. scopes, owner, resource_ids, source, eval, created_at,
  70. updated_at, status, relationships, support_capability, tools,
  71. 1 - (task_embedding <=> %s::real[]) as score
  72. FROM knowledge
  73. {where_clause}
  74. ORDER BY task_embedding <=> %s::real[]
  75. LIMIT %s
  76. """
  77. cursor.execute(sql, (query_embedding, query_embedding, limit))
  78. results = cursor.fetchall()
  79. return [self._format_result(r) for r in results]
  80. finally:
  81. cursor.close()
  82. def query(self, filters: str, limit: int = 100) -> List[Dict]:
  83. """纯标量查询"""
  84. cursor = self._get_cursor()
  85. try:
  86. where_clause = self._build_where_clause(filters)
  87. sql = f"""
  88. SELECT id, message_id, task, content, types, tags, tag_keys,
  89. scopes, owner, resource_ids, source, eval, created_at,
  90. updated_at, status, relationships, support_capability, tools
  91. FROM knowledge
  92. {where_clause}
  93. LIMIT %s
  94. """
  95. cursor.execute(sql, (limit,))
  96. results = cursor.fetchall()
  97. return [self._format_result(r) for r in results]
  98. finally:
  99. cursor.close()
  100. def get_by_id(self, knowledge_id: str, include_embedding: bool = False) -> Optional[Dict]:
  101. """根据ID获取知识(默认不返回embedding以提升性能)"""
  102. cursor = self._get_cursor()
  103. try:
  104. # 默认不返回embedding(1536维向量太大,详情页不需要)
  105. if include_embedding:
  106. fields = "id, task_embedding, content_embedding, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships, support_capability, tools"
  107. else:
  108. fields = "id, message_id, task, content, types, tags, tag_keys, scopes, owner, resource_ids, source, eval, created_at, updated_at, status, relationships, support_capability, tools"
  109. cursor.execute(f"""
  110. SELECT {fields}
  111. FROM knowledge WHERE id = %s
  112. """, (knowledge_id,))
  113. result = cursor.fetchone()
  114. return self._format_result(result) if result else None
  115. finally:
  116. cursor.close()
  117. def update(self, knowledge_id: str, updates: Dict):
  118. """更新知识"""
  119. cursor = self._get_cursor()
  120. try:
  121. set_parts = []
  122. params = []
  123. for key, value in updates.items():
  124. if key in ('tags', 'source', 'eval', 'support_capability', 'tools'):
  125. set_parts.append(f"{key} = %s")
  126. params.append(json.dumps(value))
  127. elif key == 'relationships':
  128. set_parts.append(f"{key} = %s")
  129. params.append(json.dumps(value) if isinstance(value, list) else value)
  130. else:
  131. set_parts.append(f"{key} = %s")
  132. params.append(value)
  133. params.append(knowledge_id)
  134. sql = f"UPDATE knowledge SET {', '.join(set_parts)} WHERE id = %s"
  135. cursor.execute(sql, params)
  136. self.conn.commit()
  137. finally:
  138. cursor.close()
  139. def delete(self, knowledge_id: str):
  140. """删除知识"""
  141. cursor = self._get_cursor()
  142. try:
  143. cursor.execute("DELETE FROM knowledge WHERE id = %s", (knowledge_id,))
  144. self.conn.commit()
  145. finally:
  146. cursor.close()
  147. def count(self) -> int:
  148. """返回知识总数"""
  149. cursor = self._get_cursor()
  150. try:
  151. cursor.execute("SELECT COUNT(*) as count FROM knowledge")
  152. return cursor.fetchone()['count']
  153. finally:
  154. cursor.close()
  155. def _build_where_clause(self, filters: str) -> str:
  156. """将Milvus风格的过滤表达式转换为PostgreSQL WHERE子句"""
  157. if not filters:
  158. return ""
  159. where = filters
  160. import re
  161. # 替换操作符
  162. where = where.replace(' == ', ' = ')
  163. where = where.replace(' or ', ' OR ')
  164. where = where.replace(' and ', ' AND ')
  165. # 处理数组包含操作
  166. where = re.sub(r'array_contains\((\w+),\s*"([^"]+)"\)', r"\1 @> ARRAY['\2']", where)
  167. # 处理 eval["score"] 语法
  168. where = where.replace('eval["score"]', "(eval->>'score')::int")
  169. # 把所有剩余的双引号字符串值替换为单引号(PostgreSQL标准)
  170. where = re.sub(r'"([^"]*)"', r"'\1'", where)
  171. return f"WHERE {where}"
  172. def _format_result(self, row: Dict) -> Dict:
  173. """格式化查询结果"""
  174. if not row:
  175. return None
  176. result = dict(row)
  177. if 'tags' in result and isinstance(result['tags'], str):
  178. result['tags'] = json.loads(result['tags'])
  179. if 'source' in result and isinstance(result['source'], str):
  180. result['source'] = json.loads(result['source'])
  181. if 'eval' in result and isinstance(result['eval'], str):
  182. result['eval'] = json.loads(result['eval'])
  183. if 'relationships' in result and isinstance(result['relationships'], str):
  184. result['relationships'] = json.loads(result['relationships'])
  185. if 'support_capability' in result and isinstance(result['support_capability'], str):
  186. result['support_capability'] = json.loads(result['support_capability'])
  187. if 'tools' in result and isinstance(result['tools'], str):
  188. result['tools'] = json.loads(result['tools'])
  189. if 'created_at' in result and result['created_at']:
  190. result['created_at'] = result['created_at'] * 1000
  191. if 'updated_at' in result and result['updated_at']:
  192. result['updated_at'] = result['updated_at'] * 1000
  193. return result
  194. def close(self):
  195. """关闭连接"""
  196. if self.conn:
  197. self.conn.close()
  198. def insert_batch(self, knowledge_list: List[Dict]):
  199. """批量插入知识"""
  200. if not knowledge_list:
  201. return
  202. cursor = self._get_cursor()
  203. try:
  204. data = []
  205. for k in knowledge_list:
  206. data.append((
  207. k['id'], k.get('task_embedding') or k.get('embedding'), k['message_id'], k['task'],
  208. k['content'], k.get('types', []),
  209. json.dumps(k.get('tags', {})), k.get('tag_keys', []),
  210. k.get('scopes', []), k['owner'], k.get('resource_ids', []),
  211. json.dumps(k.get('source', {})), json.dumps(k.get('eval', {})),
  212. k['created_at'], k['updated_at'], k.get('status', 'approved'),
  213. json.dumps(k.get('relationships', [])),
  214. json.dumps(k.get('support_capability', [])),
  215. json.dumps(k.get('tools', [])),
  216. ))
  217. execute_batch(cursor, """
  218. INSERT INTO knowledge (
  219. id, task_embedding, message_id, task, content, types, tags,
  220. tag_keys, scopes, owner, resource_ids, source, eval,
  221. created_at, updated_at, status, relationships,
  222. support_capability, tools
  223. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  224. """, data)
  225. self.conn.commit()
  226. finally:
  227. cursor.close()