embedding.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import io
  2. import urllib.request
  3. from typing import List
  4. import torch
  5. from PIL import Image
  6. from .clip_model import get_model
  7. # init model
  8. model, image_processor, tokenizer, DEVICE, DTYPE, MAX_BATCH = get_model()
  9. def _normalize(x: torch.Tensor) -> torch.Tensor:
  10. return x / (x.norm(dim=-1, keepdim=True) + 1e-12)
  11. def _to_list(x: torch.Tensor):
  12. return x.detach().cpu().tolist()
  13. async def embed_image_url(img_url_list: List[str]):
  14. images = []
  15. for u in img_url_list:
  16. with urllib.request.urlopen(u, timeout=15) as r:
  17. img = Image.open(io.BytesIO(r.read())).convert("RGB")
  18. images.append(img)
  19. outputs = []
  20. for chunk_start in range(0, len(images), MAX_BATCH):
  21. chunk = images[chunk_start:chunk_start + MAX_BATCH]
  22. # ✅ 用 image_processor,不再用混合 processor
  23. inputs = image_processor(images=chunk, return_tensors="pt")
  24. inputs = {k: v.to(DEVICE, dtype=DTYPE) if hasattr(v, "to") else v for k, v in inputs.items()}
  25. outputs = model(**inputs)
  26. # 某些实现是 outputs.last_hidden_state,某些是 outputs.image_embeds
  27. feats = outputs.image_embeds if hasattr(outputs, "image_embeds") else outputs.last_hidden_state
  28. feats = _normalize(feats)
  29. outputs.extend(_to_list(feats))
  30. return outputs