|
|
@@ -339,7 +339,7 @@ def generate_long(
|
|
|
temperature: float = 0.8,
|
|
|
compile: bool = False,
|
|
|
iterative_prompt: bool = True,
|
|
|
- chunk_length: int = 150,
|
|
|
+ chunk_length: int = 512,
|
|
|
prompt_text: Optional[str | list[str]] = None,
|
|
|
prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
|
|
|
):
|
|
|
@@ -365,6 +365,24 @@ def generate_long(
|
|
|
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
|
max_length = model.config.max_seq_len
|
|
|
|
|
|
+ # if use_prompt:
|
|
|
+ # base_content_sequence.append(
|
|
|
+ # [
|
|
|
+ # TextPart(text=prompt_text[0]),
|
|
|
+ # VQPart(codes=prompt_tokens[0]),
|
|
|
+ # ],
|
|
|
+ # add_end=True,
|
|
|
+ # )
|
|
|
+
|
|
|
+ # for text in texts:
|
|
|
+ # content_sequence = ContentSequence(modality=None)
|
|
|
+ # base_content_sequence.append(
|
|
|
+ # [
|
|
|
+ # TextPart(text=text),
|
|
|
+ # ],
|
|
|
+ # add_end=True,
|
|
|
+ # )
|
|
|
+
|
|
|
if use_prompt:
|
|
|
for t, c in zip(prompt_text, prompt_tokens):
|
|
|
base_content_sequence.append(
|
|
|
@@ -385,7 +403,7 @@ def generate_long(
|
|
|
|
|
|
encoded = []
|
|
|
for text in texts:
|
|
|
- content_sequence = ContentSequence(modality=None)
|
|
|
+ content_sequence = ContentSequence(modality="text")
|
|
|
content_sequence.append(TextPart(text=text))
|
|
|
encoded.append(
|
|
|
content_sequence.encode_for_inference(
|