| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import time
- from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse
- from tools.server.agent.generation_utils import (
- initialize_decode_buffers,
- process_response_tokens,
- send_reset_buffer,
- )
- from tools.server.agent.pre_generation_utils import (
- create_generation_request,
- send_generation_request,
- )
- def generate_responses(
- input_queue, tokenizer, config, request, prompt, im_end_id, device
- ):
- """
- Main generation function that handles the conversation, encodes the request,
- sends the generation request, and handles decoding/streaming.
- It returns a response generator (ServeResponse or ServeStreamResponse).
- """
- stats = {}
- start = time.time()
- stats["start_time"] = start
- stats["tokens_count"] = 0
- # Prepare and send the generation request
- req = create_generation_request(prompt, request, im_end_id, device)
- response_queue = send_generation_request(input_queue, req)
- decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples)
- while True:
- response = response_queue.get()
- # Handle abnormal finish or error
- if response in ["stop", "error"]:
- finish_reason = response
- break
- # Process the response tokens
- is_first_token = stats["tokens_count"] == 0
- responses = process_response_tokens(
- response,
- tokenizer,
- config,
- request,
- decode_buffer,
- parts,
- finished,
- im_end_id,
- stats,
- start,
- is_first_token,
- )
- # Yield the responses if streaming
- if request.streaming and responses:
- for r in responses:
- yield r
- stats["tokens_count"] += 1
- # Check if all samples are finished
- if all(finished):
- finish_reason = "stop"
- break
- # Finalize the response
- final_responses = finalize_response(
- request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
- )
- for fr in final_responses:
- yield fr
- def finalize_response(
- request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
- ):
- """
- Finalize the response by sending the remaining text buffers.
- """
- responses = []
- # Send the remaining text buffers
- for sample_id in range(request.num_samples):
- responses.extend(
- send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
- )
- # Calculate the final stats
- stats["total_time"] = (time.time() - stats["start_time"]) * 1000
- stats["total_tokens"] = stats["tokens_count"]
- # If streaming, send the final chunks for each sample
- if request.streaming:
- for sample_id in range(request.num_samples):
- if finished[sample_id]:
- continue
- responses.append(
- ServeStreamResponse(
- finish_reason=finish_reason, stats=stats, sample_id=sample_id
- )
- )
- else:
- # If not streaming, send the full messages for each sample
- full_messages = [
- ServeMessage(role="assistant", parts=parts[i])
- for i in range(request.num_samples)
- ]
- responses.append(
- ServeResponse(
- messages=full_messages,
- finish_reason=finish_reason,
- stats=stats,
- )
- )
- return responses
|