import math import asyncio import numpy as np from typing import Dict, Any, List, Tuple, Union from applications.config import DEFAULT_MODEL from applications.api import get_basic_embedding class KGClassifier: def __init__(self, kg_spec: Dict[str, Any]): self.root = kg_spec["root"] self._embed_cache: Dict[str, np.ndarray] = {} async def init_cache(self): """ 并发初始化知识图谱节点的 embedding 缓存 """ async def collect_nodes(node, nodes: List[str]): """收集所有节点名称""" nodes.append(node["name"]) for ch in node.get("children", []): await collect_nodes(ch, nodes) nodes: List[str] = [] await collect_nodes(self.root, nodes) # 去掉重复的 unique_nodes = list(set(nodes)) async def fetch(name: str): if name not in self._embed_cache: self._embed_cache[name] = await self._get_embedding(name) # 并发执行 await asyncio.gather(*(fetch(name) for name in unique_nodes)) @staticmethod async def _get_embedding(text: str) -> np.ndarray: """ 调用 HTTP embedding 服务,返回向量 """ embedding = await get_basic_embedding(text=text, model=DEFAULT_MODEL, dev=True) return np.array(embedding, dtype=np.float32) async def classify( self, text: Union[str, np.ndarray], topk: int = 3 ) -> Tuple[List[str], float]: """ 支持输入原始文本或预先算好的 embedding。 返回 (topic_path, purity)。 """ if isinstance(text, str): text_emb = await self._get_embedding(text) else: text_emb = text path, purities = [], [] node = self.root while True: children = node.get("children", []) if not children: break scores = [] for ch in children: vec = self._embed_cache[ch["name"]] score = float( np.dot(text_emb, vec) / (np.linalg.norm(text_emb) * np.linalg.norm(vec)) ) scores.append((ch, score)) scores.sort(key=lambda x: x[1], reverse=True) best, second = scores[0], (scores[1] if len(scores) > 1 else (None, -1.0)) path.append(best[0]["name"]) margin = max(0.0, (best[1] - max(second[1], -1.0))) purities.append(1 / (1 + math.exp(-5 * margin))) node = best[0] purity = float(np.mean(purities)) if purities else 1.0 return path, purity