|
|
@@ -470,16 +470,14 @@ def generate_long(
|
|
|
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
|
|
|
|
if use_prompt:
|
|
|
- encoded.append(
|
|
|
- encode_tokens(
|
|
|
- tokenizer,
|
|
|
- prompt_text,
|
|
|
- prompt_tokens=prompt_tokens,
|
|
|
- bos=True,
|
|
|
- device=device,
|
|
|
- speaker=speaker,
|
|
|
- num_codebooks=model.config.num_codebooks,
|
|
|
- )
|
|
|
+ encoded_prompts = encode_tokens(
|
|
|
+ tokenizer,
|
|
|
+ prompt_text,
|
|
|
+ prompt_tokens=prompt_tokens,
|
|
|
+ bos=True,
|
|
|
+ device=device,
|
|
|
+ speaker=speaker,
|
|
|
+ num_codebooks=model.config.num_codebooks,
|
|
|
)
|
|
|
|
|
|
for idx, text in enumerate(texts):
|
|
|
@@ -501,10 +499,6 @@ def generate_long(
|
|
|
all_codes = []
|
|
|
seg_idx = 0
|
|
|
|
|
|
- if use_prompt:
|
|
|
- seg_idx = 1
|
|
|
- global_encoded.append(encoded[0])
|
|
|
-
|
|
|
while seg_idx < len(encoded):
|
|
|
logger.info(
|
|
|
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
|
|
@@ -531,6 +525,9 @@ def generate_long(
|
|
|
else:
|
|
|
partial_encoded = global_encoded
|
|
|
|
|
|
+ if use_prompt:
|
|
|
+ partial_encoded = [encoded_prompts] + partial_encoded
|
|
|
+
|
|
|
cat_encoded = torch.cat(partial_encoded, dim=1)
|
|
|
prompt_length = cat_encoded.size(1)
|
|
|
|
|
|
@@ -593,7 +590,7 @@ def generate_long(
|
|
|
|
|
|
if is_streaming:
|
|
|
# This indicates the end of the current sample
|
|
|
- yield None
|
|
|
+ yield "next"
|
|
|
else:
|
|
|
all_codes = torch.cat(all_codes, dim=1)
|
|
|
assert (all_codes >= 0).all(), f"Negative code found: {codes}"
|
|
|
@@ -623,20 +620,21 @@ def launch_thread_safe_queue(
|
|
|
break
|
|
|
|
|
|
kwargs = item["request"]
|
|
|
- event = item["event"]
|
|
|
+ response_queue = item["response_queue"]
|
|
|
|
|
|
try:
|
|
|
item["success"] = True
|
|
|
- item["response"] = list(
|
|
|
- generate_long(
|
|
|
- model=model, decode_one_token=decode_one_token, **kwargs
|
|
|
- )
|
|
|
- )
|
|
|
+ for chunk in generate_long(
|
|
|
+ model=model, decode_one_token=decode_one_token, **kwargs
|
|
|
+ ):
|
|
|
+ response_queue.put(chunk)
|
|
|
+
|
|
|
+ response_queue.put("done")
|
|
|
except Exception as e:
|
|
|
item["success"] = False
|
|
|
item["response"] = e
|
|
|
|
|
|
- event.set()
|
|
|
+ response_queue.put("done")
|
|
|
|
|
|
threading.Thread(target=worker, daemon=True).start()
|
|
|
init_event.wait()
|