conversation.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from copy import deepcopy
  2. from dataclasses import dataclass, field
  3. from typing import Literal
  4. import torch
  5. from transformers import PreTrainedTokenizerFast
  6. from fish_speech.content_sequence import (
  7. AudioPart,
  8. BasePart,
  9. ContentSequence,
  10. EncodedMessage,
  11. TextPart,
  12. VQPart,
  13. )
  14. from fish_speech.tokenizer import IM_END_TOKEN, IM_START_TOKEN, MODALITY_TOKENS
  15. @dataclass(kw_only=True)
  16. class Message:
  17. role: Literal["system", "user", "assistant"]
  18. parts: list[BasePart] = field(default_factory=list)
  19. add_im_start: bool = True
  20. add_im_end: bool = True
  21. cal_loss: bool = False
  22. modality: Literal["text", "voice", "interleave"] | None = None
  23. # By default, ignore the loss of the auto-generated im_start token
  24. ignore_im_start_loss: bool = True
  25. @dataclass
  26. class Conversation:
  27. messages: list[Message]
  28. def __init__(self: "Conversation", messages: list[Message] | None = None):
  29. self.messages = messages or []
  30. def _build_content_sequence(
  31. self: "Conversation",
  32. metadata: dict | None = None,
  33. ) -> ContentSequence:
  34. """
  35. Build a ContentSequence from all messages.
  36. Handles cal_loss inheritance from message to part level.
  37. """
  38. all_parts = []
  39. for message in self.messages:
  40. # Add im_start
  41. if message.add_im_start:
  42. modality_token = (
  43. MODALITY_TOKENS[message.modality] if message.modality else ""
  44. )
  45. all_parts.append(
  46. TextPart(
  47. text=f"{IM_START_TOKEN}{message.role}\n{modality_token}",
  48. cal_loss=not message.ignore_im_start_loss,
  49. )
  50. )
  51. # Add message parts
  52. for part in message.parts:
  53. # Inherit cal_loss from message if not set at part level
  54. if not hasattr(part, "cal_loss") or part.cal_loss is False:
  55. new_part = deepcopy(part)
  56. new_part.cal_loss = message.cal_loss
  57. all_parts.append(new_part)
  58. else:
  59. all_parts.append(part)
  60. # Add im_end
  61. if message.add_im_end:
  62. all_parts.append(
  63. TextPart(text=IM_END_TOKEN + "\n", cal_loss=message.cal_loss)
  64. )
  65. return ContentSequence(parts=all_parts, modality=None, metadata=metadata)
  66. def encode(
  67. self: "Conversation",
  68. tokenizer: any,
  69. add_shift: bool = True,
  70. ignore_loss_tokens: list[str] = [],
  71. metadata: dict | None = None,
  72. max_length: int | None = None,
  73. ) -> EncodedMessage:
  74. # Build ContentSequence from messages
  75. content_seq = self._build_content_sequence(metadata=metadata)
  76. return content_seq.encode(
  77. tokenizer,
  78. add_shift=add_shift,
  79. ignore_loss_tokens=ignore_loss_tokens,
  80. max_length=max_length,
  81. )
  82. def encode_for_inference(
  83. self: "Conversation",
  84. tokenizer: any,
  85. num_codebooks: int,
  86. metadata: dict | None = None,
  87. ):
  88. content_seq = self._build_content_sequence(metadata=metadata)
  89. return content_seq.encode_for_inference(tokenizer, num_codebooks=num_codebooks)
  90. def visualize(
  91. self: "Conversation",
  92. tokenizer: PreTrainedTokenizerFast,
  93. ignore_loss_tokens: list[str] = [],
  94. merge_semantic_tokens: bool = False,
  95. merge_audio_tokens: bool = False,
  96. use_color: bool = True,
  97. ):
  98. """
  99. Visualize the encoded sequence with color-coded tokens.
  100. Blue/cyan tokens contribute to loss, green tokens do not.
  101. """
  102. # Build ContentSequence from messages and use its visualize method
  103. content_seq = self._build_content_sequence()
  104. content_seq.visualize(
  105. tokenizer,
  106. ignore_loss_tokens=ignore_loss_tokens,
  107. merge_semantic_tokens=merge_semantic_tokens,
  108. )
  109. def append(self: "Conversation", message: Message):
  110. self.messages.append(message)
  111. def to_content_sequence(
  112. self: "Conversation",
  113. metadata: dict | None = None,
  114. ) -> ContentSequence:
  115. """
  116. Convert the Conversation to a ContentSequence.
  117. This method builds a ContentSequence from all messages,
  118. handling cal_loss inheritance from message to part level.
  119. Args:
  120. metadata: Optional metadata to include in the ContentSequence
  121. Returns:
  122. ContentSequence with all messages converted to parts
  123. """
  124. return self._build_content_sequence(metadata=metadata)
  125. if __name__ == "__main__":
  126. # Test the new implementation with the same API
  127. message0 = Message(
  128. role="user",
  129. parts=[
  130. TextPart(text="Hello, how are you?"),
  131. VQPart(codes=torch.zeros((4, 10))),
  132. ],
  133. cal_loss=False,
  134. )
  135. message1 = Message(
  136. role="assistant",
  137. parts=[TextPart(text="I'm fine, thank you.")],
  138. cal_loss=True,
  139. )
  140. conversation = Conversation([message0, message1])
  141. tokenizer = PreTrainedTokenizerFast.from_pretrained("checkpoints/agent-0.6b-debug")
  142. # Test with enhanced visualization from ContentSequence
  143. print("Basic visualization:")
  144. conversation.visualize(tokenizer)
  145. print("\nWith merged semantic tokens:")
  146. conversation.visualize(tokenizer, merge_semantic_tokens=True)
  147. print("\nWithout colors:")
  148. conversation.visualize(tokenizer, use_color=False)