kg_classifier.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import math
  2. import asyncio
  3. import numpy as np
  4. from typing import Dict, Any, List, Tuple, Union
  5. from applications.config import DEFAULT_MODEL
  6. from applications.api import get_basic_embedding
  7. class KGClassifier:
  8. def __init__(self, kg_spec: Dict[str, Any]):
  9. self.root = kg_spec["root"]
  10. self._embed_cache: Dict[str, np.ndarray] = {}
  11. async def init_cache(self):
  12. """
  13. 并发初始化知识图谱节点的 embedding 缓存
  14. """
  15. async def collect_nodes(node, nodes: List[str]):
  16. """收集所有节点名称"""
  17. nodes.append(node["name"])
  18. for ch in node.get("children", []):
  19. await collect_nodes(ch, nodes)
  20. nodes: List[str] = []
  21. await collect_nodes(self.root, nodes)
  22. # 去掉重复的
  23. unique_nodes = list(set(nodes))
  24. async def fetch(name: str):
  25. if name not in self._embed_cache:
  26. self._embed_cache[name] = await self._get_embedding(name)
  27. # 并发执行
  28. await asyncio.gather(*(fetch(name) for name in unique_nodes))
  29. @staticmethod
  30. async def _get_embedding(text: str) -> np.ndarray:
  31. """
  32. 调用 HTTP embedding 服务,返回向量
  33. """
  34. embedding = await get_basic_embedding(text=text, model=DEFAULT_MODEL, dev=True)
  35. return np.array(embedding, dtype=np.float32)
  36. async def classify(
  37. self, text: Union[str, np.ndarray], topk: int = 3
  38. ) -> Tuple[List[str], float]:
  39. """
  40. 支持输入原始文本或预先算好的 embedding。
  41. 返回 (topic_path, purity)。
  42. """
  43. if isinstance(text, str):
  44. text_emb = await self._get_embedding(text)
  45. else:
  46. text_emb = text
  47. path, purities = [], []
  48. node = self.root
  49. while True:
  50. children = node.get("children", [])
  51. if not children:
  52. break
  53. scores = []
  54. for ch in children:
  55. vec = self._embed_cache[ch["name"]]
  56. score = float(
  57. np.dot(text_emb, vec)
  58. / (np.linalg.norm(text_emb) * np.linalg.norm(vec))
  59. )
  60. scores.append((ch, score))
  61. scores.sort(key=lambda x: x[1], reverse=True)
  62. best, second = scores[0], (scores[1] if len(scores) > 1 else (None, -1.0))
  63. path.append(best[0]["name"])
  64. margin = max(0.0, (best[1] - max(second[1], -1.0)))
  65. purities.append(1 / (1 + math.exp(-5 * margin)))
  66. node = best[0]
  67. purity = float(np.mean(purities)) if purities else 1.0
  68. return path, purity