tts.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. package volcengine
  2. import (
  3. "encoding/base64"
  4. "encoding/json"
  5. "errors"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "github.com/QuantumNous/new-api/dto"
  10. relaycommon "github.com/QuantumNous/new-api/relay/common"
  11. "github.com/QuantumNous/new-api/types"
  12. "github.com/gin-gonic/gin"
  13. "github.com/google/uuid"
  14. )
  15. type VolcengineTTSRequest struct {
  16. App VolcengineTTSApp `json:"app"`
  17. User VolcengineTTSUser `json:"user"`
  18. Audio VolcengineTTSAudio `json:"audio"`
  19. Request VolcengineTTSReqInfo `json:"request"`
  20. }
  21. type VolcengineTTSApp struct {
  22. AppID string `json:"appid"`
  23. Token string `json:"token"`
  24. Cluster string `json:"cluster"`
  25. }
  26. type VolcengineTTSUser struct {
  27. UID string `json:"uid"`
  28. }
  29. type VolcengineTTSAudio struct {
  30. VoiceType string `json:"voice_type"`
  31. Encoding string `json:"encoding"`
  32. SpeedRatio float64 `json:"speed_ratio"`
  33. Rate int `json:"rate"`
  34. Bitrate int `json:"bitrate,omitempty"`
  35. LoudnessRatio float64 `json:"loudness_ratio,omitempty"`
  36. EnableEmotion bool `json:"enable_emotion,omitempty"`
  37. Emotion string `json:"emotion,omitempty"`
  38. EmotionScale float64 `json:"emotion_scale,omitempty"`
  39. ExplicitLanguage string `json:"explicit_language,omitempty"`
  40. ContextLanguage string `json:"context_language,omitempty"`
  41. }
  42. type VolcengineTTSReqInfo struct {
  43. ReqID string `json:"reqid"`
  44. Text string `json:"text"`
  45. Operation string `json:"operation"`
  46. Model string `json:"model,omitempty"`
  47. TextType string `json:"text_type,omitempty"`
  48. SilenceDuration float64 `json:"silence_duration,omitempty"`
  49. WithTimestamp interface{} `json:"with_timestamp,omitempty"`
  50. ExtraParam *VolcengineTTSExtraParam `json:"extra_param,omitempty"`
  51. }
  52. type VolcengineTTSExtraParam struct {
  53. DisableMarkdownFilter bool `json:"disable_markdown_filter,omitempty"`
  54. EnableLatexTn bool `json:"enable_latex_tn,omitempty"`
  55. MuteCutThreshold string `json:"mute_cut_threshold,omitempty"`
  56. MuteCutRemainMs string `json:"mute_cut_remain_ms,omitempty"`
  57. DisableEmojiFilter bool `json:"disable_emoji_filter,omitempty"`
  58. UnsupportedCharRatioThresh float64 `json:"unsupported_char_ratio_thresh,omitempty"`
  59. AigcWatermark bool `json:"aigc_watermark,omitempty"`
  60. CacheConfig *VolcengineTTSCacheConfig `json:"cache_config,omitempty"`
  61. }
  62. type VolcengineTTSCacheConfig struct {
  63. TextType int `json:"text_type,omitempty"`
  64. UseCache bool `json:"use_cache,omitempty"`
  65. }
  66. type VolcengineTTSResponse struct {
  67. ReqID string `json:"reqid"`
  68. Code int `json:"code"`
  69. Message string `json:"message"`
  70. Sequence int `json:"sequence"`
  71. Data string `json:"data"`
  72. Addition *VolcengineTTSAdditionInfo `json:"addition,omitempty"`
  73. }
  74. type VolcengineTTSAdditionInfo struct {
  75. Duration string `json:"duration"`
  76. }
  77. var openAIToVolcengineVoiceMap = map[string]string{
  78. "alloy": "zh_male_M392_conversation_wvae_bigtts",
  79. "echo": "zh_male_wenhao_mars_bigtts",
  80. "fable": "zh_female_tianmei_mars_bigtts",
  81. "onyx": "zh_male_zhibei_mars_bigtts",
  82. "nova": "zh_female_shuangkuaisisi_mars_bigtts",
  83. "shimmer": "zh_female_cancan_mars_bigtts",
  84. }
  85. var responseFormatToEncodingMap = map[string]string{
  86. "mp3": "mp3",
  87. "opus": "ogg_opus",
  88. "aac": "mp3",
  89. "flac": "mp3",
  90. "wav": "wav",
  91. "pcm": "pcm",
  92. }
  93. func parseVolcengineAuth(apiKey string) (appID, token string, err error) {
  94. parts := strings.Split(apiKey, "|")
  95. if len(parts) != 2 {
  96. return "", "", errors.New("invalid api key format, expected: appid:access_token")
  97. }
  98. return parts[0], parts[1], nil
  99. }
  100. func mapVoiceType(openAIVoice string) string {
  101. if voice, ok := openAIToVolcengineVoiceMap[openAIVoice]; ok {
  102. return voice
  103. }
  104. return openAIVoice
  105. }
  106. // [0.1,2],默认为 1,通常保留一位小数即可
  107. func mapSpeedRatio(speed float64) float64 {
  108. if speed == 0 {
  109. return 1.0
  110. }
  111. if speed < 0.1 {
  112. return 0.1
  113. }
  114. if speed > 2.0 {
  115. return 2.0
  116. }
  117. return speed
  118. }
  119. func mapEncoding(responseFormat string) string {
  120. if encoding, ok := responseFormatToEncodingMap[responseFormat]; ok {
  121. return encoding
  122. }
  123. return "mp3"
  124. }
  125. func getContentTypeByEncoding(encoding string) string {
  126. contentTypeMap := map[string]string{
  127. "mp3": "audio/mpeg",
  128. "ogg_opus": "audio/ogg",
  129. "wav": "audio/wav",
  130. "pcm": "audio/pcm",
  131. }
  132. if ct, ok := contentTypeMap[encoding]; ok {
  133. return ct
  134. }
  135. return "application/octet-stream"
  136. }
  137. func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
  138. body, readErr := io.ReadAll(resp.Body)
  139. if readErr != nil {
  140. return nil, types.NewErrorWithStatusCode(
  141. errors.New("failed to read volcengine response"),
  142. types.ErrorCodeReadResponseBodyFailed,
  143. http.StatusInternalServerError,
  144. )
  145. }
  146. defer resp.Body.Close()
  147. var volcResp VolcengineTTSResponse
  148. if unmarshalErr := json.Unmarshal(body, &volcResp); unmarshalErr != nil {
  149. return nil, types.NewErrorWithStatusCode(
  150. errors.New("failed to parse volcengine response"),
  151. types.ErrorCodeBadResponseBody,
  152. http.StatusInternalServerError,
  153. )
  154. }
  155. if volcResp.Code != 3000 {
  156. return nil, types.NewErrorWithStatusCode(
  157. errors.New(volcResp.Message),
  158. types.ErrorCodeBadResponse,
  159. http.StatusBadRequest,
  160. )
  161. }
  162. audioData, decodeErr := base64.StdEncoding.DecodeString(volcResp.Data)
  163. if decodeErr != nil {
  164. return nil, types.NewErrorWithStatusCode(
  165. errors.New("failed to decode audio data"),
  166. types.ErrorCodeBadResponseBody,
  167. http.StatusInternalServerError,
  168. )
  169. }
  170. contentType := getContentTypeByEncoding(encoding)
  171. c.Header("Content-Type", contentType)
  172. c.Data(http.StatusOK, contentType, audioData)
  173. usage = &dto.Usage{
  174. PromptTokens: info.PromptTokens,
  175. CompletionTokens: 0,
  176. TotalTokens: info.PromptTokens,
  177. }
  178. return usage, nil
  179. }
  180. func generateRequestID() string {
  181. return uuid.New().String()
  182. }