tts.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. package volcengine
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "strings"
  11. "github.com/QuantumNous/new-api/dto"
  12. relaycommon "github.com/QuantumNous/new-api/relay/common"
  13. "github.com/QuantumNous/new-api/types"
  14. "github.com/gin-gonic/gin"
  15. "github.com/google/uuid"
  16. "github.com/gorilla/websocket"
  17. )
  18. type VolcengineTTSRequest struct {
  19. App VolcengineTTSApp `json:"app"`
  20. User VolcengineTTSUser `json:"user"`
  21. Audio VolcengineTTSAudio `json:"audio"`
  22. Request VolcengineTTSReqInfo `json:"request"`
  23. }
  24. type VolcengineTTSApp struct {
  25. AppID string `json:"appid"`
  26. Token string `json:"token"`
  27. Cluster string `json:"cluster"`
  28. }
  29. type VolcengineTTSUser struct {
  30. UID string `json:"uid"`
  31. }
  32. type VolcengineTTSAudio struct {
  33. VoiceType string `json:"voice_type"`
  34. Encoding string `json:"encoding"`
  35. SpeedRatio float64 `json:"speed_ratio"`
  36. Rate int `json:"rate"`
  37. Bitrate int `json:"bitrate,omitempty"`
  38. LoudnessRatio float64 `json:"loudness_ratio,omitempty"`
  39. EnableEmotion bool `json:"enable_emotion,omitempty"`
  40. Emotion string `json:"emotion,omitempty"`
  41. EmotionScale float64 `json:"emotion_scale,omitempty"`
  42. ExplicitLanguage string `json:"explicit_language,omitempty"`
  43. ContextLanguage string `json:"context_language,omitempty"`
  44. }
  45. type VolcengineTTSReqInfo struct {
  46. ReqID string `json:"reqid"`
  47. Text string `json:"text"`
  48. Operation string `json:"operation"`
  49. Model string `json:"model,omitempty"`
  50. TextType string `json:"text_type,omitempty"`
  51. SilenceDuration float64 `json:"silence_duration,omitempty"`
  52. WithTimestamp interface{} `json:"with_timestamp,omitempty"`
  53. ExtraParam *VolcengineTTSExtraParam `json:"extra_param,omitempty"`
  54. }
  55. type VolcengineTTSExtraParam struct {
  56. DisableMarkdownFilter bool `json:"disable_markdown_filter,omitempty"`
  57. EnableLatexTn bool `json:"enable_latex_tn,omitempty"`
  58. MuteCutThreshold string `json:"mute_cut_threshold,omitempty"`
  59. MuteCutRemainMs string `json:"mute_cut_remain_ms,omitempty"`
  60. DisableEmojiFilter bool `json:"disable_emoji_filter,omitempty"`
  61. UnsupportedCharRatioThresh float64 `json:"unsupported_char_ratio_thresh,omitempty"`
  62. AigcWatermark bool `json:"aigc_watermark,omitempty"`
  63. CacheConfig *VolcengineTTSCacheConfig `json:"cache_config,omitempty"`
  64. }
  65. type VolcengineTTSCacheConfig struct {
  66. TextType int `json:"text_type,omitempty"`
  67. UseCache bool `json:"use_cache,omitempty"`
  68. }
  69. type VolcengineTTSResponse struct {
  70. ReqID string `json:"reqid"`
  71. Code int `json:"code"`
  72. Message string `json:"message"`
  73. Sequence int `json:"sequence"`
  74. Data string `json:"data"`
  75. Addition *VolcengineTTSAdditionInfo `json:"addition,omitempty"`
  76. }
  77. type VolcengineTTSAdditionInfo struct {
  78. Duration string `json:"duration"`
  79. }
  80. var openAIToVolcengineVoiceMap = map[string]string{
  81. "alloy": "zh_male_M392_conversation_wvae_bigtts",
  82. "echo": "zh_male_wenhao_mars_bigtts",
  83. "fable": "zh_female_tianmei_mars_bigtts",
  84. "onyx": "zh_male_zhibei_mars_bigtts",
  85. "nova": "zh_female_shuangkuaisisi_mars_bigtts",
  86. "shimmer": "zh_female_cancan_mars_bigtts",
  87. }
  88. var responseFormatToEncodingMap = map[string]string{
  89. "mp3": "mp3",
  90. "opus": "ogg_opus",
  91. "aac": "mp3",
  92. "flac": "mp3",
  93. "wav": "wav",
  94. "pcm": "pcm",
  95. }
  96. func parseVolcengineAuth(apiKey string) (appID, token string, err error) {
  97. parts := strings.Split(apiKey, "|")
  98. if len(parts) != 2 {
  99. return "", "", errors.New("invalid api key format, expected: appid|access_token")
  100. }
  101. return parts[0], parts[1], nil
  102. }
  103. func mapVoiceType(openAIVoice string) string {
  104. if voice, ok := openAIToVolcengineVoiceMap[openAIVoice]; ok {
  105. return voice
  106. }
  107. return openAIVoice
  108. }
  109. func mapEncoding(responseFormat string) string {
  110. if encoding, ok := responseFormatToEncodingMap[responseFormat]; ok {
  111. return encoding
  112. }
  113. return "mp3"
  114. }
  115. func getContentTypeByEncoding(encoding string) string {
  116. contentTypeMap := map[string]string{
  117. "mp3": "audio/mpeg",
  118. "ogg_opus": "audio/ogg",
  119. "wav": "audio/wav",
  120. "pcm": "audio/pcm",
  121. }
  122. if ct, ok := contentTypeMap[encoding]; ok {
  123. return ct
  124. }
  125. return "application/octet-stream"
  126. }
  127. func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
  128. body, readErr := io.ReadAll(resp.Body)
  129. if readErr != nil {
  130. return nil, types.NewErrorWithStatusCode(
  131. errors.New("failed to read volcengine response"),
  132. types.ErrorCodeReadResponseBodyFailed,
  133. http.StatusInternalServerError,
  134. )
  135. }
  136. defer resp.Body.Close()
  137. var volcResp VolcengineTTSResponse
  138. if unmarshalErr := json.Unmarshal(body, &volcResp); unmarshalErr != nil {
  139. return nil, types.NewErrorWithStatusCode(
  140. errors.New("failed to parse volcengine response"),
  141. types.ErrorCodeBadResponseBody,
  142. http.StatusInternalServerError,
  143. )
  144. }
  145. if volcResp.Code != 3000 {
  146. return nil, types.NewErrorWithStatusCode(
  147. errors.New(volcResp.Message),
  148. types.ErrorCodeBadResponse,
  149. http.StatusBadRequest,
  150. )
  151. }
  152. audioData, decodeErr := base64.StdEncoding.DecodeString(volcResp.Data)
  153. if decodeErr != nil {
  154. return nil, types.NewErrorWithStatusCode(
  155. errors.New("failed to decode audio data"),
  156. types.ErrorCodeBadResponseBody,
  157. http.StatusInternalServerError,
  158. )
  159. }
  160. contentType := getContentTypeByEncoding(encoding)
  161. c.Header("Content-Type", contentType)
  162. c.Data(http.StatusOK, contentType, audioData)
  163. usage = &dto.Usage{
  164. PromptTokens: info.GetEstimatePromptTokens(),
  165. CompletionTokens: 0,
  166. TotalTokens: info.GetEstimatePromptTokens(),
  167. }
  168. return usage, nil
  169. }
  170. func generateRequestID() string {
  171. return uuid.New().String()
  172. }
  173. func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
  174. _, token, parseErr := parseVolcengineAuth(info.ApiKey)
  175. if parseErr != nil {
  176. return nil, types.NewErrorWithStatusCode(
  177. parseErr,
  178. types.ErrorCodeChannelInvalidKey,
  179. http.StatusUnauthorized,
  180. )
  181. }
  182. header := http.Header{}
  183. header.Set("Authorization", fmt.Sprintf("Bearer;%s", token))
  184. conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header)
  185. if dialErr != nil {
  186. if resp != nil {
  187. return nil, types.NewErrorWithStatusCode(
  188. fmt.Errorf("failed to connect to websocket: %w, status: %d", dialErr, resp.StatusCode),
  189. types.ErrorCodeBadResponseStatusCode,
  190. http.StatusBadGateway,
  191. )
  192. }
  193. return nil, types.NewErrorWithStatusCode(
  194. fmt.Errorf("failed to connect to websocket: %w", dialErr),
  195. types.ErrorCodeBadResponseStatusCode,
  196. http.StatusBadGateway,
  197. )
  198. }
  199. defer conn.Close()
  200. payload, marshalErr := json.Marshal(volcRequest)
  201. if marshalErr != nil {
  202. return nil, types.NewErrorWithStatusCode(
  203. fmt.Errorf("failed to marshal request: %w", marshalErr),
  204. types.ErrorCodeBadRequestBody,
  205. http.StatusInternalServerError,
  206. )
  207. }
  208. if sendErr := FullClientRequest(conn, payload); sendErr != nil {
  209. return nil, types.NewErrorWithStatusCode(
  210. fmt.Errorf("failed to send request: %w", sendErr),
  211. types.ErrorCodeBadRequestBody,
  212. http.StatusInternalServerError,
  213. )
  214. }
  215. contentType := getContentTypeByEncoding(encoding)
  216. c.Header("Content-Type", contentType)
  217. c.Header("Transfer-Encoding", "chunked")
  218. for {
  219. msg, recvErr := ReceiveMessage(conn)
  220. if recvErr != nil {
  221. if websocket.IsCloseError(recvErr, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  222. break
  223. }
  224. return nil, types.NewErrorWithStatusCode(
  225. fmt.Errorf("failed to receive message: %w", recvErr),
  226. types.ErrorCodeBadResponse,
  227. http.StatusInternalServerError,
  228. )
  229. }
  230. switch msg.MsgType {
  231. case MsgTypeError:
  232. return nil, types.NewErrorWithStatusCode(
  233. fmt.Errorf("received error from server: code=%d, %s", msg.ErrorCode, string(msg.Payload)),
  234. types.ErrorCodeBadResponse,
  235. http.StatusBadRequest,
  236. )
  237. case MsgTypeFrontEndResultServer:
  238. continue
  239. case MsgTypeAudioOnlyServer:
  240. if len(msg.Payload) > 0 {
  241. if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
  242. return nil, types.NewErrorWithStatusCode(
  243. fmt.Errorf("failed to write audio data: %w", writeErr),
  244. types.ErrorCodeBadResponse,
  245. http.StatusInternalServerError,
  246. )
  247. }
  248. c.Writer.Flush()
  249. }
  250. if msg.Sequence < 0 {
  251. c.Status(http.StatusOK)
  252. usage = &dto.Usage{
  253. PromptTokens: info.GetEstimatePromptTokens(),
  254. CompletionTokens: 0,
  255. TotalTokens: info.GetEstimatePromptTokens(),
  256. }
  257. return usage, nil
  258. }
  259. default:
  260. continue
  261. }
  262. }
  263. c.Status(http.StatusOK)
  264. usage = &dto.Usage{
  265. PromptTokens: info.GetEstimatePromptTokens(),
  266. CompletionTokens: 0,
  267. TotalTokens: info.GetEstimatePromptTokens(),
  268. }
  269. return usage, nil
  270. }