Przeglądaj źródła

Update document

Lengyue 2 lat temu
rodzic
commit
9e664473b9
3 zmienionych plików z 11 dodań i 7 usunięć
  1. 5 1
      README.zh.md
  2. BIN
      figs/diagram.png
  3. 6 6
      tools/llama/generate.py

+ 5 - 1
README.zh.md

@@ -2,6 +2,10 @@
 
 此代码库根据 BSD-3-Clause 许可证发布,所有模型根据 CC-BY-NC-SA-4.0 许可证发布。请参阅 [LICENSE](LICENSE) 了解更多细节。
 
+<p align="center">
+<img src="figs/diagram.png" width="75%">
+</p>
+
 ## 免责声明
 我们不对代码库的任何非法使用承担任何责任。请参阅您当地关于DMCA和其他相关法律的法律。
 
@@ -45,7 +49,7 @@ python tools/vqgan/inference.py -i paimon.wav --checkpoint-path checkpoints/vqga
 ```bash
 python tools/llama/generate.py \
     --text "要转换的文本" \
-    --prompt-string "你的参考文本" \
+    --prompt-text "你的参考文本" \
     --prompt-tokens "fake.npy" \
     --checkpoint-path "checkpoints/text2semantic-400m-v0.1-4k.pth" \
     --num-samples 2 \

BIN
figs/diagram.png


+ 6 - 6
tools/llama/generate.py

@@ -260,12 +260,12 @@ def encode_tokens(
     string,
     bos=True,
     device="cuda",
-    prompt_string=None,
+    prompt_text=None,
     prompt_tokens=None,
     use_g2p=False,
 ):
-    if prompt_string is not None:
-        string = prompt_string + " " + string
+    if prompt_text is not None:
+        string = prompt_text + " " + string
 
     if use_g2p:
         prompt = g2p(string)
@@ -353,7 +353,7 @@ def load_model(config_name, checkpoint_path, device, precision):
 
 @click.command()
 @click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
-@click.option("--prompt-string", type=str, default=None)
+@click.option("--prompt-text", type=str, default=None)
 @click.option(
     "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
 )
@@ -375,7 +375,7 @@ def load_model(config_name, checkpoint_path, device, precision):
 @click.option("--seed", type=int, default=42)
 def main(
     text: str,
-    prompt_string: Optional[str],
+    prompt_text: Optional[str],
     prompt_tokens: Optional[Path],
     num_samples: int,
     max_new_tokens: int,
@@ -410,7 +410,7 @@ def main(
     encoded = encode_tokens(
         tokenizer,
         text,
-        prompt_string=prompt_string,
+        prompt_text=prompt_text,
         prompt_tokens=prompt_tokens,
         bos=True,
         device=device,