embedding.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import io
  2. import urllib.request
  3. from typing import List
  4. import torch
  5. from PIL import Image
  6. from .clip_model import get_model
  7. # init model
  8. model, processor, DEVICE, DTYPE, MAX_BATCH = get_model()
  9. def _normalize(x: torch.Tensor) -> torch.Tensor:
  10. return x / (x.norm(dim=-1, keepdim=True) + 1e-12)
  11. def _to_list(x: torch.Tensor):
  12. return x.detach().cpu().tolist()
  13. async def embed_image_url(img_url_list: List[str]):
  14. images = []
  15. for u in img_url_list:
  16. with urllib.request.urlopen(u, timeout=15) as r:
  17. img = Image.open(io.BytesIO(r.read())).convert("RGB")
  18. images.append(img)
  19. outputs = []
  20. for chunk_start in range(0, len(images), MAX_BATCH):
  21. chunk = images[chunk_start:chunk_start + MAX_BATCH]
  22. inputs = processor(images=chunk, return_tensors="pt")
  23. inputs = {k: v.to(DEVICE, dtype=DTYPE) if hasattr(v, "to") else v for k, v in inputs.items()}
  24. feats = model.get_image_features(**inputs)
  25. feats = _normalize(feats)
  26. outputs.extend(_to_list(feats))
  27. return outputs