| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- """
- Embedding 生成模块
- 使用 OpenRouter 的 openai/text-embedding-3-small 模型生成向量。
- 支持单条和批量处理。
- """
- import os
- import asyncio
- from typing import List, Union
- import httpx
- OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
- OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
- EMBEDDING_MODEL = "openai/text-embedding-3-small"
- EMBEDDING_DIM = 1536
- async def get_embedding(text: str) -> List[float]:
- """
- 生成单条文本的向量
- Args:
- text: 输入文本
- Returns:
- 1536 维向量
- """
- embeddings = await get_embeddings_batch([text])
- return embeddings[0]
- async def get_embeddings_batch(texts: List[str], batch_size: int = 100) -> List[List[float]]:
- """
- 批量生成文本向量
- Args:
- texts: 文本列表
- batch_size: 每批处理数量(OpenAI 限制 2048)
- Returns:
- 向量列表
- """
- if not texts:
- return []
- # 分批处理
- all_embeddings = []
- for i in range(0, len(texts), batch_size):
- batch = texts[i:i + batch_size]
- embeddings = await _call_embedding_api(batch)
- all_embeddings.extend(embeddings)
- return all_embeddings
- async def _call_embedding_api(texts: List[str]) -> List[List[float]]:
- """
- 调用 OpenRouter embedding API
- Args:
- texts: 文本列表(单批)
- Returns:
- 向量列表
- """
- if not OPENROUTER_API_KEY:
- raise ValueError("OPENROUTER_API_KEY not set in environment")
- async with httpx.AsyncClient(timeout=30.0) as client:
- response = await client.post(
- f"{OPENROUTER_BASE_URL}/embeddings",
- headers={
- "Authorization": f"Bearer {OPENROUTER_API_KEY}",
- "Content-Type": "application/json",
- },
- json={
- "model": EMBEDDING_MODEL,
- "input": texts,
- }
- )
- response.raise_for_status()
- data = response.json()
- # 按 index 排序(API 可能乱序返回)
- embeddings_data = sorted(data["data"], key=lambda x: x["index"])
- return [item["embedding"] for item in embeddings_data]
|