| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- package relay
- import (
- "encoding/json"
- "errors"
- "fmt"
- "github.com/gin-gonic/gin"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/model"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/service"
- )
- func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
- audioRequest := &dto.AudioRequest{}
- err := common.UnmarshalBodyReusable(c, audioRequest)
- if err != nil {
- return nil, err
- }
- switch info.RelayMode {
- case relayconstant.RelayModeAudioSpeech:
- if audioRequest.Model == "" {
- return nil, errors.New("model is required")
- }
- if constant.ShouldCheckPromptSensitive() {
- err := service.CheckSensitiveInput(audioRequest.Input)
- if err != nil {
- return nil, err
- }
- }
- default:
- if audioRequest.Model == "" {
- audioRequest.Model = c.PostForm("model")
- }
- if audioRequest.Model == "" {
- return nil, errors.New("model is required")
- }
- if audioRequest.ResponseFormat == "" {
- audioRequest.ResponseFormat = "json"
- }
- }
- return audioRequest, nil
- }
- func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
- relayInfo := relaycommon.GenRelayInfo(c)
- audioRequest, err := getAndValidAudioRequest(c, relayInfo)
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
- return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest)
- }
- promptTokens := 0
- preConsumedTokens := common.PreConsumedQuota
- if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
- promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
- }
- preConsumedTokens = promptTokens
- relayInfo.PromptTokens = promptTokens
- }
- modelRatio := common.GetModelRatio(audioRequest.Model)
- groupRatio := common.GetGroupRatio(relayInfo.Group)
- ratio := modelRatio * groupRatio
- preConsumedQuota := int(float64(preConsumedTokens) * ratio)
- userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
- }
- if userQuota-preConsumedQuota < 0 {
- return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
- }
- err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
- }
- if userQuota > 100*preConsumedQuota {
- // in this case, we do not pre-consume quota
- // because the user has enough quota
- preConsumedQuota = 0
- }
- if preConsumedQuota > 0 {
- userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
- }
- }
- // map model name
- modelMapping := c.GetString("model_mapping")
- if modelMapping != "" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[audioRequest.Model] != "" {
- audioRequest.Model = modelMap[audioRequest.Model]
- }
- }
- relayInfo.UpstreamModelName = audioRequest.Model
- adaptor := GetAdaptor(relayInfo.ApiType)
- if adaptor == nil {
- return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
- }
- adaptor.Init(relayInfo)
- ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
- }
- resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- }
- statusCodeMappingStr := c.GetString("status_code_mapping")
- var httpResp *http.Response
- if resp != nil {
- httpResp = resp.(*http.Response)
- if httpResp.StatusCode != http.StatusOK {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- openaiErr := service.RelayErrorHandler(httpResp)
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
- }
- usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
- if openaiErr != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
- postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
- return nil
- }
|