generate.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import time
  2. from fish_speech.utils.schema import ServeMessage, ServeResponse, ServeStreamResponse
  3. from tools.server.agent.generation_utils import (
  4. initialize_decode_buffers,
  5. process_response_tokens,
  6. send_reset_buffer,
  7. )
  8. from tools.server.agent.pre_generation_utils import (
  9. create_generation_request,
  10. send_generation_request,
  11. )
  12. def generate_responses(
  13. input_queue, tokenizer, config, request, prompt, im_end_id, device
  14. ):
  15. """
  16. Main generation function that handles the conversation, encodes the request,
  17. sends the generation request, and handles decoding/streaming.
  18. It returns a response generator (ServeResponse or ServeStreamResponse).
  19. """
  20. stats = {}
  21. start = time.time()
  22. stats["start_time"] = start
  23. stats["tokens_count"] = 0
  24. # Prepare and send the generation request
  25. req = create_generation_request(prompt, request, im_end_id, device)
  26. response_queue = send_generation_request(input_queue, req)
  27. decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples)
  28. while True:
  29. response = response_queue.get()
  30. # Handle abnormal finish or error
  31. if response in ["stop", "error"]:
  32. finish_reason = response
  33. break
  34. # Process the response tokens
  35. is_first_token = stats["tokens_count"] == 0
  36. responses = process_response_tokens(
  37. response,
  38. tokenizer,
  39. config,
  40. request,
  41. decode_buffer,
  42. parts,
  43. finished,
  44. im_end_id,
  45. stats,
  46. start,
  47. is_first_token,
  48. )
  49. # Yield the responses if streaming
  50. if request.streaming and responses:
  51. for r in responses:
  52. yield r
  53. stats["tokens_count"] += 1
  54. # Check if all samples are finished
  55. if all(finished):
  56. finish_reason = "stop"
  57. break
  58. # Finalize the response
  59. final_responses = finalize_response(
  60. request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
  61. )
  62. for fr in final_responses:
  63. yield fr
  64. def finalize_response(
  65. request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
  66. ):
  67. """
  68. Finalize the response by sending the remaining text buffers.
  69. """
  70. responses = []
  71. # Send the remaining text buffers
  72. for sample_id in range(request.num_samples):
  73. responses.extend(
  74. send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
  75. )
  76. # Calculate the final stats
  77. stats["total_time"] = (time.time() - stats["start_time"]) * 1000
  78. stats["total_tokens"] = stats["tokens_count"]
  79. # If streaming, send the final chunks for each sample
  80. if request.streaming:
  81. for sample_id in range(request.num_samples):
  82. if finished[sample_id]:
  83. continue
  84. responses.append(
  85. ServeStreamResponse(
  86. finish_reason=finish_reason, stats=stats, sample_id=sample_id
  87. )
  88. )
  89. else:
  90. # If not streaming, send the full messages for each sample
  91. full_messages = [
  92. ServeMessage(role="assistant", parts=parts[i])
  93. for i in range(request.num_samples)
  94. ]
  95. responses.append(
  96. ServeResponse(
  97. messages=full_messages,
  98. finish_reason=finish_reason,
  99. stats=stats,
  100. )
  101. )
  102. return responses