pg_requirement_store.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. """
  2. PostgreSQL requirement_table 存储封装
  3. 用于存储和检索需求数据,支持向量检索
  4. """
  5. import os
  6. import json
  7. import psycopg2
  8. from psycopg2.extras import RealDictCursor
  9. from typing import List, Dict, Optional
  10. from dotenv import load_dotenv
  11. load_dotenv()
  12. class PostgreSQLRequirementStore:
  13. def __init__(self):
  14. """初始化 PostgreSQL 连接"""
  15. self.conn = psycopg2.connect(
  16. host=os.getenv('KNOWHUB_DB'),
  17. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  18. user=os.getenv('KNOWHUB_USER'),
  19. password=os.getenv('KNOWHUB_PASSWORD'),
  20. database=os.getenv('KNOWHUB_DB_NAME')
  21. )
  22. self.conn.autocommit = False
  23. print(f"[PostgreSQL Requirement] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
  24. def _get_cursor(self):
  25. """获取游标"""
  26. return self.conn.cursor(cursor_factory=RealDictCursor)
  27. def insert_or_update(self, requirement: Dict):
  28. """插入或更新需求记录"""
  29. cursor = self._get_cursor()
  30. try:
  31. cursor.execute("""
  32. INSERT INTO requirement_table (
  33. id, task, type, source_type, source_itemset_id,
  34. source_items, tools, knowledge, case_knowledge,
  35. process_knowledge, trace, body, embedding
  36. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  37. ON CONFLICT (id) DO UPDATE SET
  38. task = EXCLUDED.task,
  39. type = EXCLUDED.type,
  40. source_type = EXCLUDED.source_type,
  41. source_itemset_id = EXCLUDED.source_itemset_id,
  42. source_items = EXCLUDED.source_items,
  43. tools = EXCLUDED.tools,
  44. knowledge = EXCLUDED.knowledge,
  45. case_knowledge = EXCLUDED.case_knowledge,
  46. process_knowledge = EXCLUDED.process_knowledge,
  47. trace = EXCLUDED.trace,
  48. body = EXCLUDED.body,
  49. embedding = EXCLUDED.embedding
  50. """, (
  51. requirement['id'],
  52. requirement['task'],
  53. requirement.get('type', '制作'),
  54. requirement.get('source_type', 'itemset'),
  55. requirement.get('source_itemset_id', ''),
  56. json.dumps(requirement.get('source_items', [])),
  57. json.dumps(requirement.get('tools', [])),
  58. json.dumps(requirement.get('knowledge', [])),
  59. json.dumps(requirement.get('case_knowledge', [])),
  60. json.dumps(requirement.get('process_knowledge', [])),
  61. json.dumps(requirement.get('trace', {})),
  62. requirement.get('body', ''),
  63. requirement['embedding']
  64. ))
  65. self.conn.commit()
  66. finally:
  67. cursor.close()
  68. def get_by_id(self, req_id: str) -> Optional[Dict]:
  69. """根据 ID 获取需求"""
  70. cursor = self._get_cursor()
  71. try:
  72. cursor.execute("""
  73. SELECT * FROM requirement_table WHERE id = %s
  74. """, (req_id,))
  75. result = cursor.fetchone()
  76. return dict(result) if result else None
  77. finally:
  78. cursor.close()
  79. def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]:
  80. """向量检索需求(使用 fastann ANN 索引)"""
  81. cursor = self._get_cursor()
  82. try:
  83. sql = """
  84. SELECT id, task, type, source_type, source_itemset_id,
  85. source_items, tools, knowledge, case_knowledge,
  86. process_knowledge, trace, body,
  87. 1 - (embedding <=> %s::real[]) as score
  88. FROM requirement_table
  89. ORDER BY embedding <=> %s::real[]
  90. LIMIT %s
  91. """
  92. cursor.execute(sql, (query_embedding, query_embedding, limit))
  93. results = cursor.fetchall()
  94. return [dict(r) for r in results]
  95. finally:
  96. cursor.close()
  97. def list_all(self, limit: int = 100) -> List[Dict]:
  98. """列出所有需求"""
  99. cursor = self._get_cursor()
  100. try:
  101. cursor.execute("""
  102. SELECT * FROM requirement_table LIMIT %s
  103. """, (limit,))
  104. results = cursor.fetchall()
  105. return [dict(r) for r in results]
  106. finally:
  107. cursor.close()
  108. def count(self) -> int:
  109. """统计需求总数"""
  110. cursor = self._get_cursor()
  111. try:
  112. cursor.execute("SELECT COUNT(*) as count FROM requirement_table")
  113. result = cursor.fetchone()
  114. return result['count'] if result else 0
  115. finally:
  116. cursor.close()
  117. def close(self):
  118. """关闭连接"""
  119. if self.conn:
  120. self.conn.close()