RecModel.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from torch import nn
  2. from .RecCTCHead import CTCHead
  3. from .RecMv1_enhance import MobileNetV1Enhance
  4. from .RNN import Im2Im, Im2Seq, SequenceEncoder
  5. backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance}
  6. neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im}
  7. head_dict = {"CTCHead": CTCHead}
  8. class RecModel(nn.Module):
  9. def __init__(self, config):
  10. super().__init__()
  11. assert "in_channels" in config, "in_channels must in model config"
  12. backbone_type = config.backbone.pop("type")
  13. assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}"
  14. self.backbone = backbone_dict[backbone_type](
  15. config.in_channels, **config.backbone
  16. )
  17. neck_type = config.neck.pop("type")
  18. assert neck_type in neck_dict, f"neck.type must in {neck_dict}"
  19. self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)
  20. head_type = config.head.pop("type")
  21. assert head_type in head_dict, f"head.type must in {head_dict}"
  22. self.head = head_dict[head_type](self.neck.out_channels, **config.head)
  23. self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}"
  24. def load_3rd_state_dict(self, _3rd_name, _state):
  25. self.backbone.load_3rd_state_dict(_3rd_name, _state)
  26. self.neck.load_3rd_state_dict(_3rd_name, _state)
  27. self.head.load_3rd_state_dict(_3rd_name, _state)
  28. def forward(self, x):
  29. x = self.backbone(x)
  30. x = self.neck(x)
  31. x = self.head(x)
  32. return x
  33. def encode(self, x):
  34. x = self.backbone(x)
  35. x = self.neck(x)
  36. x = self.head.ctc_encoder(x)
  37. return x