embeddings.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. """
  2. Embedding 生成模块
  3. 使用 OpenRouter 的 openai/text-embedding-3-small 模型生成向量。
  4. 支持单条和批量处理。
  5. """
  6. import os
  7. import asyncio
  8. from typing import List, Union
  9. import httpx
  10. OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
  11. OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
  12. EMBEDDING_MODEL = "openai/text-embedding-3-small"
  13. EMBEDDING_DIM = 1536
  14. async def get_embedding(text: str) -> List[float]:
  15. """
  16. 生成单条文本的向量
  17. Args:
  18. text: 输入文本
  19. Returns:
  20. 1536 维向量
  21. """
  22. embeddings = await get_embeddings_batch([text])
  23. return embeddings[0]
  24. async def get_embeddings_batch(texts: List[str], batch_size: int = 100) -> List[List[float]]:
  25. """
  26. 批量生成文本向量
  27. Args:
  28. texts: 文本列表
  29. batch_size: 每批处理数量(OpenAI 限制 2048)
  30. Returns:
  31. 向量列表
  32. """
  33. if not texts:
  34. return []
  35. # 分批处理
  36. all_embeddings = []
  37. for i in range(0, len(texts), batch_size):
  38. batch = texts[i:i + batch_size]
  39. embeddings = await _call_embedding_api(batch)
  40. all_embeddings.extend(embeddings)
  41. return all_embeddings
  42. async def _call_embedding_api(texts: List[str]) -> List[List[float]]:
  43. """
  44. 调用 OpenRouter embedding API
  45. Args:
  46. texts: 文本列表(单批)
  47. Returns:
  48. 向量列表
  49. """
  50. if not OPENROUTER_API_KEY:
  51. raise ValueError("OPENROUTER_API_KEY not set in environment")
  52. async with httpx.AsyncClient(timeout=30.0) as client:
  53. response = await client.post(
  54. f"{OPENROUTER_BASE_URL}/embeddings",
  55. headers={
  56. "Authorization": f"Bearer {OPENROUTER_API_KEY}",
  57. "Content-Type": "application/json",
  58. },
  59. json={
  60. "model": EMBEDDING_MODEL,
  61. "input": texts,
  62. }
  63. )
  64. response.raise_for_status()
  65. data = response.json()
  66. # 按 index 排序(API 可能乱序返回)
  67. embeddings_data = sorted(data["data"], key=lambda x: x["index"])
  68. return [item["embedding"] for item in embeddings_data]