@@ -122,7 +122,13 @@ def decode_one_token(
)[0]
)
- return torch.stack(codebooks, dim=0)
+ codebooks = torch.stack(codebooks, dim=0)
+ if codebooks[0] == 2:
+ codebooks[1] = 1
+ else:
+ codebooks[1] = codebooks[0] - 32311 + 2
+
+ return codebooks
def prefill(