""" PostgreSQL requirement_table 存储封装 用于存储和检索需求数据,支持向量检索 """ import os import json import psycopg2 from psycopg2.extras import RealDictCursor from typing import List, Dict, Optional from dotenv import load_dotenv load_dotenv() class PostgreSQLRequirementStore: def __init__(self): """初始化 PostgreSQL 连接""" self.conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME') ) self.conn.autocommit = False print(f"[PostgreSQL Requirement] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}") def _get_cursor(self): """获取游标""" return self.conn.cursor(cursor_factory=RealDictCursor) def insert_or_update(self, requirement: Dict): """插入或更新需求记录""" cursor = self._get_cursor() try: cursor.execute(""" INSERT INTO requirement_table ( id, task, type, source_type, source_itemset_id, source_items, tools, knowledge, case_knowledge, process_knowledge, trace, body, embedding ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (id) DO UPDATE SET task = EXCLUDED.task, type = EXCLUDED.type, source_type = EXCLUDED.source_type, source_itemset_id = EXCLUDED.source_itemset_id, source_items = EXCLUDED.source_items, tools = EXCLUDED.tools, knowledge = EXCLUDED.knowledge, case_knowledge = EXCLUDED.case_knowledge, process_knowledge = EXCLUDED.process_knowledge, trace = EXCLUDED.trace, body = EXCLUDED.body, embedding = EXCLUDED.embedding """, ( requirement['id'], requirement['task'], requirement.get('type', '制作'), requirement.get('source_type', 'itemset'), requirement.get('source_itemset_id', ''), json.dumps(requirement.get('source_items', [])), json.dumps(requirement.get('tools', [])), json.dumps(requirement.get('knowledge', [])), json.dumps(requirement.get('case_knowledge', [])), json.dumps(requirement.get('process_knowledge', [])), json.dumps(requirement.get('trace', {})), requirement.get('body', ''), requirement['embedding'] )) self.conn.commit() finally: cursor.close() def get_by_id(self, req_id: str) -> Optional[Dict]: """根据 ID 获取需求""" cursor = self._get_cursor() try: cursor.execute(""" SELECT * FROM requirement_table WHERE id = %s """, (req_id,)) result = cursor.fetchone() return dict(result) if result else None finally: cursor.close() def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]: """向量检索需求(使用 fastann ANN 索引)""" cursor = self._get_cursor() try: sql = """ SELECT id, task, type, source_type, source_itemset_id, source_items, tools, knowledge, case_knowledge, process_knowledge, trace, body, 1 - (embedding <=> %s::real[]) as score FROM requirement_table ORDER BY embedding <=> %s::real[] LIMIT %s """ cursor.execute(sql, (query_embedding, query_embedding, limit)) results = cursor.fetchall() return [dict(r) for r in results] finally: cursor.close() def list_all(self, limit: int = 100) -> List[Dict]: """列出所有需求""" cursor = self._get_cursor() try: cursor.execute(""" SELECT * FROM requirement_table LIMIT %s """, (limit,)) results = cursor.fetchall() return [dict(r) for r in results] finally: cursor.close() def count(self) -> int: """统计需求总数""" cursor = self._get_cursor() try: cursor.execute("SELECT COUNT(*) as count FROM requirement_table") result = cursor.fetchone() return result['count'] if result else 0 finally: cursor.close() def close(self): """关闭连接""" if self.conn: self.conn.close()