test_echo.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import io
  2. import wave
  3. from typing import List
  4. import av
  5. import numpy as np
  6. from fastapi import FastAPI, WebSocket, WebSocketDisconnect
  7. from fastapi.responses import HTMLResponse
  8. app = FastAPI()
  9. html = """
  10. <!DOCTYPE html>
  11. <html>
  12. <head>
  13. <title>Real-time Chat Room</title>
  14. </head>
  15. <body>
  16. <h1>Real-time Chat Room</h1>
  17. <button id="start">Start Streaming</button>
  18. <button id="stop">Stop Streaming</button>
  19. <script type="module">
  20. import { MediaRecorder, register } from 'https://dev.jspm.io/npm:extendable-media-recorder';
  21. import { connect } from 'https://dev.jspm.io/npm:extendable-media-recorder-wav-encoder';
  22. await register(await connect());
  23. let socket;
  24. let mediaRecorder;
  25. let audioContext;
  26. function startStreaming() {
  27. initWebSocket();
  28. audioContext = new (window.AudioContext || window.webkitAudioContext)();
  29. navigator.mediaDevices.getUserMedia({ audio: {
  30. channelCount: 1,
  31. sampleRate: 44100,
  32. sampleSize: 16,
  33. echoCancellation: true,
  34. noiseSuppression: true
  35. } })
  36. .then(function (stream) {
  37. mediaRecorder = new MediaRecorder(stream, { mimeType: 'audio/webm;codecs=opus' });
  38. mediaRecorder.start(100);
  39. mediaRecorder.addEventListener("dataavailable", function (event) {
  40. socket.send(event.data);
  41. });
  42. })
  43. .catch(function (err) {
  44. console.error("Error accessing microphone:", err);
  45. });
  46. // Create a MediaSource
  47. const mediaSource = new MediaSource();
  48. const mediaStream = new MediaStream();
  49. // Create an HTMLVideoElement and attach the MediaSource to it
  50. const audioElement = document.createElement('audio');
  51. audioElement.src = URL.createObjectURL(mediaSource);
  52. audioElement.autoplay = true;
  53. document.body.appendChild(audioElement);
  54. mediaSource.addEventListener('sourceopen', function() {
  55. const sourceBuffer = mediaSource.addSourceBuffer('audio/webm; codecs=opus');
  56. socket.onmessage = function(event) {
  57. const arrayBuffer = event.data;
  58. sourceBuffer.appendBuffer(arrayBuffer);
  59. };
  60. });
  61. }
  62. function stopStreaming() {
  63. mediaRecorder.stop();
  64. }
  65. function initWebSocket() {
  66. const is_wss = window.location.protocol === "https:";
  67. socket = new WebSocket(`${is_wss ? "wss" : "ws"}://${window.location.host}/ws`);
  68. socket.binaryType = 'arraybuffer';
  69. }
  70. document.getElementById("start").onclick = startStreaming;
  71. document.getElementById("stop").onclick = stopStreaming;
  72. </script>
  73. </body>
  74. </html>
  75. """
  76. def encode_wav(data):
  77. sample_rate = 44100
  78. samples = np.frombuffer(data, dtype=np.int16)
  79. buffer = io.BytesIO()
  80. with wave.open(buffer, "wb") as wav_file:
  81. wav_file.setnchannels(1)
  82. wav_file.setsampwidth(2)
  83. wav_file.setframerate(sample_rate)
  84. wav_file.writeframes(samples.tobytes())
  85. return buffer.getvalue()
  86. class ConnectionManager:
  87. def __init__(self):
  88. self.active_connections: List[WebSocket] = []
  89. async def connect(self, websocket: WebSocket):
  90. await websocket.accept()
  91. self.active_connections.append(websocket)
  92. def disconnect(self, websocket: WebSocket):
  93. self.active_connections.remove(websocket)
  94. async def broadcast(self, message: bytes, sender: WebSocket):
  95. for connection in self.active_connections:
  96. if connection == sender:
  97. # print("Sending message to client", connection)
  98. await connection.send_bytes(message)
  99. manager = ConnectionManager()
  100. @app.get("/")
  101. async def get():
  102. return HTMLResponse(html)
  103. @app.websocket("/ws")
  104. async def websocket_endpoint(websocket: WebSocket):
  105. await manager.connect(websocket)
  106. try:
  107. buffer = io.BytesIO()
  108. container = None
  109. cur_pos = 0
  110. total_size = 0
  111. while True:
  112. data = await websocket.receive_bytes()
  113. # data = encode_wav(data)
  114. # if len(data) == 1:
  115. # print(f"len(data): {len(data)}, data: {data}")
  116. # if len(data) > 1:
  117. # data = b'\x1a' + data
  118. # with open("output.webm", "wb") as f:
  119. # f.write(data)
  120. # exit()
  121. # print(f"len(data): {len(data)}")
  122. # print("Received data:", data)
  123. # Save as webm file and exit
  124. # with open("output.wav", "wb") as f:
  125. # f.write(encode_wav(data))
  126. buffer.write(data)
  127. buffer.seek(cur_pos)
  128. total_size += len(data)
  129. if not container and total_size > 1000:
  130. container = av.open(buffer, "r", format="webm")
  131. print(container)
  132. elif container:
  133. for packet in container.decode(video=0):
  134. if packet.size == 0:
  135. continue
  136. cur_pos += packet.size
  137. for frame in packet.decode():
  138. print(frame.to_ndarray().shape)
  139. await manager.broadcast(data, websocket)
  140. except WebSocketDisconnect:
  141. manager.disconnect(websocket)
  142. if __name__ == "__main__":
  143. import uvicorn
  144. uvicorn.run(app, host="0.0.0.0", port=8000)