pre_generation_utils.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import queue
  2. from fish_speech.conversation import Conversation, Message
  3. from fish_speech.tokenizer import IM_END_TOKEN
  4. from tools.llama.generate import GenerateRequest
  5. def prepare_messages(request, tokenizer, config):
  6. """
  7. Reorganise the provided list of messages into a conversation.
  8. Encode the conversation for inference.
  9. """
  10. # Convert the messages to ConversationMessage objects
  11. messages = [msg.to_conversation_message() for msg in request.messages]
  12. if len(messages) < 1:
  13. raise ValueError("At least one message is required")
  14. # Check the last message to determine the next step
  15. last_role = messages[-1].role
  16. match last_role:
  17. case "user":
  18. # The last message is from the user, ask the assistant to respond with a new message
  19. messages.append(
  20. Message(role="assistant", parts=[], add_im_end=False, modality="voice")
  21. )
  22. case "raw":
  23. # The last message is raw text, ask the assistant to complete it
  24. messages[-1].add_im_start = False
  25. messages[-1].add_im_end = False
  26. messages[-1].modality = "voice"
  27. case "assistant":
  28. # The last message is from the assistant, ask the assistant to continue
  29. messages[-1].add_im_end = False
  30. case _:
  31. # We expect it to be assistant if not user or raw
  32. raise ValueError("The last message must be from the assistant, user or raw")
  33. # Create a conversation object and encode it for inference
  34. conv = Conversation(messages=messages)
  35. prompt = conv.encode_for_inference(
  36. tokenizer=tokenizer, num_codebooks=config.num_codebooks
  37. )
  38. im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
  39. return prompt, im_end_id
  40. def create_generation_request(prompt, request, im_end_id, device):
  41. """
  42. Convert the request into a dictionary that can be sent to the model for generation.
  43. """
  44. req = {
  45. "prompt": prompt.to(device),
  46. "max_new_tokens": request.max_new_tokens,
  47. "im_end_id": im_end_id,
  48. "temperature": request.temperature,
  49. "top_p": request.top_p,
  50. "repetition_penalty": request.repetition_penalty,
  51. "num_samples": request.num_samples,
  52. "early_stop_threshold": request.early_stop_threshold,
  53. }
  54. return req
  55. def send_generation_request(input_queue, req):
  56. """
  57. Send the generation request to the model and return a queue to get the response.
  58. """
  59. response_queue = queue.Queue()
  60. input_queue.put(GenerateRequest(req, response_queue))
  61. return response_queue