|
@@ -3,8 +3,9 @@ import queue
|
|
|
import string
|
|
import string
|
|
|
import threading
|
|
import threading
|
|
|
import time
|
|
import time
|
|
|
|
|
+from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
-from typing import Optional, Tuple, Union
|
|
|
|
|
|
|
+from typing import Literal, Optional, Tuple, Union
|
|
|
|
|
|
|
|
import click
|
|
import click
|
|
|
import hydra
|
|
import hydra
|
|
@@ -439,6 +440,13 @@ def split_text(text, min_length):
|
|
|
return segments
|
|
return segments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+@dataclass
|
|
|
|
|
+class GenerateResponse:
|
|
|
|
|
+ action: Literal["sample", "next"]
|
|
|
|
|
+ codes: Optional[torch.Tensor] = None
|
|
|
|
|
+ text: Optional[str] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def generate_long(
|
|
def generate_long(
|
|
|
*,
|
|
*,
|
|
|
model,
|
|
model,
|
|
@@ -458,7 +466,6 @@ def generate_long(
|
|
|
speaker: Optional[str] = None,
|
|
speaker: Optional[str] = None,
|
|
|
prompt_text: Optional[str] = None,
|
|
prompt_text: Optional[str] = None,
|
|
|
prompt_tokens: Optional[torch.Tensor] = None,
|
|
prompt_tokens: Optional[torch.Tensor] = None,
|
|
|
- is_streaming: bool = False,
|
|
|
|
|
):
|
|
):
|
|
|
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
|
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
|
|
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
|
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
|
@@ -508,7 +515,6 @@ def generate_long(
|
|
|
torch.cuda.synchronize()
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
global_encoded = []
|
|
global_encoded = []
|
|
|
- all_codes = []
|
|
|
|
|
seg_idx = 0
|
|
seg_idx = 0
|
|
|
|
|
|
|
|
while seg_idx < len(encoded):
|
|
while seg_idx < len(encoded):
|
|
@@ -594,22 +600,24 @@ def generate_long(
|
|
|
|
|
|
|
|
# But for global encoding, we should keep the <im_end> token
|
|
# But for global encoding, we should keep the <im_end> token
|
|
|
global_encoded.append(decoded)
|
|
global_encoded.append(decoded)
|
|
|
|
|
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
|
|
|
|
|
+ yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
|
|
|
|
|
+ seg_idx += 1
|
|
|
|
|
|
|
|
- if is_streaming:
|
|
|
|
|
- assert (codes >= 0).all(), f"Negative code found: {codes}"
|
|
|
|
|
- yield codes
|
|
|
|
|
- else:
|
|
|
|
|
- all_codes.append(codes)
|
|
|
|
|
|
|
+ # This indicates the end of the current sample
|
|
|
|
|
+ yield GenerateResponse(action="next")
|
|
|
|
|
|
|
|
- seg_idx += 1
|
|
|
|
|
|
|
|
|
|
- if is_streaming:
|
|
|
|
|
- # This indicates the end of the current sample
|
|
|
|
|
- yield "next"
|
|
|
|
|
- else:
|
|
|
|
|
- all_codes = torch.cat(all_codes, dim=1)
|
|
|
|
|
- assert (all_codes >= 0).all(), f"Negative code found: {codes}"
|
|
|
|
|
- yield all_codes
|
|
|
|
|
|
|
+@dataclass
|
|
|
|
|
+class WrappedGenerateResponse:
|
|
|
|
|
+ status: Literal["success", "error"]
|
|
|
|
|
+ response: Optional[GenerateResponse | Exception] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@dataclass
|
|
|
|
|
+class GenerateRequest:
|
|
|
|
|
+ request: dict
|
|
|
|
|
+ response_queue: queue.Queue
|
|
|
|
|
|
|
|
|
|
|
|
|
def launch_thread_safe_queue(
|
|
def launch_thread_safe_queue(
|
|
@@ -617,8 +625,8 @@ def launch_thread_safe_queue(
|
|
|
checkpoint_path,
|
|
checkpoint_path,
|
|
|
device,
|
|
device,
|
|
|
precision,
|
|
precision,
|
|
|
- max_length,
|
|
|
|
|
- compile=False,
|
|
|
|
|
|
|
+ max_length: int,
|
|
|
|
|
+ compile: bool = False,
|
|
|
):
|
|
):
|
|
|
input_queue = queue.Queue()
|
|
input_queue = queue.Queue()
|
|
|
init_event = threading.Event()
|
|
init_event = threading.Event()
|
|
@@ -630,26 +638,22 @@ def launch_thread_safe_queue(
|
|
|
init_event.set()
|
|
init_event.set()
|
|
|
|
|
|
|
|
while True:
|
|
while True:
|
|
|
- item = input_queue.get()
|
|
|
|
|
|
|
+ item: GenerateRequest | None = input_queue.get()
|
|
|
if item is None:
|
|
if item is None:
|
|
|
break
|
|
break
|
|
|
|
|
|
|
|
- kwargs = item["request"]
|
|
|
|
|
- response_queue = item["response_queue"]
|
|
|
|
|
|
|
+ kwargs = item.request
|
|
|
|
|
+ response_queue = item.response_queue
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
- item["success"] = True
|
|
|
|
|
for chunk in generate_long(
|
|
for chunk in generate_long(
|
|
|
model=model, decode_one_token=decode_one_token, **kwargs
|
|
model=model, decode_one_token=decode_one_token, **kwargs
|
|
|
):
|
|
):
|
|
|
- response_queue.put(chunk)
|
|
|
|
|
-
|
|
|
|
|
- response_queue.put("done")
|
|
|
|
|
|
|
+ response_queue.put(
|
|
|
|
|
+ WrappedGenerateResponse(status="success", response=chunk)
|
|
|
|
|
+ )
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
- item["success"] = False
|
|
|
|
|
- item["response"] = e
|
|
|
|
|
-
|
|
|
|
|
- response_queue.put("done")
|
|
|
|
|
|
|
+ response_queue.put(WrappedGenerateResponse(status="error", response=e))
|
|
|
|
|
|
|
|
threading.Thread(target=worker, daemon=True).start()
|
|
threading.Thread(target=worker, daemon=True).start()
|
|
|
init_event.wait()
|
|
init_event.wait()
|
|
@@ -753,9 +757,21 @@ def main(
|
|
|
prompt_tokens=prompt_tokens,
|
|
prompt_tokens=prompt_tokens,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- for idx, codes in enumerate(generator):
|
|
|
|
|
- np.save(f"codes_{idx}.npy", codes.cpu().numpy())
|
|
|
|
|
- logger.info(f"Saved codes to codes_{idx}.npy")
|
|
|
|
|
|
|
+ idx = 0
|
|
|
|
|
+ codes = []
|
|
|
|
|
+
|
|
|
|
|
+ for response in generator:
|
|
|
|
|
+ if response.action == "sample":
|
|
|
|
|
+ codes.append(response.codes)
|
|
|
|
|
+ logger.info(f"Sampled text: {response.text}")
|
|
|
|
|
+ elif response.action == "next":
|
|
|
|
|
+ if codes:
|
|
|
|
|
+ np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
|
|
|
|
|
+ logger.info(f"Saved codes to codes_{idx}.npy")
|
|
|
|
|
+ logger.info(f"Next sample")
|
|
|
|
|
+ idx += 1
|
|
|
|
|
+ else:
|
|
|
|
|
+ logger.error(f"Error: {response}")
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|