generation_utils.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import time
  2. from tools.schema import (
  3. ServeStreamDelta,
  4. ServeStreamResponse,
  5. ServeTextPart,
  6. ServeVQPart,
  7. )
  8. def initialize_decode_buffers(num_samples):
  9. """Initialise the decode buffers for each sample."""
  10. decode_buffer = [[] for _ in range(num_samples)]
  11. parts = [[] for _ in range(num_samples)]
  12. finished = [False for _ in range(num_samples)]
  13. return decode_buffer, parts, finished
  14. def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request):
  15. """Send the remaining text buffer for a sample."""
  16. if len(decode_buffer[sample_id]) == 0:
  17. return []
  18. decoded = tokenizer.decode(decode_buffer[sample_id])
  19. part = ServeTextPart(text=decoded)
  20. responses = []
  21. if request.streaming:
  22. responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part)))
  23. else:
  24. parts[sample_id].append(part)
  25. decode_buffer[sample_id] = []
  26. return responses
  27. def handle_semantic_tokens(tokens, config, sample_id, parts, request):
  28. """Handle the semantic tokens returned by the model."""
  29. responses = []
  30. _tokens = tokens[1:].clone()
  31. if not config.share_codebook_embeddings:
  32. for i in range(len(_tokens)):
  33. _tokens[i] -= config.codebook_size * i
  34. # If streaming, send the VQ parts directly
  35. if request.streaming:
  36. responses.append(
  37. ServeStreamResponse(
  38. sample_id=sample_id,
  39. delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
  40. )
  41. )
  42. else:
  43. # If not streaming, accumulate the VQ parts
  44. if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart):
  45. parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
  46. else:
  47. # Accumulate the codes
  48. for codebook_id, value in enumerate(_tokens):
  49. parts[sample_id][-1].codes[codebook_id].append(value.item())
  50. return responses
  51. def process_response_tokens(
  52. response,
  53. tokenizer,
  54. config,
  55. request,
  56. decode_buffer,
  57. parts,
  58. finished,
  59. im_end_id,
  60. stats,
  61. start,
  62. is_first_token,
  63. ):
  64. """Process the response tokens returned by the model."""
  65. responses = []
  66. for sample_id, tokens in enumerate(response):
  67. if finished[sample_id]:
  68. continue
  69. # End of the conversation
  70. if tokens[0] == im_end_id:
  71. finished[sample_id] = True
  72. # Send the remaining text buffer
  73. responses.extend(
  74. send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
  75. )
  76. if request.streaming:
  77. responses.append(
  78. ServeStreamResponse(
  79. sample_id=sample_id,
  80. finish_reason="stop",
  81. stats=stats,
  82. )
  83. )
  84. continue
  85. # Check if the token is semantic
  86. is_semantic = (
  87. tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
  88. )
  89. if is_semantic:
  90. # Before the semantic tokens, send the remaining text buffer
  91. responses.extend(
  92. send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
  93. )
  94. responses.extend(
  95. handle_semantic_tokens(tokens, config, sample_id, parts, request)
  96. )
  97. else:
  98. # Accumulate the text tokens (not implemented?)
  99. decode_buffer[sample_id].append(tokens[0, 0])
  100. if is_first_token:
  101. stats["time_to_first_token"] = (time.time() - start) * 1000
  102. return responses