| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- from copy import deepcopy
- from dataclasses import dataclass, field
- from typing import Literal
- import torch
- from transformers import PreTrainedTokenizerFast
- from fish_speech.content_sequence import (
- AudioPart,
- BasePart,
- ContentSequence,
- EncodedMessage,
- TextPart,
- VQPart,
- )
- from fish_speech.tokenizer import IM_END_TOKEN, IM_START_TOKEN, MODALITY_TOKENS
- @dataclass(kw_only=True)
- class Message:
- role: Literal["system", "user", "assistant"]
- parts: list[BasePart] = field(default_factory=list)
- add_im_start: bool = True
- add_im_end: bool = True
- cal_loss: bool = False
- modality: Literal["text", "voice", "interleave"] | None = None
- # By default, ignore the loss of the auto-generated im_start token
- ignore_im_start_loss: bool = True
- @dataclass
- class Conversation:
- messages: list[Message]
- def __init__(self: "Conversation", messages: list[Message] | None = None):
- self.messages = messages or []
- def _build_content_sequence(
- self: "Conversation",
- metadata: dict | None = None,
- ) -> ContentSequence:
- """
- Build a ContentSequence from all messages.
- Handles cal_loss inheritance from message to part level.
- """
- all_parts = []
- for message in self.messages:
- # Add im_start
- if message.add_im_start:
- modality_token = (
- MODALITY_TOKENS[message.modality] if message.modality else ""
- )
- all_parts.append(
- TextPart(
- text=f"{IM_START_TOKEN}{message.role}\n{modality_token}",
- cal_loss=not message.ignore_im_start_loss,
- )
- )
- # Add message parts
- for part in message.parts:
- # Inherit cal_loss from message if not set at part level
- if not hasattr(part, "cal_loss") or part.cal_loss is False:
- new_part = deepcopy(part)
- new_part.cal_loss = message.cal_loss
- all_parts.append(new_part)
- else:
- all_parts.append(part)
- # Add im_end
- if message.add_im_end:
- all_parts.append(
- TextPart(text=IM_END_TOKEN + "\n", cal_loss=message.cal_loss)
- )
- return ContentSequence(parts=all_parts, modality=None, metadata=metadata)
- def encode(
- self: "Conversation",
- tokenizer: any,
- add_shift: bool = True,
- ignore_loss_tokens: list[str] = [],
- metadata: dict | None = None,
- max_length: int | None = None,
- ) -> EncodedMessage:
- # Build ContentSequence from messages
- content_seq = self._build_content_sequence(metadata=metadata)
- return content_seq.encode(
- tokenizer,
- add_shift=add_shift,
- ignore_loss_tokens=ignore_loss_tokens,
- max_length=max_length,
- )
- def encode_for_inference(
- self: "Conversation",
- tokenizer: any,
- num_codebooks: int,
- metadata: dict | None = None,
- ):
- content_seq = self._build_content_sequence(metadata=metadata)
- return content_seq.encode_for_inference(tokenizer, num_codebooks=num_codebooks)
- def visualize(
- self: "Conversation",
- tokenizer: PreTrainedTokenizerFast,
- ignore_loss_tokens: list[str] = [],
- merge_semantic_tokens: bool = False,
- merge_audio_tokens: bool = False,
- use_color: bool = True,
- ):
- """
- Visualize the encoded sequence with color-coded tokens.
- Blue/cyan tokens contribute to loss, green tokens do not.
- """
- # Build ContentSequence from messages and use its visualize method
- content_seq = self._build_content_sequence()
- content_seq.visualize(
- tokenizer,
- ignore_loss_tokens=ignore_loss_tokens,
- merge_semantic_tokens=merge_semantic_tokens,
- )
- def append(self: "Conversation", message: Message):
- self.messages.append(message)
- def to_content_sequence(
- self: "Conversation",
- metadata: dict | None = None,
- ) -> ContentSequence:
- """
- Convert the Conversation to a ContentSequence.
- This method builds a ContentSequence from all messages,
- handling cal_loss inheritance from message to part level.
- Args:
- metadata: Optional metadata to include in the ContentSequence
- Returns:
- ContentSequence with all messages converted to parts
- """
- return self._build_content_sequence(metadata=metadata)
- if __name__ == "__main__":
- # Test the new implementation with the same API
- message0 = Message(
- role="user",
- parts=[
- TextPart(text="Hello, how are you?"),
- VQPart(codes=torch.zeros((4, 10))),
- ],
- cal_loss=False,
- )
- message1 = Message(
- role="assistant",
- parts=[TextPart(text="I'm fine, thank you.")],
- cal_loss=True,
- )
- conversation = Conversation([message0, message1])
- tokenizer = PreTrainedTokenizerFast.from_pretrained("checkpoints/agent-0.6b-debug")
- # Test with enhanced visualization from ContentSequence
- print("Basic visualization:")
- conversation.visualize(tokenizer)
- print("\nWith merged semantic tokens:")
- conversation.visualize(tokenizer, merge_semantic_tokens=True)
- print("\nWithout colors:")
- conversation.visualize(tokenizer, use_color=False)
|