|
|
@@ -123,10 +123,9 @@ def decode_one_token(
|
|
|
)
|
|
|
|
|
|
codebooks = torch.stack(codebooks, dim=0)
|
|
|
- if codebooks[0] == 2:
|
|
|
- codebooks[1] = 1
|
|
|
- else:
|
|
|
- codebooks[1] = codebooks[0] - 32311 + 2
|
|
|
+ codebooks[1] = torch.where(
|
|
|
+ codebooks[0] <= 32311, codebooks[0], codebooks[0] - 32311 + 2
|
|
|
+ )
|
|
|
|
|
|
return codebooks
|
|
|
|