tts.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. package volcengine
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "strings"
  11. "github.com/QuantumNous/new-api/dto"
  12. relaycommon "github.com/QuantumNous/new-api/relay/common"
  13. "github.com/QuantumNous/new-api/types"
  14. "github.com/gin-gonic/gin"
  15. "github.com/google/uuid"
  16. "github.com/gorilla/websocket"
  17. )
  18. type VolcengineTTSRequest struct {
  19. App VolcengineTTSApp `json:"app"`
  20. User VolcengineTTSUser `json:"user"`
  21. Audio VolcengineTTSAudio `json:"audio"`
  22. Request VolcengineTTSReqInfo `json:"request"`
  23. }
  24. type VolcengineTTSApp struct {
  25. AppID string `json:"appid"`
  26. Token string `json:"token"`
  27. Cluster string `json:"cluster"`
  28. }
  29. type VolcengineTTSUser struct {
  30. UID string `json:"uid"`
  31. }
  32. type VolcengineTTSAudio struct {
  33. VoiceType string `json:"voice_type"`
  34. Encoding string `json:"encoding"`
  35. SpeedRatio float64 `json:"speed_ratio"`
  36. Rate int `json:"rate"`
  37. Bitrate int `json:"bitrate,omitempty"`
  38. LoudnessRatio float64 `json:"loudness_ratio,omitempty"`
  39. EnableEmotion bool `json:"enable_emotion,omitempty"`
  40. Emotion string `json:"emotion,omitempty"`
  41. EmotionScale float64 `json:"emotion_scale,omitempty"`
  42. ExplicitLanguage string `json:"explicit_language,omitempty"`
  43. ContextLanguage string `json:"context_language,omitempty"`
  44. }
  45. type VolcengineTTSReqInfo struct {
  46. ReqID string `json:"reqid"`
  47. Text string `json:"text"`
  48. Operation string `json:"operation"`
  49. Model string `json:"model,omitempty"`
  50. TextType string `json:"text_type,omitempty"`
  51. SilenceDuration float64 `json:"silence_duration,omitempty"`
  52. WithTimestamp interface{} `json:"with_timestamp,omitempty"`
  53. ExtraParam *VolcengineTTSExtraParam `json:"extra_param,omitempty"`
  54. }
  55. type VolcengineTTSExtraParam struct {
  56. DisableMarkdownFilter bool `json:"disable_markdown_filter,omitempty"`
  57. EnableLatexTn bool `json:"enable_latex_tn,omitempty"`
  58. MuteCutThreshold string `json:"mute_cut_threshold,omitempty"`
  59. MuteCutRemainMs string `json:"mute_cut_remain_ms,omitempty"`
  60. DisableEmojiFilter bool `json:"disable_emoji_filter,omitempty"`
  61. UnsupportedCharRatioThresh float64 `json:"unsupported_char_ratio_thresh,omitempty"`
  62. AigcWatermark bool `json:"aigc_watermark,omitempty"`
  63. CacheConfig *VolcengineTTSCacheConfig `json:"cache_config,omitempty"`
  64. }
  65. type VolcengineTTSCacheConfig struct {
  66. TextType int `json:"text_type,omitempty"`
  67. UseCache bool `json:"use_cache,omitempty"`
  68. }
  69. type VolcengineTTSResponse struct {
  70. ReqID string `json:"reqid"`
  71. Code int `json:"code"`
  72. Message string `json:"message"`
  73. Sequence int `json:"sequence"`
  74. Data string `json:"data"`
  75. Addition *VolcengineTTSAdditionInfo `json:"addition,omitempty"`
  76. }
  77. type VolcengineTTSAdditionInfo struct {
  78. Duration string `json:"duration"`
  79. }
  80. var openAIToVolcengineVoiceMap = map[string]string{
  81. "alloy": "zh_male_M392_conversation_wvae_bigtts",
  82. "echo": "zh_male_wenhao_mars_bigtts",
  83. "fable": "zh_female_tianmei_mars_bigtts",
  84. "onyx": "zh_male_zhibei_mars_bigtts",
  85. "nova": "zh_female_shuangkuaisisi_mars_bigtts",
  86. "shimmer": "zh_female_cancan_mars_bigtts",
  87. }
  88. var responseFormatToEncodingMap = map[string]string{
  89. "mp3": "mp3",
  90. "opus": "ogg_opus",
  91. "aac": "mp3",
  92. "flac": "mp3",
  93. "wav": "wav",
  94. "pcm": "pcm",
  95. }
  96. func parseVolcengineAuth(apiKey string) (appID, token string, err error) {
  97. parts := strings.Split(apiKey, "|")
  98. if len(parts) != 2 {
  99. return "", "", errors.New("invalid api key format, expected: appid|access_token")
  100. }
  101. return parts[0], parts[1], nil
  102. }
  103. func mapVoiceType(openAIVoice string) string {
  104. if voice, ok := openAIToVolcengineVoiceMap[openAIVoice]; ok {
  105. return voice
  106. }
  107. return openAIVoice
  108. }
  109. func mapEncoding(responseFormat string) string {
  110. if encoding, ok := responseFormatToEncodingMap[responseFormat]; ok {
  111. return encoding
  112. }
  113. return "mp3"
  114. }
  115. func getContentTypeByEncoding(encoding string) string {
  116. contentTypeMap := map[string]string{
  117. "mp3": "audio/mpeg",
  118. "ogg_opus": "audio/ogg",
  119. "wav": "audio/wav",
  120. "pcm": "audio/pcm",
  121. }
  122. if ct, ok := contentTypeMap[encoding]; ok {
  123. return ct
  124. }
  125. return "application/octet-stream"
  126. }
  127. func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
  128. body, readErr := io.ReadAll(resp.Body)
  129. if readErr != nil {
  130. return nil, types.NewErrorWithStatusCode(
  131. errors.New("failed to read volcengine response"),
  132. types.ErrorCodeReadResponseBodyFailed,
  133. http.StatusInternalServerError,
  134. )
  135. }
  136. defer resp.Body.Close()
  137. var volcResp VolcengineTTSResponse
  138. if unmarshalErr := json.Unmarshal(body, &volcResp); unmarshalErr != nil {
  139. return nil, types.NewErrorWithStatusCode(
  140. errors.New("failed to parse volcengine response"),
  141. types.ErrorCodeBadResponseBody,
  142. http.StatusInternalServerError,
  143. )
  144. }
  145. if volcResp.Code != 3000 {
  146. return nil, types.NewErrorWithStatusCode(
  147. errors.New(volcResp.Message),
  148. types.ErrorCodeBadResponse,
  149. http.StatusBadRequest,
  150. )
  151. }
  152. audioData, decodeErr := base64.StdEncoding.DecodeString(volcResp.Data)
  153. if decodeErr != nil {
  154. return nil, types.NewErrorWithStatusCode(
  155. errors.New("failed to decode audio data"),
  156. types.ErrorCodeBadResponseBody,
  157. http.StatusInternalServerError,
  158. )
  159. }
  160. contentType := getContentTypeByEncoding(encoding)
  161. c.Header("Content-Type", contentType)
  162. c.Data(http.StatusOK, contentType, audioData)
  163. usage = &dto.Usage{
  164. PromptTokens: info.PromptTokens,
  165. CompletionTokens: 0,
  166. TotalTokens: info.PromptTokens,
  167. }
  168. return usage, nil
  169. }
  170. func generateRequestID() string {
  171. return uuid.New().String()
  172. }
  173. // handleTTSWebSocketResponse handles streaming TTS response via WebSocket
  174. func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
  175. // Parse API key for auth
  176. _, token, parseErr := parseVolcengineAuth(info.ApiKey)
  177. if parseErr != nil {
  178. return nil, types.NewErrorWithStatusCode(
  179. parseErr,
  180. types.ErrorCodeChannelInvalidKey,
  181. http.StatusUnauthorized,
  182. )
  183. }
  184. // Setup WebSocket headers
  185. header := http.Header{}
  186. header.Set("Authorization", fmt.Sprintf("Bearer;%s", token))
  187. // Dial WebSocket connection
  188. conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header)
  189. if dialErr != nil {
  190. if resp != nil {
  191. return nil, types.NewErrorWithStatusCode(
  192. fmt.Errorf("failed to connect to websocket: %w, status: %d", dialErr, resp.StatusCode),
  193. types.ErrorCodeBadResponseStatusCode,
  194. http.StatusBadGateway,
  195. )
  196. }
  197. return nil, types.NewErrorWithStatusCode(
  198. fmt.Errorf("failed to connect to websocket: %w", dialErr),
  199. types.ErrorCodeBadResponseStatusCode,
  200. http.StatusBadGateway,
  201. )
  202. }
  203. defer conn.Close()
  204. // Update request operation to "submit" for WebSocket
  205. volcRequest.Request.Operation = "submit"
  206. // Marshal request payload
  207. payload, marshalErr := json.Marshal(volcRequest)
  208. if marshalErr != nil {
  209. return nil, types.NewErrorWithStatusCode(
  210. fmt.Errorf("failed to marshal request: %w", marshalErr),
  211. types.ErrorCodeBadRequestBody,
  212. http.StatusInternalServerError,
  213. )
  214. }
  215. // Send full client request
  216. if sendErr := FullClientRequest(conn, payload); sendErr != nil {
  217. return nil, types.NewErrorWithStatusCode(
  218. fmt.Errorf("failed to send request: %w", sendErr),
  219. types.ErrorCodeBadRequestBody,
  220. http.StatusInternalServerError,
  221. )
  222. }
  223. // Set response headers
  224. contentType := getContentTypeByEncoding(encoding)
  225. c.Header("Content-Type", contentType)
  226. c.Header("Transfer-Encoding", "chunked")
  227. // Stream audio data
  228. var audioBuffer []byte
  229. for {
  230. msg, recvErr := ReceiveMessage(conn)
  231. if recvErr != nil {
  232. if websocket.IsCloseError(recvErr, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  233. break
  234. }
  235. return nil, types.NewErrorWithStatusCode(
  236. fmt.Errorf("failed to receive message: %w", recvErr),
  237. types.ErrorCodeBadResponse,
  238. http.StatusInternalServerError,
  239. )
  240. }
  241. switch msg.MsgType {
  242. case MsgTypeError:
  243. return nil, types.NewErrorWithStatusCode(
  244. fmt.Errorf("received error from server: code=%d, %s", msg.ErrorCode, string(msg.Payload)),
  245. types.ErrorCodeBadResponse,
  246. http.StatusBadRequest,
  247. )
  248. case MsgTypeFrontEndResultServer:
  249. // Metadata response, can be logged or processed
  250. continue
  251. case MsgTypeAudioOnlyServer:
  252. // Stream audio chunk to client
  253. if len(msg.Payload) > 0 {
  254. audioBuffer = append(audioBuffer, msg.Payload...)
  255. if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
  256. return nil, types.NewErrorWithStatusCode(
  257. fmt.Errorf("failed to write audio data: %w", writeErr),
  258. types.ErrorCodeBadResponse,
  259. http.StatusInternalServerError,
  260. )
  261. }
  262. c.Writer.Flush()
  263. }
  264. // Check if this is the last packet (negative sequence)
  265. if msg.Sequence < 0 {
  266. c.Status(http.StatusOK)
  267. usage = &dto.Usage{
  268. PromptTokens: info.PromptTokens,
  269. CompletionTokens: 0,
  270. TotalTokens: info.PromptTokens,
  271. }
  272. return usage, nil
  273. }
  274. default:
  275. // Unknown message type, log and continue
  276. continue
  277. }
  278. }
  279. // If we reach here, connection closed without final packet
  280. c.Status(http.StatusOK)
  281. usage = &dto.Usage{
  282. PromptTokens: info.PromptTokens,
  283. CompletionTokens: 0,
  284. TotalTokens: info.PromptTokens,
  285. }
  286. return usage, nil
  287. }