| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- package volcengine
- import (
- "context"
- "encoding/base64"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "strings"
- "github.com/QuantumNous/new-api/dto"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/QuantumNous/new-api/types"
- "github.com/gin-gonic/gin"
- "github.com/google/uuid"
- "github.com/gorilla/websocket"
- )
- type VolcengineTTSRequest struct {
- App VolcengineTTSApp `json:"app"`
- User VolcengineTTSUser `json:"user"`
- Audio VolcengineTTSAudio `json:"audio"`
- Request VolcengineTTSReqInfo `json:"request"`
- }
- type VolcengineTTSApp struct {
- AppID string `json:"appid"`
- Token string `json:"token"`
- Cluster string `json:"cluster"`
- }
- type VolcengineTTSUser struct {
- UID string `json:"uid"`
- }
- type VolcengineTTSAudio struct {
- VoiceType string `json:"voice_type"`
- Encoding string `json:"encoding"`
- SpeedRatio float64 `json:"speed_ratio"`
- Rate int `json:"rate"`
- Bitrate int `json:"bitrate,omitempty"`
- LoudnessRatio float64 `json:"loudness_ratio,omitempty"`
- EnableEmotion bool `json:"enable_emotion,omitempty"`
- Emotion string `json:"emotion,omitempty"`
- EmotionScale float64 `json:"emotion_scale,omitempty"`
- ExplicitLanguage string `json:"explicit_language,omitempty"`
- ContextLanguage string `json:"context_language,omitempty"`
- }
- type VolcengineTTSReqInfo struct {
- ReqID string `json:"reqid"`
- Text string `json:"text"`
- Operation string `json:"operation"`
- Model string `json:"model,omitempty"`
- TextType string `json:"text_type,omitempty"`
- SilenceDuration float64 `json:"silence_duration,omitempty"`
- WithTimestamp interface{} `json:"with_timestamp,omitempty"`
- ExtraParam *VolcengineTTSExtraParam `json:"extra_param,omitempty"`
- }
- type VolcengineTTSExtraParam struct {
- DisableMarkdownFilter bool `json:"disable_markdown_filter,omitempty"`
- EnableLatexTn bool `json:"enable_latex_tn,omitempty"`
- MuteCutThreshold string `json:"mute_cut_threshold,omitempty"`
- MuteCutRemainMs string `json:"mute_cut_remain_ms,omitempty"`
- DisableEmojiFilter bool `json:"disable_emoji_filter,omitempty"`
- UnsupportedCharRatioThresh float64 `json:"unsupported_char_ratio_thresh,omitempty"`
- AigcWatermark bool `json:"aigc_watermark,omitempty"`
- CacheConfig *VolcengineTTSCacheConfig `json:"cache_config,omitempty"`
- }
- type VolcengineTTSCacheConfig struct {
- TextType int `json:"text_type,omitempty"`
- UseCache bool `json:"use_cache,omitempty"`
- }
- type VolcengineTTSResponse struct {
- ReqID string `json:"reqid"`
- Code int `json:"code"`
- Message string `json:"message"`
- Sequence int `json:"sequence"`
- Data string `json:"data"`
- Addition *VolcengineTTSAdditionInfo `json:"addition,omitempty"`
- }
- type VolcengineTTSAdditionInfo struct {
- Duration string `json:"duration"`
- }
- var openAIToVolcengineVoiceMap = map[string]string{
- "alloy": "zh_male_M392_conversation_wvae_bigtts",
- "echo": "zh_male_wenhao_mars_bigtts",
- "fable": "zh_female_tianmei_mars_bigtts",
- "onyx": "zh_male_zhibei_mars_bigtts",
- "nova": "zh_female_shuangkuaisisi_mars_bigtts",
- "shimmer": "zh_female_cancan_mars_bigtts",
- }
- var responseFormatToEncodingMap = map[string]string{
- "mp3": "mp3",
- "opus": "ogg_opus",
- "aac": "mp3",
- "flac": "mp3",
- "wav": "wav",
- "pcm": "pcm",
- }
- func parseVolcengineAuth(apiKey string) (appID, token string, err error) {
- parts := strings.Split(apiKey, "|")
- if len(parts) != 2 {
- return "", "", errors.New("invalid api key format, expected: appid|access_token")
- }
- return parts[0], parts[1], nil
- }
- func mapVoiceType(openAIVoice string) string {
- if voice, ok := openAIToVolcengineVoiceMap[openAIVoice]; ok {
- return voice
- }
- return openAIVoice
- }
- func mapEncoding(responseFormat string) string {
- if encoding, ok := responseFormatToEncodingMap[responseFormat]; ok {
- return encoding
- }
- return "mp3"
- }
- func getContentTypeByEncoding(encoding string) string {
- contentTypeMap := map[string]string{
- "mp3": "audio/mpeg",
- "ogg_opus": "audio/ogg",
- "wav": "audio/wav",
- "pcm": "audio/pcm",
- }
- if ct, ok := contentTypeMap[encoding]; ok {
- return ct
- }
- return "application/octet-stream"
- }
- func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
- body, readErr := io.ReadAll(resp.Body)
- if readErr != nil {
- return nil, types.NewErrorWithStatusCode(
- errors.New("failed to read volcengine response"),
- types.ErrorCodeReadResponseBodyFailed,
- http.StatusInternalServerError,
- )
- }
- defer resp.Body.Close()
- var volcResp VolcengineTTSResponse
- if unmarshalErr := json.Unmarshal(body, &volcResp); unmarshalErr != nil {
- return nil, types.NewErrorWithStatusCode(
- errors.New("failed to parse volcengine response"),
- types.ErrorCodeBadResponseBody,
- http.StatusInternalServerError,
- )
- }
- if volcResp.Code != 3000 {
- return nil, types.NewErrorWithStatusCode(
- errors.New(volcResp.Message),
- types.ErrorCodeBadResponse,
- http.StatusBadRequest,
- )
- }
- audioData, decodeErr := base64.StdEncoding.DecodeString(volcResp.Data)
- if decodeErr != nil {
- return nil, types.NewErrorWithStatusCode(
- errors.New("failed to decode audio data"),
- types.ErrorCodeBadResponseBody,
- http.StatusInternalServerError,
- )
- }
- contentType := getContentTypeByEncoding(encoding)
- c.Header("Content-Type", contentType)
- c.Data(http.StatusOK, contentType, audioData)
- usage = &dto.Usage{
- PromptTokens: info.PromptTokens,
- CompletionTokens: 0,
- TotalTokens: info.PromptTokens,
- }
- return usage, nil
- }
- func generateRequestID() string {
- return uuid.New().String()
- }
- // handleTTSWebSocketResponse handles streaming TTS response via WebSocket
- func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
- // Parse API key for auth
- _, token, parseErr := parseVolcengineAuth(info.ApiKey)
- if parseErr != nil {
- return nil, types.NewErrorWithStatusCode(
- parseErr,
- types.ErrorCodeChannelInvalidKey,
- http.StatusUnauthorized,
- )
- }
- // Setup WebSocket headers
- header := http.Header{}
- header.Set("Authorization", fmt.Sprintf("Bearer;%s", token))
- // Dial WebSocket connection
- conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header)
- if dialErr != nil {
- if resp != nil {
- return nil, types.NewErrorWithStatusCode(
- fmt.Errorf("failed to connect to websocket: %w, status: %d", dialErr, resp.StatusCode),
- types.ErrorCodeBadResponseStatusCode,
- http.StatusBadGateway,
- )
- }
- return nil, types.NewErrorWithStatusCode(
- fmt.Errorf("failed to connect to websocket: %w", dialErr),
- types.ErrorCodeBadResponseStatusCode,
- http.StatusBadGateway,
- )
- }
- defer conn.Close()
- // Update request operation to "submit" for WebSocket
- volcRequest.Request.Operation = "submit"
- // Marshal request payload
- payload, marshalErr := json.Marshal(volcRequest)
- if marshalErr != nil {
- return nil, types.NewErrorWithStatusCode(
- fmt.Errorf("failed to marshal request: %w", marshalErr),
- types.ErrorCodeBadRequestBody,
- http.StatusInternalServerError,
- )
- }
- // Send full client request
- if sendErr := FullClientRequest(conn, payload); sendErr != nil {
- return nil, types.NewErrorWithStatusCode(
- fmt.Errorf("failed to send request: %w", sendErr),
- types.ErrorCodeBadRequestBody,
- http.StatusInternalServerError,
- )
- }
- // Set response headers
- contentType := getContentTypeByEncoding(encoding)
- c.Header("Content-Type", contentType)
- c.Header("Transfer-Encoding", "chunked")
- // Stream audio data
- var audioBuffer []byte
- for {
- msg, recvErr := ReceiveMessage(conn)
- if recvErr != nil {
- if websocket.IsCloseError(recvErr, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
- break
- }
- return nil, types.NewErrorWithStatusCode(
- fmt.Errorf("failed to receive message: %w", recvErr),
- types.ErrorCodeBadResponse,
- http.StatusInternalServerError,
- )
- }
- switch msg.MsgType {
- case MsgTypeError:
- return nil, types.NewErrorWithStatusCode(
- fmt.Errorf("received error from server: code=%d, %s", msg.ErrorCode, string(msg.Payload)),
- types.ErrorCodeBadResponse,
- http.StatusBadRequest,
- )
- case MsgTypeFrontEndResultServer:
- // Metadata response, can be logged or processed
- continue
- case MsgTypeAudioOnlyServer:
- // Stream audio chunk to client
- if len(msg.Payload) > 0 {
- audioBuffer = append(audioBuffer, msg.Payload...)
- if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
- return nil, types.NewErrorWithStatusCode(
- fmt.Errorf("failed to write audio data: %w", writeErr),
- types.ErrorCodeBadResponse,
- http.StatusInternalServerError,
- )
- }
- c.Writer.Flush()
- }
- // Check if this is the last packet (negative sequence)
- if msg.Sequence < 0 {
- c.Status(http.StatusOK)
- usage = &dto.Usage{
- PromptTokens: info.PromptTokens,
- CompletionTokens: 0,
- TotalTokens: info.PromptTokens,
- }
- return usage, nil
- }
- default:
- // Unknown message type, log and continue
- continue
- }
- }
- // If we reach here, connection closed without final packet
- c.Status(http.StatusOK)
- usage = &dto.Usage{
- PromptTokens: info.PromptTokens,
- CompletionTokens: 0,
- TotalTokens: info.PromptTokens,
- }
- return usage, nil
- }
|