__init__.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. from copy import deepcopy
  2. import torch
  3. from ..utils import load_file_from_url
  4. from .retinaface import RetinaFace
  5. def init_detection_model(model_name, half=False, device="cuda", model_rootpath=None):
  6. if model_name == "retinaface_resnet50":
  7. model = RetinaFace(network_name="resnet50", half=half, device=device)
  8. model_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth"
  9. elif model_name == "retinaface_mobile0.25":
  10. model = RetinaFace(network_name="mobile0.25", half=half, device=device)
  11. model_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth"
  12. else:
  13. raise NotImplementedError(f"{model_name} is not implemented.")
  14. model_path = load_file_from_url(
  15. url=model_url,
  16. model_dir="facexlib/weights",
  17. progress=True,
  18. file_name=None,
  19. save_dir=model_rootpath,
  20. )
  21. # TODO: clean pretrained model
  22. load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
  23. # remove unnecessary 'module.'
  24. for k, v in deepcopy(load_net).items():
  25. if k.startswith("module."):
  26. load_net[k[7:]] = v
  27. load_net.pop(k)
  28. model.load_state_dict(load_net, strict=True)
  29. model.eval()
  30. model = model.to(device)
  31. return model