embeddings.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. """
  2. Embedding 生成模块
  3. 使用 OpenRouter 的 openai/text-embedding-3-small 模型生成向量。
  4. 支持单条和批量处理。
  5. """
  6. import os
  7. import asyncio
  8. from typing import List, Union
  9. import httpx
  10. OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
  11. EMBEDDING_MODEL = "openai/text-embedding-3-small"
  12. EMBEDDING_DIM = 1536
  13. def _get_api_key() -> str:
  14. """获取 API key(延迟读取环境变量)"""
  15. key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPEN_ROUTER_API_KEY")
  16. if not key:
  17. raise ValueError("OPENROUTER_API_KEY or OPEN_ROUTER_API_KEY not set in environment")
  18. return key
  19. async def get_embedding(text: str) -> List[float]:
  20. """
  21. 生成单条文本的向量
  22. Args:
  23. text: 输入文本
  24. Returns:
  25. 1536 维向量
  26. """
  27. embeddings = await get_embeddings_batch([text])
  28. return embeddings[0]
  29. async def get_embeddings_batch(texts: List[str], batch_size: int = 100) -> List[List[float]]:
  30. """
  31. 批量生成文本向量
  32. Args:
  33. texts: 文本列表
  34. batch_size: 每批处理数量(OpenAI 限制 2048)
  35. Returns:
  36. 向量列表
  37. """
  38. if not texts:
  39. return []
  40. # 分批处理
  41. all_embeddings = []
  42. for i in range(0, len(texts), batch_size):
  43. batch = texts[i:i + batch_size]
  44. embeddings = await _call_embedding_api(batch)
  45. all_embeddings.extend(embeddings)
  46. return all_embeddings
  47. async def _call_embedding_api(texts: List[str]) -> List[List[float]]:
  48. """
  49. 调用 OpenRouter embedding API
  50. Args:
  51. texts: 文本列表(单批)
  52. Returns:
  53. 向量列表
  54. """
  55. api_key = _get_api_key()
  56. async with httpx.AsyncClient(timeout=30.0) as client:
  57. response = await client.post(
  58. f"{OPENROUTER_BASE_URL}/embeddings",
  59. headers={
  60. "Authorization": f"Bearer {api_key}",
  61. "Content-Type": "application/json",
  62. },
  63. json={
  64. "model": EMBEDDING_MODEL,
  65. "input": texts,
  66. }
  67. )
  68. response.raise_for_status()
  69. data = response.json()
  70. # 按 index 排序(API 可能乱序返回)
  71. embeddings_data = sorted(data["data"], key=lambda x: x["index"])
  72. return [item["embedding"] for item in embeddings_data]