embedding.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import io
  2. import urllib.request
  3. from typing import List
  4. import torch
  5. from PIL import Image
  6. from .clip_model import get_model
  7. # init model
  8. model, image_processor, tokenizer, DEVICE, DTYPE, MAX_BATCH = get_model()
  9. def _normalize(x: torch.Tensor) -> torch.Tensor:
  10. return x / (x.norm(dim=-1, keepdim=True) + 1e-12)
  11. def _to_list(x: torch.Tensor):
  12. return x.detach().cpu().tolist()
  13. async def embed_image_url(img_url_list: List[str]):
  14. images = []
  15. for u in img_url_list:
  16. with urllib.request.urlopen(u, timeout=15) as r:
  17. img = Image.open(io.BytesIO(r.read())).convert("RGB")
  18. images.append(img)
  19. outputs = []
  20. for chunk_start in range(0, len(images), MAX_BATCH):
  21. chunk = images[chunk_start:chunk_start + MAX_BATCH]
  22. # 只取 pixel_values,不传 text
  23. inputs = image_processor(images=chunk, return_tensors="pt")
  24. pixel_values = inputs["pixel_values"].to(DEVICE, dtype=DTYPE)
  25. with torch.no_grad():
  26. # ✅ 调用图像编码器
  27. image_features = model.encode_image(pixel_values)
  28. feats = _normalize(image_features)
  29. outputs.extend(_to_list(feats))
  30. return outputs