| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- """
- PostgreSQL requirement_table 存储封装
- 用于存储和检索需求数据,支持向量检索
- """
- import os
- import json
- import psycopg2
- from psycopg2.extras import RealDictCursor
- from typing import List, Dict, Optional
- from dotenv import load_dotenv
- load_dotenv()
- class PostgreSQLRequirementStore:
- def __init__(self):
- """初始化 PostgreSQL 连接"""
- self.conn = psycopg2.connect(
- host=os.getenv('KNOWHUB_DB'),
- port=int(os.getenv('KNOWHUB_PORT', 5432)),
- user=os.getenv('KNOWHUB_USER'),
- password=os.getenv('KNOWHUB_PASSWORD'),
- database=os.getenv('KNOWHUB_DB_NAME')
- )
- self.conn.autocommit = False
- print(f"[PostgreSQL Requirement] 已连接到远程数据库: {os.getenv('KNOWHUB_DB')}")
- def _get_cursor(self):
- """获取游标"""
- return self.conn.cursor(cursor_factory=RealDictCursor)
- def insert_or_update(self, requirement: Dict):
- """插入或更新需求记录"""
- cursor = self._get_cursor()
- try:
- cursor.execute("""
- INSERT INTO requirement_table (
- id, task, type, source_type, source_itemset_id,
- source_items, tools, knowledge, case_knowledge,
- process_knowledge, trace, body, embedding
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
- ON CONFLICT (id) DO UPDATE SET
- task = EXCLUDED.task,
- type = EXCLUDED.type,
- source_type = EXCLUDED.source_type,
- source_itemset_id = EXCLUDED.source_itemset_id,
- source_items = EXCLUDED.source_items,
- tools = EXCLUDED.tools,
- knowledge = EXCLUDED.knowledge,
- case_knowledge = EXCLUDED.case_knowledge,
- process_knowledge = EXCLUDED.process_knowledge,
- trace = EXCLUDED.trace,
- body = EXCLUDED.body,
- embedding = EXCLUDED.embedding
- """, (
- requirement['id'],
- requirement['task'],
- requirement.get('type', '制作'),
- requirement.get('source_type', 'itemset'),
- requirement.get('source_itemset_id', ''),
- json.dumps(requirement.get('source_items', [])),
- json.dumps(requirement.get('tools', [])),
- json.dumps(requirement.get('knowledge', [])),
- json.dumps(requirement.get('case_knowledge', [])),
- json.dumps(requirement.get('process_knowledge', [])),
- json.dumps(requirement.get('trace', {})),
- requirement.get('body', ''),
- requirement['embedding']
- ))
- self.conn.commit()
- finally:
- cursor.close()
- def get_by_id(self, req_id: str) -> Optional[Dict]:
- """根据 ID 获取需求"""
- cursor = self._get_cursor()
- try:
- cursor.execute("""
- SELECT * FROM requirement_table WHERE id = %s
- """, (req_id,))
- result = cursor.fetchone()
- return dict(result) if result else None
- finally:
- cursor.close()
- def search(self, query_embedding: List[float], limit: int = 10) -> List[Dict]:
- """向量检索需求(使用 fastann ANN 索引)"""
- cursor = self._get_cursor()
- try:
- sql = """
- SELECT id, task, type, source_type, source_itemset_id,
- source_items, tools, knowledge, case_knowledge,
- process_knowledge, trace, body,
- 1 - (embedding <=> %s::real[]) as score
- FROM requirement_table
- ORDER BY embedding <=> %s::real[]
- LIMIT %s
- """
- cursor.execute(sql, (query_embedding, query_embedding, limit))
- results = cursor.fetchall()
- return [dict(r) for r in results]
- finally:
- cursor.close()
- def list_all(self, limit: int = 100) -> List[Dict]:
- """列出所有需求"""
- cursor = self._get_cursor()
- try:
- cursor.execute("""
- SELECT * FROM requirement_table LIMIT %s
- """, (limit,))
- results = cursor.fetchall()
- return [dict(r) for r in results]
- finally:
- cursor.close()
- def count(self) -> int:
- """统计需求总数"""
- cursor = self._get_cursor()
- try:
- cursor.execute("SELECT COUNT(*) as count FROM requirement_table")
- result = cursor.fetchone()
- return result['count'] if result else 0
- finally:
- cursor.close()
- def close(self):
- """关闭连接"""
- if self.conn:
- self.conn.close()
|