| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- """
- 内容平台注册表
- 定义所有支持的内容平台及其搜索参数 schema。
- 供 content_platforms / content_search / content_detail 路由使用。
- """
- from dataclasses import dataclass, field
- from typing import Any, Callable, Coroutine, Dict, List, Optional
- from agent.tools.models import ToolResult
- # ── 类型定义 ──
- @dataclass
- class ParamSpec:
- """平台专属参数的描述"""
- values: Optional[List[str]] = None # 枚举值(None 表示自由文本)
- default: Optional[str] = None
- note: str = "" # 额外说明
- def to_dict(self) -> dict:
- d: dict = {}
- if self.values is not None:
- d["values"] = self.values
- d["default"] = self.default
- if self.note:
- d["note"] = self.note
- return d
- # 平台实现函数的签名
- SearchFunc = Callable[..., Coroutine[Any, Any, ToolResult]]
- DetailFunc = Callable[..., Coroutine[Any, Any, ToolResult]]
- SuggestFunc = Callable[..., Coroutine[Any, Any, ToolResult]]
- @dataclass
- class PlatformDef:
- """一个内容平台的完整定义"""
- id: str # 唯一标识,如 "xhs"
- name: str # 显示名,如 "小红书"
- aliases: List[str] = field(default_factory=list) # 模糊匹配别名,如 ["小红书", "RED"]
- search_params: Dict[str, ParamSpec] = field(default_factory=dict)
- detail_extras: Dict[str, ParamSpec] = field(default_factory=dict)
- supports_suggest: bool = False
- suggest_channels: Optional[List[str]] = None # suggest API 的 channel 值(可能与 id 不同)
- # 平台实现函数(运行时由 platforms/ 模块设置)
- search_impl: Optional[SearchFunc] = None
- detail_impl: Optional[DetailFunc] = None
- suggest_impl: Optional[SuggestFunc] = None
- def summary(self) -> dict:
- """概要信息(不含参数细节)"""
- d = {"id": self.id, "name": self.name}
- if self.search_params:
- d["has_search_params"] = True
- if self.detail_extras:
- d["has_detail_extras"] = True
- if self.supports_suggest:
- d["supports_suggest"] = True
- return d
- def detail(self) -> dict:
- """完整参数说明"""
- d = self.summary()
- if self.search_params:
- d["search_params"] = {k: v.to_dict() for k, v in self.search_params.items()}
- if self.detail_extras:
- d["detail_extras"] = {k: v.to_dict() for k, v in self.detail_extras.items()}
- return d
- # ── 平台注册表 ──
- _PLATFORMS: Dict[str, PlatformDef] = {}
- def register_platform(p: PlatformDef) -> None:
- _PLATFORMS[p.id] = p
- def get_platform(platform_id: str) -> Optional[PlatformDef]:
- return _PLATFORMS.get(platform_id)
- def all_platforms() -> List[PlatformDef]:
- return list(_PLATFORMS.values())
- def match_platforms(query: str) -> List[PlatformDef]:
- """
- 模糊匹配平台:精确 ID > 别名包含 > token 交集。
- 空 query 返回全部。
- """
- if not query:
- return all_platforms()
- q = query.strip().lower()
- # 1) 精确 ID 匹配
- if q in _PLATFORMS:
- return [_PLATFORMS[q]]
- # 2) 别名 / 名称包含匹配
- alias_hits = [
- p for p in _PLATFORMS.values()
- if q in p.name.lower() or any(q in a.lower() for a in p.aliases)
- ]
- if alias_hits:
- return alias_hits
- # 3) token 交集(把 query 拆成字符/词,看命中率)
- q_tokens = set(q.replace("_", " ").replace("-", " ").split())
- scored = []
- for p in _PLATFORMS.values():
- pool = {p.id, p.name.lower()} | {a.lower() for a in p.aliases}
- pool_text = " ".join(pool)
- hits = sum(1 for t in q_tokens if t in pool_text)
- if hits > 0:
- scored.append((hits, p))
- scored.sort(key=lambda x: -x[0])
- return [p for _, p in scored]
|