| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- import time
- from tools.schema import (
- ServeStreamDelta,
- ServeStreamResponse,
- ServeTextPart,
- ServeVQPart,
- )
- def initialize_decode_buffers(num_samples):
- """Initialise the decode buffers for each sample."""
- decode_buffer = [[] for _ in range(num_samples)]
- parts = [[] for _ in range(num_samples)]
- finished = [False for _ in range(num_samples)]
- return decode_buffer, parts, finished
- def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request):
- """Send the remaining text buffer for a sample."""
- if len(decode_buffer[sample_id]) == 0:
- return []
- decoded = tokenizer.decode(decode_buffer[sample_id])
- part = ServeTextPart(text=decoded)
- responses = []
- if request.streaming:
- responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part)))
- else:
- parts[sample_id].append(part)
- decode_buffer[sample_id] = []
- return responses
- def handle_semantic_tokens(tokens, config, sample_id, parts, request):
- """Handle the semantic tokens returned by the model."""
- responses = []
- _tokens = tokens[1:].clone()
- if not config.share_codebook_embeddings:
- for i in range(len(_tokens)):
- _tokens[i] -= config.codebook_size * i
- # If streaming, send the VQ parts directly
- if request.streaming:
- responses.append(
- ServeStreamResponse(
- sample_id=sample_id,
- delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
- )
- )
- else:
- # If not streaming, accumulate the VQ parts
- if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart):
- parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
- else:
- # Accumulate the codes
- for codebook_id, value in enumerate(_tokens):
- parts[sample_id][-1].codes[codebook_id].append(value.item())
- return responses
- def process_response_tokens(
- response,
- tokenizer,
- config,
- request,
- decode_buffer,
- parts,
- finished,
- im_end_id,
- stats,
- start,
- is_first_token,
- ):
- """Process the response tokens returned by the model."""
- responses = []
- for sample_id, tokens in enumerate(response):
- if finished[sample_id]:
- continue
- # End of the conversation
- if tokens[0] == im_end_id:
- finished[sample_id] = True
- # Send the remaining text buffer
- responses.extend(
- send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
- )
- if request.streaming:
- responses.append(
- ServeStreamResponse(
- sample_id=sample_id,
- finish_reason="stop",
- stats=stats,
- )
- )
- continue
- # Check if the token is semantic
- is_semantic = (
- tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
- )
- if is_semantic:
- # Before the semantic tokens, send the remaining text buffer
- responses.extend(
- send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
- )
- responses.extend(
- handle_semantic_tokens(tokens, config, sample_id, parts, request)
- )
- else:
- # Accumulate the text tokens (not implemented?)
- decode_buffer[sample_id].append(tokens[0, 0])
- if is_first_token:
- stats["time_to_first_token"] = (time.time() - start) * 1000
- return responses
|