Explorar o código

support streaming

Lengyue hai 1 ano
pai
achega
b6e15cc1ad
Modificáronse 1 ficheiros con 15 adicións e 5 borrados
  1. 15 5
      tools/llama/generate.py

+ 15 - 5
tools/llama/generate.py

@@ -456,6 +456,7 @@ def generate_long(
     speaker: Optional[str] = None,
     prompt_text: Optional[str] = None,
     prompt_tokens: Optional[torch.Tensor] = None,
+    is_streaming: bool = False,
 ):
     model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
     im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
@@ -580,13 +581,22 @@ def generate_long(
 
             # But for global encoding, we should keep the <im_end> token
             global_encoded.append(decoded)
-            all_codes.append(codes)
-            seg_idx += 1
 
-        codes = torch.cat(all_codes, dim=1)
-        assert (codes >= 0).all(), f"Negative code found: {codes}"
+            if is_streaming:
+                assert (codes >= 0).all(), f"Negative code found: {codes}"
+                yield codes
+            else:
+                all_codes.append(codes)
 
-        yield codes
+            seg_idx += 1
+
+        if is_streaming:
+            # This indicates the end of the current sample
+            yield None
+        else:
+            all_codes = torch.cat(all_codes, dim=1)
+            assert (all_codes >= 0).all(), f"Negative code found: {codes}"
+            yield all_codes
 
 
 @click.command()