|
|
@@ -509,7 +509,7 @@ def main(
|
|
|
|
|
|
use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
|
encoded = []
|
|
|
- texts = split_text(text, 20) if iterative_prompt else [text]
|
|
|
+ texts = split_text(text, 30) if iterative_prompt else [text]
|
|
|
for idx, text in enumerate(texts):
|
|
|
encoded.append(
|
|
|
encode_tokens(
|
|
|
@@ -561,7 +561,26 @@ def main(
|
|
|
while seg_idx < len(encoded):
|
|
|
seg = encoded[seg_idx]
|
|
|
global_encoded.append(seg)
|
|
|
- cat_encoded = torch.cat(global_encoded, dim=1)
|
|
|
+
|
|
|
+ lengths = reversed([seg.size(1) for seg in global_encoded])
|
|
|
+ # Pick last 2000 tokens
|
|
|
+ count = 0
|
|
|
+ for i, length in enumerate(lengths):
|
|
|
+ count += length
|
|
|
+ if count >= 2000:
|
|
|
+ break
|
|
|
+
|
|
|
+ if i != 0 and i % 2 == 0:
|
|
|
+ i -= 1
|
|
|
+
|
|
|
+ if i < len(global_encoded) - 2:
|
|
|
+ partial_encoded = global_encoded[-i:]
|
|
|
+ print(f"Loaded partial encoded")
|
|
|
+ else:
|
|
|
+ partial_encoded = global_encoded
|
|
|
+ print(f"Using full encoded")
|
|
|
+
|
|
|
+ cat_encoded = torch.cat(partial_encoded, dim=1)
|
|
|
prompt_length = cat_encoded.size(1)
|
|
|
|
|
|
t0 = time.perf_counter()
|