registry.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. """
  2. 内容平台注册表
  3. 定义所有支持的内容平台及其搜索参数 schema。
  4. 供 content_platforms / content_search / content_detail 路由使用。
  5. """
  6. from dataclasses import dataclass, field
  7. from typing import Any, Callable, Coroutine, Dict, List, Optional
  8. from agent.tools.models import ToolResult
  9. # ── 类型定义 ──
  10. @dataclass
  11. class ParamSpec:
  12. """平台专属参数的描述"""
  13. values: Optional[List[str]] = None # 枚举值(None 表示自由文本)
  14. default: Optional[str] = None
  15. note: str = "" # 额外说明
  16. def to_dict(self) -> dict:
  17. d: dict = {}
  18. if self.values is not None:
  19. d["values"] = self.values
  20. d["default"] = self.default
  21. if self.note:
  22. d["note"] = self.note
  23. return d
  24. # 平台实现函数的签名
  25. SearchFunc = Callable[..., Coroutine[Any, Any, ToolResult]]
  26. DetailFunc = Callable[..., Coroutine[Any, Any, ToolResult]]
  27. SuggestFunc = Callable[..., Coroutine[Any, Any, ToolResult]]
  28. @dataclass
  29. class PlatformDef:
  30. """一个内容平台的完整定义"""
  31. id: str # 唯一标识,如 "xhs"
  32. name: str # 显示名,如 "小红书"
  33. aliases: List[str] = field(default_factory=list) # 模糊匹配别名,如 ["小红书", "RED"]
  34. search_params: Dict[str, ParamSpec] = field(default_factory=dict)
  35. detail_extras: Dict[str, ParamSpec] = field(default_factory=dict)
  36. supports_suggest: bool = False
  37. suggest_channels: Optional[List[str]] = None # suggest API 的 channel 值(可能与 id 不同)
  38. # 平台实现函数(运行时由 platforms/ 模块设置)
  39. search_impl: Optional[SearchFunc] = None
  40. detail_impl: Optional[DetailFunc] = None
  41. suggest_impl: Optional[SuggestFunc] = None
  42. def summary(self) -> dict:
  43. """概要信息(不含参数细节)"""
  44. d = {"id": self.id, "name": self.name}
  45. if self.search_params:
  46. d["has_search_params"] = True
  47. if self.detail_extras:
  48. d["has_detail_extras"] = True
  49. if self.supports_suggest:
  50. d["supports_suggest"] = True
  51. return d
  52. def detail(self) -> dict:
  53. """完整参数说明"""
  54. d = self.summary()
  55. if self.search_params:
  56. d["search_params"] = {k: v.to_dict() for k, v in self.search_params.items()}
  57. if self.detail_extras:
  58. d["detail_extras"] = {k: v.to_dict() for k, v in self.detail_extras.items()}
  59. return d
  60. # ── 平台注册表 ──
  61. _PLATFORMS: Dict[str, PlatformDef] = {}
  62. def register_platform(p: PlatformDef) -> None:
  63. _PLATFORMS[p.id] = p
  64. def get_platform(platform_id: str) -> Optional[PlatformDef]:
  65. return _PLATFORMS.get(platform_id)
  66. def all_platforms() -> List[PlatformDef]:
  67. return list(_PLATFORMS.values())
  68. def match_platforms(query: str) -> List[PlatformDef]:
  69. """
  70. 模糊匹配平台:精确 ID > 别名包含 > token 交集。
  71. 空 query 返回全部。
  72. """
  73. if not query:
  74. return all_platforms()
  75. q = query.strip().lower()
  76. # 1) 精确 ID 匹配
  77. if q in _PLATFORMS:
  78. return [_PLATFORMS[q]]
  79. # 2) 别名 / 名称包含匹配
  80. alias_hits = [
  81. p for p in _PLATFORMS.values()
  82. if q in p.name.lower() or any(q in a.lower() for a in p.aliases)
  83. ]
  84. if alias_hits:
  85. return alias_hits
  86. # 3) token 交集(把 query 拆成字符/词,看命中率)
  87. q_tokens = set(q.replace("_", " ").replace("-", " ").split())
  88. scored = []
  89. for p in _PLATFORMS.values():
  90. pool = {p.id, p.name.lower()} | {a.lower() for a in p.aliases}
  91. pool_text = " ".join(pool)
  92. hits = sum(1 for t in q_tokens if t in pool_text)
  93. if hits > 0:
  94. scored.append((hits, p))
  95. scored.sort(key=lambda x: -x[0])
  96. return [p for _, p in scored]