__init__.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import struct
  2. from functools import partial
  3. import ormsgpack
  4. from tools.server.agent.generate import generate_responses
  5. from tools.server.agent.pre_generation_utils import prepare_messages
  6. def execute_request(input_queue, tokenizer, config, request, device):
  7. """
  8. This function prepares the conversation, encodes the request,
  9. sends the generation request, and handles decoding/streaming.
  10. It returns a response generator (ServeResponse or ServeStreamResponse).
  11. """
  12. prompt, im_end_id = prepare_messages(request, tokenizer, config)
  13. yield from generate_responses(
  14. input_queue, tokenizer, config, request, prompt, im_end_id, device
  15. )
  16. def response_generator(req, llama_queue, tokenizer, config, device):
  17. """
  18. Non-streaming response wrapper for the chat endpoint.
  19. Only returns the final result.
  20. """
  21. generator = execute_request(llama_queue, tokenizer, config, req, device)
  22. return next(generator)
  23. async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode):
  24. """
  25. Streaming response wrapper for the chat endpoint.
  26. Returns the response in chunks.
  27. """
  28. generator = execute_request(llama_queue, tokenizer, config, req, device)
  29. for i in generator:
  30. if json_mode:
  31. body = i.model_dump_json().encode("utf-8")
  32. yield b"data: " + body + b"\n\n"
  33. else:
  34. body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
  35. yield struct.pack("I", len(body)) + body
  36. def get_response_generator(
  37. llama_queue, tokenizer, config, req, device, json_mode
  38. ) -> partial:
  39. """
  40. Get the correct response generator based on the request.
  41. """
  42. if not req.streaming:
  43. return partial(response_generator, req, llama_queue, tokenizer, config, device)
  44. else:
  45. return partial(
  46. streaming_generator, req, llama_queue, tokenizer, config, device, json_mode
  47. )