1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- import math
- import asyncio
- import numpy as np
- from typing import Dict, Any, List, Tuple, Union
- from applications.config import DEFAULT_MODEL
- from applications.api import get_basic_embedding
- class KGClassifier:
- def __init__(self, kg_spec: Dict[str, Any]):
- self.root = kg_spec["root"]
- self._embed_cache: Dict[str, np.ndarray] = {}
- async def init_cache(self):
- """
- 并发初始化知识图谱节点的 embedding 缓存
- """
- async def collect_nodes(node, nodes: List[str]):
- """收集所有节点名称"""
- nodes.append(node["name"])
- for ch in node.get("children", []):
- await collect_nodes(ch, nodes)
- nodes: List[str] = []
- await collect_nodes(self.root, nodes)
- # 去掉重复的
- unique_nodes = list(set(nodes))
- async def fetch(name: str):
- if name not in self._embed_cache:
- self._embed_cache[name] = await self._get_embedding(name)
- # 并发执行
- await asyncio.gather(*(fetch(name) for name in unique_nodes))
- @staticmethod
- async def _get_embedding(text: str) -> np.ndarray:
- """
- 调用 HTTP embedding 服务,返回向量
- """
- embedding = await get_basic_embedding(text=text, model=DEFAULT_MODEL, dev=True)
- return np.array(embedding, dtype=np.float32)
- async def classify(
- self, text: Union[str, np.ndarray], topk: int = 3
- ) -> Tuple[List[str], float]:
- """
- 支持输入原始文本或预先算好的 embedding。
- 返回 (topic_path, purity)。
- """
- if isinstance(text, str):
- text_emb = await self._get_embedding(text)
- else:
- text_emb = text
- path, purities = [], []
- node = self.root
- while True:
- children = node.get("children", [])
- if not children:
- break
- scores = []
- for ch in children:
- vec = self._embed_cache[ch["name"]]
- score = float(
- np.dot(text_emb, vec)
- / (np.linalg.norm(text_emb) * np.linalg.norm(vec))
- )
- scores.append((ch, score))
- scores.sort(key=lambda x: x[1], reverse=True)
- best, second = scores[0], (scores[1] if len(scores) > 1 else (None, -1.0))
- path.append(best[0]["name"])
- margin = max(0.0, (best[1] - max(second[1], -1.0)))
- purities.append(1 / (1 + math.exp(-5 * margin)))
- node = best[0]
- purity = float(np.mean(purities)) if purities else 1.0
- return path, purity
|