embedding.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import io
  2. import urllib.request
  3. from typing import List
  4. import torch
  5. from PIL import Image
  6. from .clip_model import get_model
  7. # init model
  8. model, processor, DEVICE, DTYPE, MAX_BATCH = get_model()
  9. def _normalize(x: torch.Tensor) -> torch.Tensor:
  10. return x / (x.norm(dim=-1, keepdim=True) + 1e-12)
  11. def _to_list(x: torch.Tensor):
  12. return x.detach().cpu().tolist()
  13. async def embed_image_url(img_url_list: List[str]):
  14. images = []
  15. for u in img_url_list:
  16. with urllib.request.urlopen(u, timeout=15) as r:
  17. img = Image.open(io.BytesIO(r.read())).convert("RGB")
  18. images.append(img)
  19. outputs = []
  20. for chunk_start in range(0, len(images), MAX_BATCH):
  21. chunk = images[chunk_start:chunk_start + MAX_BATCH]
  22. # 兼容两种情况:AutoProcessor vs dict(fallback)
  23. if isinstance(processor, dict):
  24. inputs = processor["image_processor"](images=chunk, return_tensors="pt")
  25. else:
  26. inputs = processor(images=chunk, return_tensors="pt")
  27. inputs = {k: v.to(DEVICE, dtype=DTYPE) if hasattr(v, "to") else v for k, v in inputs.items()}
  28. feats = model.get_image_features(**inputs)
  29. feats = _normalize(feats)
  30. outputs.extend(_to_list(feats))
  31. return outputs