RecCTCHead.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from torch import nn
  2. class CTCHead(nn.Module):
  3. def __init__(
  4. self,
  5. in_channels,
  6. out_channels=6625,
  7. fc_decay=0.0004,
  8. mid_channels=None,
  9. return_feats=False,
  10. **kwargs,
  11. ):
  12. super(CTCHead, self).__init__()
  13. if mid_channels is None:
  14. self.fc = nn.Linear(
  15. in_channels,
  16. out_channels,
  17. bias=True,
  18. )
  19. else:
  20. self.fc1 = nn.Linear(
  21. in_channels,
  22. mid_channels,
  23. bias=True,
  24. )
  25. self.fc2 = nn.Linear(
  26. mid_channels,
  27. out_channels,
  28. bias=True,
  29. )
  30. self.out_channels = out_channels
  31. self.mid_channels = mid_channels
  32. self.return_feats = return_feats
  33. def forward(self, x, labels=None):
  34. if self.mid_channels is None:
  35. predicts = self.fc(x)
  36. else:
  37. x = self.fc1(x)
  38. predicts = self.fc2(x)
  39. if self.return_feats:
  40. result = dict()
  41. result["ctc"] = predicts
  42. result["ctc_neck"] = x
  43. else:
  44. result = predicts
  45. return result