RecSVTR.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. from torch.nn import functional
  5. from torch.nn.init import ones_, trunc_normal_, zeros_
  6. def drop_path(x, drop_prob=0.0, training=False):
  7. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  8. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  9. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
  10. """
  11. if drop_prob == 0.0 or not training:
  12. return x
  13. keep_prob = torch.tensor(1 - drop_prob)
  14. shape = (x.size()[0],) + (1,) * (x.ndim - 1)
  15. random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
  16. random_tensor = torch.floor(random_tensor) # binarize
  17. output = x.divide(keep_prob) * random_tensor
  18. return output
  19. class Swish(nn.Module):
  20. def __int__(self):
  21. super(Swish, self).__int__()
  22. def forward(self, x):
  23. return x * torch.sigmoid(x)
  24. class ConvBNLayer(nn.Module):
  25. def __init__(
  26. self,
  27. in_channels,
  28. out_channels,
  29. kernel_size=3,
  30. stride=1,
  31. padding=0,
  32. bias_attr=False,
  33. groups=1,
  34. act=nn.GELU,
  35. ):
  36. super().__init__()
  37. self.conv = nn.Conv2d(
  38. in_channels=in_channels,
  39. out_channels=out_channels,
  40. kernel_size=kernel_size,
  41. stride=stride,
  42. padding=padding,
  43. groups=groups,
  44. # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
  45. bias=bias_attr,
  46. )
  47. self.norm = nn.BatchNorm2d(out_channels)
  48. self.act = act()
  49. def forward(self, inputs):
  50. out = self.conv(inputs)
  51. out = self.norm(out)
  52. out = self.act(out)
  53. return out
  54. class DropPath(nn.Module):
  55. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  56. def __init__(self, drop_prob=None):
  57. super(DropPath, self).__init__()
  58. self.drop_prob = drop_prob
  59. def forward(self, x):
  60. return drop_path(x, self.drop_prob, self.training)
  61. class Identity(nn.Module):
  62. def __init__(self):
  63. super(Identity, self).__init__()
  64. def forward(self, input):
  65. return input
  66. class Mlp(nn.Module):
  67. def __init__(
  68. self,
  69. in_features,
  70. hidden_features=None,
  71. out_features=None,
  72. act_layer=nn.GELU,
  73. drop=0.0,
  74. ):
  75. super().__init__()
  76. out_features = out_features or in_features
  77. hidden_features = hidden_features or in_features
  78. self.fc1 = nn.Linear(in_features, hidden_features)
  79. if isinstance(act_layer, str):
  80. self.act = Swish()
  81. else:
  82. self.act = act_layer()
  83. self.fc2 = nn.Linear(hidden_features, out_features)
  84. self.drop = nn.Dropout(drop)
  85. def forward(self, x):
  86. x = self.fc1(x)
  87. x = self.act(x)
  88. x = self.drop(x)
  89. x = self.fc2(x)
  90. x = self.drop(x)
  91. return x
  92. class ConvMixer(nn.Module):
  93. def __init__(
  94. self,
  95. dim,
  96. num_heads=8,
  97. HW=(8, 25),
  98. local_k=(3, 3),
  99. ):
  100. super().__init__()
  101. self.HW = HW
  102. self.dim = dim
  103. self.local_mixer = nn.Conv2d(
  104. dim,
  105. dim,
  106. local_k,
  107. 1,
  108. (local_k[0] // 2, local_k[1] // 2),
  109. groups=num_heads,
  110. # weight_attr=ParamAttr(initializer=KaimingNormal())
  111. )
  112. def forward(self, x):
  113. h = self.HW[0]
  114. w = self.HW[1]
  115. x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
  116. x = self.local_mixer(x)
  117. x = x.flatten(2).transpose([0, 2, 1])
  118. return x
  119. class Attention(nn.Module):
  120. def __init__(
  121. self,
  122. dim,
  123. num_heads=8,
  124. mixer="Global",
  125. HW=(8, 25),
  126. local_k=(7, 11),
  127. qkv_bias=False,
  128. qk_scale=None,
  129. attn_drop=0.0,
  130. proj_drop=0.0,
  131. ):
  132. super().__init__()
  133. self.num_heads = num_heads
  134. head_dim = dim // num_heads
  135. self.scale = qk_scale or head_dim**-0.5
  136. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  137. self.attn_drop = nn.Dropout(attn_drop)
  138. self.proj = nn.Linear(dim, dim)
  139. self.proj_drop = nn.Dropout(proj_drop)
  140. self.HW = HW
  141. if HW is not None:
  142. H = HW[0]
  143. W = HW[1]
  144. self.N = H * W
  145. self.C = dim
  146. if mixer == "Local" and HW is not None:
  147. hk = local_k[0]
  148. wk = local_k[1]
  149. mask = torch.ones([H * W, H + hk - 1, W + wk - 1])
  150. for h in range(0, H):
  151. for w in range(0, W):
  152. mask[h * W + w, h : h + hk, w : w + wk] = 0.0
  153. mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(
  154. 1
  155. )
  156. mask_inf = torch.full([H * W, H * W], fill_value=float("-inf"))
  157. mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
  158. self.mask = mask[None, None, :]
  159. # self.mask = mask.unsqueeze([0, 1])
  160. self.mixer = mixer
  161. def forward(self, x):
  162. if self.HW is not None:
  163. N = self.N
  164. C = self.C
  165. else:
  166. _, N, C = x.shape
  167. qkv = (
  168. self.qkv(x)
  169. .reshape((-1, N, 3, self.num_heads, C // self.num_heads))
  170. .permute((2, 0, 3, 1, 4))
  171. )
  172. q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
  173. attn = q.matmul(k.permute((0, 1, 3, 2)))
  174. if self.mixer == "Local":
  175. attn += self.mask
  176. attn = functional.softmax(attn, dim=-1)
  177. attn = self.attn_drop(attn)
  178. x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
  179. x = self.proj(x)
  180. x = self.proj_drop(x)
  181. return x
  182. class Block(nn.Module):
  183. def __init__(
  184. self,
  185. dim,
  186. num_heads,
  187. mixer="Global",
  188. local_mixer=(7, 11),
  189. HW=(8, 25),
  190. mlp_ratio=4.0,
  191. qkv_bias=False,
  192. qk_scale=None,
  193. drop=0.0,
  194. attn_drop=0.0,
  195. drop_path=0.0,
  196. act_layer=nn.GELU,
  197. norm_layer="nn.LayerNorm",
  198. epsilon=1e-6,
  199. prenorm=True,
  200. ):
  201. super().__init__()
  202. if isinstance(norm_layer, str):
  203. self.norm1 = eval(norm_layer)(dim, eps=epsilon)
  204. else:
  205. self.norm1 = norm_layer(dim)
  206. if mixer == "Global" or mixer == "Local":
  207. self.mixer = Attention(
  208. dim,
  209. num_heads=num_heads,
  210. mixer=mixer,
  211. HW=HW,
  212. local_k=local_mixer,
  213. qkv_bias=qkv_bias,
  214. qk_scale=qk_scale,
  215. attn_drop=attn_drop,
  216. proj_drop=drop,
  217. )
  218. elif mixer == "Conv":
  219. self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
  220. else:
  221. raise TypeError("The mixer must be one of [Global, Local, Conv]")
  222. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  223. if isinstance(norm_layer, str):
  224. self.norm2 = eval(norm_layer)(dim, eps=epsilon)
  225. else:
  226. self.norm2 = norm_layer(dim)
  227. mlp_hidden_dim = int(dim * mlp_ratio)
  228. self.mlp_ratio = mlp_ratio
  229. self.mlp = Mlp(
  230. in_features=dim,
  231. hidden_features=mlp_hidden_dim,
  232. act_layer=act_layer,
  233. drop=drop,
  234. )
  235. self.prenorm = prenorm
  236. def forward(self, x):
  237. if self.prenorm:
  238. x = self.norm1(x + self.drop_path(self.mixer(x)))
  239. x = self.norm2(x + self.drop_path(self.mlp(x)))
  240. else:
  241. x = x + self.drop_path(self.mixer(self.norm1(x)))
  242. x = x + self.drop_path(self.mlp(self.norm2(x)))
  243. return x
  244. class PatchEmbed(nn.Module):
  245. """Image to Patch Embedding"""
  246. def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2):
  247. super().__init__()
  248. num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
  249. self.img_size = img_size
  250. self.num_patches = num_patches
  251. self.embed_dim = embed_dim
  252. self.norm = None
  253. if sub_num == 2:
  254. self.proj = nn.Sequential(
  255. ConvBNLayer(
  256. in_channels=in_channels,
  257. out_channels=embed_dim // 2,
  258. kernel_size=3,
  259. stride=2,
  260. padding=1,
  261. act=nn.GELU,
  262. bias_attr=False,
  263. ),
  264. ConvBNLayer(
  265. in_channels=embed_dim // 2,
  266. out_channels=embed_dim,
  267. kernel_size=3,
  268. stride=2,
  269. padding=1,
  270. act=nn.GELU,
  271. bias_attr=False,
  272. ),
  273. )
  274. if sub_num == 3:
  275. self.proj = nn.Sequential(
  276. ConvBNLayer(
  277. in_channels=in_channels,
  278. out_channels=embed_dim // 4,
  279. kernel_size=3,
  280. stride=2,
  281. padding=1,
  282. act=nn.GELU,
  283. bias_attr=False,
  284. ),
  285. ConvBNLayer(
  286. in_channels=embed_dim // 4,
  287. out_channels=embed_dim // 2,
  288. kernel_size=3,
  289. stride=2,
  290. padding=1,
  291. act=nn.GELU,
  292. bias_attr=False,
  293. ),
  294. ConvBNLayer(
  295. in_channels=embed_dim // 2,
  296. out_channels=embed_dim,
  297. kernel_size=3,
  298. stride=2,
  299. padding=1,
  300. act=nn.GELU,
  301. bias_attr=False,
  302. ),
  303. )
  304. def forward(self, x):
  305. B, C, H, W = x.shape
  306. assert (
  307. H == self.img_size[0] and W == self.img_size[1]
  308. ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  309. x = self.proj(x).flatten(2).permute(0, 2, 1)
  310. return x
  311. class SubSample(nn.Module):
  312. def __init__(
  313. self,
  314. in_channels,
  315. out_channels,
  316. types="Pool",
  317. stride=(2, 1),
  318. sub_norm="nn.LayerNorm",
  319. act=None,
  320. ):
  321. super().__init__()
  322. self.types = types
  323. if types == "Pool":
  324. self.avgpool = nn.AvgPool2d(
  325. kernel_size=(3, 5), stride=stride, padding=(1, 2)
  326. )
  327. self.maxpool = nn.MaxPool2d(
  328. kernel_size=(3, 5), stride=stride, padding=(1, 2)
  329. )
  330. self.proj = nn.Linear(in_channels, out_channels)
  331. else:
  332. self.conv = nn.Conv2d(
  333. in_channels,
  334. out_channels,
  335. kernel_size=3,
  336. stride=stride,
  337. padding=1,
  338. # weight_attr=ParamAttr(initializer=KaimingNormal())
  339. )
  340. self.norm = eval(sub_norm)(out_channels)
  341. if act is not None:
  342. self.act = act()
  343. else:
  344. self.act = None
  345. def forward(self, x):
  346. if self.types == "Pool":
  347. x1 = self.avgpool(x)
  348. x2 = self.maxpool(x)
  349. x = (x1 + x2) * 0.5
  350. out = self.proj(x.flatten(2).permute((0, 2, 1)))
  351. else:
  352. x = self.conv(x)
  353. out = x.flatten(2).permute((0, 2, 1))
  354. out = self.norm(out)
  355. if self.act is not None:
  356. out = self.act(out)
  357. return out
  358. class SVTRNet(nn.Module):
  359. def __init__(
  360. self,
  361. img_size=[48, 100],
  362. in_channels=3,
  363. embed_dim=[64, 128, 256],
  364. depth=[3, 6, 3],
  365. num_heads=[2, 4, 8],
  366. mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
  367. local_mixer=[[7, 11], [7, 11], [7, 11]],
  368. patch_merging="Conv", # Conv, Pool, None
  369. mlp_ratio=4,
  370. qkv_bias=True,
  371. qk_scale=None,
  372. drop_rate=0.0,
  373. last_drop=0.1,
  374. attn_drop_rate=0.0,
  375. drop_path_rate=0.1,
  376. norm_layer="nn.LayerNorm",
  377. sub_norm="nn.LayerNorm",
  378. epsilon=1e-6,
  379. out_channels=192,
  380. out_char_num=25,
  381. block_unit="Block",
  382. act="nn.GELU",
  383. last_stage=True,
  384. sub_num=2,
  385. prenorm=True,
  386. use_lenhead=False,
  387. **kwargs,
  388. ):
  389. super().__init__()
  390. self.img_size = img_size
  391. self.embed_dim = embed_dim
  392. self.out_channels = out_channels
  393. self.prenorm = prenorm
  394. patch_merging = (
  395. None
  396. if patch_merging != "Conv" and patch_merging != "Pool"
  397. else patch_merging
  398. )
  399. self.patch_embed = PatchEmbed(
  400. img_size=img_size,
  401. in_channels=in_channels,
  402. embed_dim=embed_dim[0],
  403. sub_num=sub_num,
  404. )
  405. num_patches = self.patch_embed.num_patches
  406. self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
  407. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
  408. # self.pos_embed = self.create_parameter(
  409. # shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
  410. # self.add_parameter("pos_embed", self.pos_embed)
  411. self.pos_drop = nn.Dropout(p=drop_rate)
  412. Block_unit = eval(block_unit)
  413. dpr = np.linspace(0, drop_path_rate, sum(depth))
  414. self.blocks1 = nn.ModuleList(
  415. [
  416. Block_unit(
  417. dim=embed_dim[0],
  418. num_heads=num_heads[0],
  419. mixer=mixer[0 : depth[0]][i],
  420. HW=self.HW,
  421. local_mixer=local_mixer[0],
  422. mlp_ratio=mlp_ratio,
  423. qkv_bias=qkv_bias,
  424. qk_scale=qk_scale,
  425. drop=drop_rate,
  426. act_layer=eval(act),
  427. attn_drop=attn_drop_rate,
  428. drop_path=dpr[0 : depth[0]][i],
  429. norm_layer=norm_layer,
  430. epsilon=epsilon,
  431. prenorm=prenorm,
  432. )
  433. for i in range(depth[0])
  434. ]
  435. )
  436. if patch_merging is not None:
  437. self.sub_sample1 = SubSample(
  438. embed_dim[0],
  439. embed_dim[1],
  440. sub_norm=sub_norm,
  441. stride=[2, 1],
  442. types=patch_merging,
  443. )
  444. HW = [self.HW[0] // 2, self.HW[1]]
  445. else:
  446. HW = self.HW
  447. self.patch_merging = patch_merging
  448. self.blocks2 = nn.ModuleList(
  449. [
  450. Block_unit(
  451. dim=embed_dim[1],
  452. num_heads=num_heads[1],
  453. mixer=mixer[depth[0] : depth[0] + depth[1]][i],
  454. HW=HW,
  455. local_mixer=local_mixer[1],
  456. mlp_ratio=mlp_ratio,
  457. qkv_bias=qkv_bias,
  458. qk_scale=qk_scale,
  459. drop=drop_rate,
  460. act_layer=eval(act),
  461. attn_drop=attn_drop_rate,
  462. drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
  463. norm_layer=norm_layer,
  464. epsilon=epsilon,
  465. prenorm=prenorm,
  466. )
  467. for i in range(depth[1])
  468. ]
  469. )
  470. if patch_merging is not None:
  471. self.sub_sample2 = SubSample(
  472. embed_dim[1],
  473. embed_dim[2],
  474. sub_norm=sub_norm,
  475. stride=[2, 1],
  476. types=patch_merging,
  477. )
  478. HW = [self.HW[0] // 4, self.HW[1]]
  479. else:
  480. HW = self.HW
  481. self.blocks3 = nn.ModuleList(
  482. [
  483. Block_unit(
  484. dim=embed_dim[2],
  485. num_heads=num_heads[2],
  486. mixer=mixer[depth[0] + depth[1] :][i],
  487. HW=HW,
  488. local_mixer=local_mixer[2],
  489. mlp_ratio=mlp_ratio,
  490. qkv_bias=qkv_bias,
  491. qk_scale=qk_scale,
  492. drop=drop_rate,
  493. act_layer=eval(act),
  494. attn_drop=attn_drop_rate,
  495. drop_path=dpr[depth[0] + depth[1] :][i],
  496. norm_layer=norm_layer,
  497. epsilon=epsilon,
  498. prenorm=prenorm,
  499. )
  500. for i in range(depth[2])
  501. ]
  502. )
  503. self.last_stage = last_stage
  504. if last_stage:
  505. self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))
  506. self.last_conv = nn.Conv2d(
  507. in_channels=embed_dim[2],
  508. out_channels=self.out_channels,
  509. kernel_size=1,
  510. stride=1,
  511. padding=0,
  512. bias=False,
  513. )
  514. self.hardswish = nn.Hardswish()
  515. self.dropout = nn.Dropout(p=last_drop)
  516. if not prenorm:
  517. self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
  518. self.use_lenhead = use_lenhead
  519. if use_lenhead:
  520. self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
  521. self.hardswish_len = nn.Hardswish()
  522. self.dropout_len = nn.Dropout(p=last_drop)
  523. trunc_normal_(self.pos_embed, std=0.02)
  524. self.apply(self._init_weights)
  525. def _init_weights(self, m):
  526. if isinstance(m, nn.Linear):
  527. trunc_normal_(m.weight, std=0.02)
  528. if isinstance(m, nn.Linear) and m.bias is not None:
  529. zeros_(m.bias)
  530. elif isinstance(m, nn.LayerNorm):
  531. zeros_(m.bias)
  532. ones_(m.weight)
  533. def forward_features(self, x):
  534. x = self.patch_embed(x)
  535. x = x + self.pos_embed
  536. x = self.pos_drop(x)
  537. for blk in self.blocks1:
  538. x = blk(x)
  539. if self.patch_merging is not None:
  540. x = self.sub_sample1(
  541. x.permute([0, 2, 1]).reshape(
  542. [-1, self.embed_dim[0], self.HW[0], self.HW[1]]
  543. )
  544. )
  545. for blk in self.blocks2:
  546. x = blk(x)
  547. if self.patch_merging is not None:
  548. x = self.sub_sample2(
  549. x.permute([0, 2, 1]).reshape(
  550. [-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]
  551. )
  552. )
  553. for blk in self.blocks3:
  554. x = blk(x)
  555. if not self.prenorm:
  556. x = self.norm(x)
  557. return x
  558. def forward(self, x):
  559. x = self.forward_features(x)
  560. if self.use_lenhead:
  561. len_x = self.len_conv(x.mean(1))
  562. len_x = self.dropout_len(self.hardswish_len(len_x))
  563. if self.last_stage:
  564. if self.patch_merging is not None:
  565. h = self.HW[0] // 4
  566. else:
  567. h = self.HW[0]
  568. x = self.avg_pool(
  569. x.permute([0, 2, 1]).reshape([-1, self.embed_dim[2], h, self.HW[1]])
  570. )
  571. x = self.last_conv(x)
  572. x = self.hardswish(x)
  573. x = self.dropout(x)
  574. if self.use_lenhead:
  575. return x, len_x
  576. return x
  577. if __name__ == "__main__":
  578. a = torch.rand(1, 3, 48, 100)
  579. svtr = SVTRNet()
  580. out = svtr(a)
  581. print(svtr)
  582. print(out.size())