|
|
@@ -292,13 +292,20 @@ def encode_tokens(
|
|
|
|
|
|
string = f"[INST] {string} [/INST]"
|
|
|
|
|
|
- tokens = tokenizer.encode(
|
|
|
- string,
|
|
|
- max_length=10**6,
|
|
|
- add_special_tokens=bos,
|
|
|
- truncation=False,
|
|
|
- )
|
|
|
- tokens = torch.tensor([tokens], dtype=torch.int, device=device)
|
|
|
+ # Handle English less frequent words
|
|
|
+ # TODO: update tokenizer to handle this
|
|
|
+ sub_strings = string.split(" ")
|
|
|
+ new_tokens = []
|
|
|
+ for i, string in enumerate(sub_strings):
|
|
|
+ tokens = tokenizer.encode(
|
|
|
+ string,
|
|
|
+ add_special_tokens=i == 0 and bos,
|
|
|
+ max_length=10**6,
|
|
|
+ truncation=False,
|
|
|
+ )
|
|
|
+ new_tokens.extend(tokens)
|
|
|
+
|
|
|
+ tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
|
|
|
|
|
|
# Codebooks
|
|
|
zeros = torch.zeros((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
|