|
|
@@ -1026,6 +1026,7 @@ def launch_thread_safe_queue_agent(
|
|
|
@click.option("--half/--no-half", default=False)
|
|
|
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
|
|
|
@click.option("--chunk-length", type=int, default=100)
|
|
|
+@click.option("--output-dir", type=Path, default="temp")
|
|
|
def main(
|
|
|
text: str,
|
|
|
prompt_text: Optional[list[str]],
|
|
|
@@ -1042,8 +1043,9 @@ def main(
|
|
|
half: bool,
|
|
|
iterative_prompt: bool,
|
|
|
chunk_length: int,
|
|
|
+ output_dir: Path,
|
|
|
) -> None:
|
|
|
-
|
|
|
+ os.makedirs(output_dir, exist_ok=True)
|
|
|
precision = torch.half if half else torch.bfloat16
|
|
|
|
|
|
if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
|
|
|
@@ -1101,8 +1103,9 @@ def main(
|
|
|
logger.info(f"Sampled text: {response.text}")
|
|
|
elif response.action == "next":
|
|
|
if codes:
|
|
|
- np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
|
|
|
- logger.info(f"Saved codes to codes_{idx}.npy")
|
|
|
+ codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
|
|
|
+ np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
|
|
|
+ logger.info(f"Saved codes to {codes_npy_path}")
|
|
|
logger.info(f"Next sample")
|
|
|
codes = []
|
|
|
idx += 1
|