Jelajahi Sumber

feat: implement audio duration retrieval without ffmpeg dependencies

CaIon 4 bulan lalu
induk
melakukan
2b70095b47
10 mengubah file dengan 430 tambahan dan 139 penghapusan
  1. 1 1
      Dockerfile
  2. 295 0
      common/audio.go
  3. 4 4
      common/gin.go
  4. 0 39
      common/utils.go
  5. 13 0
      go.mod
  6. 40 0
      go.sum
  7. 47 30
      middleware/distributor.go
  8. 1 53
      relay/channel/openai/relay-openai.go
  9. 0 10
      relay/helper/valid_request.go
  10. 29 2
      service/token_counter.go

+ 1 - 1
Dockerfile

@@ -28,7 +28,7 @@ RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$
 FROM alpine
 
 RUN apk upgrade --no-cache \
-    && apk add --no-cache ca-certificates tzdata ffmpeg \
+    && apk add --no-cache ca-certificates tzdata \
     && update-ca-certificates
 
 COPY --from=builder2 /build/new-api /

+ 295 - 0
common/audio.go

@@ -0,0 +1,295 @@
+package common
+
+import (
+	"context"
+	"encoding/binary"
+	"fmt"
+	"io"
+
+	"github.com/abema/go-mp4"
+	"github.com/go-audio/aiff"
+	"github.com/go-audio/wav"
+	"github.com/jfreymuth/oggvorbis"
+	"github.com/mewkiz/flac"
+	"github.com/pkg/errors"
+	"github.com/tcolgate/mp3"
+	"github.com/yapingcat/gomedia/go-codec"
+)
+
+// GetAudioDuration 使用纯 Go 库获取音频文件的时长(秒)。
+// 它不再依赖外部的 ffmpeg 或 ffprobe 程序。
+func GetAudioDuration(ctx context.Context, f io.ReadSeeker, ext string) (duration float64, err error) {
+	SysLog(fmt.Sprintf("GetAudioDuration: ext=%s", ext))
+	// 根据文件扩展名选择解析器
+	switch ext {
+	case ".mp3":
+		duration, err = getMP3Duration(f)
+	case ".wav":
+		duration, err = getWAVDuration(f)
+	case ".flac":
+		duration, err = getFLACDuration(f)
+	case ".m4a", ".mp4":
+		duration, err = getM4ADuration(f)
+	case ".ogg", ".oga":
+		duration, err = getOGGDuration(f)
+	case ".opus":
+		duration, err = getOpusDuration(f)
+	case ".aiff", ".aif", ".aifc":
+		duration, err = getAIFFDuration(f)
+	case ".webm":
+		duration, err = getWebMDuration(f)
+	case ".aac":
+		duration, err = getAACDuration(f)
+	default:
+		return 0, fmt.Errorf("unsupported audio format: %s", ext)
+	}
+	SysLog(fmt.Sprintf("GetAudioDuration: duration=%f", duration))
+	return duration, err
+}
+
+// getMP3Duration 解析 MP3 文件以获取时长。
+// 注意:对于 VBR (Variable Bitrate) MP3,这个估算可能不完全精确,但通常足够好。
+// FFmpeg 在这种情况下会扫描整个文件来获得精确值,但这里的库提供了快速估算。
+func getMP3Duration(r io.Reader) (float64, error) {
+	d := mp3.NewDecoder(r)
+	var f mp3.Frame
+	skipped := 0
+	duration := 0.0
+
+	for {
+		if err := d.Decode(&f, &skipped); err != nil {
+			if err == io.EOF {
+				break
+			}
+			return 0, errors.Wrap(err, "failed to decode mp3 frame")
+		}
+		duration += f.Duration().Seconds()
+	}
+	return duration, nil
+}
+
+// getWAVDuration 解析 WAV 文件头以获取时长。
+func getWAVDuration(r io.ReadSeeker) (float64, error) {
+	dec := wav.NewDecoder(r)
+	if !dec.IsValidFile() {
+		return 0, errors.New("invalid wav file")
+	}
+	d, err := dec.Duration()
+	if err != nil {
+		return 0, errors.Wrap(err, "failed to get wav duration")
+	}
+	return d.Seconds(), nil
+}
+
+// getFLACDuration 解析 FLAC 文件的 STREAMINFO 块。
+func getFLACDuration(r io.Reader) (float64, error) {
+	stream, err := flac.Parse(r)
+	if err != nil {
+		return 0, errors.Wrap(err, "failed to parse flac stream")
+	}
+	defer stream.Close()
+
+	// 时长 = 总采样数 / 采样率
+	duration := float64(stream.Info.NSamples) / float64(stream.Info.SampleRate)
+	return duration, nil
+}
+
+// getM4ADuration 解析 M4A/MP4 文件的 'mvhd' box。
+func getM4ADuration(r io.ReadSeeker) (float64, error) {
+	// go-mp4 库需要 ReadSeeker 接口
+	info, err := mp4.Probe(r)
+	if err != nil {
+		return 0, errors.Wrap(err, "failed to probe m4a/mp4 file")
+	}
+	// 时长 = Duration / Timescale
+	return float64(info.Duration) / float64(info.Timescale), nil
+}
+
+// getOGGDuration 解析 OGG/Vorbis 文件以获取时长。
+func getOGGDuration(r io.ReadSeeker) (float64, error) {
+	// 重置 reader 到开头
+	if _, err := r.Seek(0, io.SeekStart); err != nil {
+		return 0, errors.Wrap(err, "failed to seek ogg file")
+	}
+
+	reader, err := oggvorbis.NewReader(r)
+	if err != nil {
+		return 0, errors.Wrap(err, "failed to create ogg vorbis reader")
+	}
+
+	// 计算时长 = 总采样数 / 采样率
+	// 需要读取整个文件来获取总采样数
+	channels := reader.Channels()
+	sampleRate := reader.SampleRate()
+
+	// 估算方法:读取到文件结尾
+	var totalSamples int64
+	buf := make([]float32, 4096*channels)
+	for {
+		n, err := reader.Read(buf)
+		if err == io.EOF {
+			break
+		}
+		if err != nil {
+			return 0, errors.Wrap(err, "failed to read ogg samples")
+		}
+		totalSamples += int64(n / channels)
+	}
+
+	duration := float64(totalSamples) / float64(sampleRate)
+	return duration, nil
+}
+
+// getOpusDuration 解析 Opus 文件(在 OGG 容器中)以获取时长。
+func getOpusDuration(r io.ReadSeeker) (float64, error) {
+	// Opus 通常封装在 OGG 容器中
+	// 我们需要解析 OGG 页面来获取时长信息
+	if _, err := r.Seek(0, io.SeekStart); err != nil {
+		return 0, errors.Wrap(err, "failed to seek opus file")
+	}
+
+	// 读取 OGG 页面头部
+	var totalGranulePos int64
+	buf := make([]byte, 27) // OGG 页面头部最小大小
+
+	for {
+		n, err := r.Read(buf)
+		if err == io.EOF {
+			break
+		}
+		if err != nil {
+			return 0, errors.Wrap(err, "failed to read opus/ogg page")
+		}
+		if n < 27 {
+			break
+		}
+
+		// 检查 OGG 页面标识 "OggS"
+		if string(buf[0:4]) != "OggS" {
+			// 跳过一些字节继续寻找
+			if _, err := r.Seek(-26, io.SeekCurrent); err != nil {
+				break
+			}
+			continue
+		}
+
+		// 读取 granule position (字节 6-13, 小端序)
+		granulePos := int64(binary.LittleEndian.Uint64(buf[6:14]))
+		if granulePos > totalGranulePos {
+			totalGranulePos = granulePos
+		}
+
+		// 读取段表大小
+		numSegments := int(buf[26])
+		segmentTable := make([]byte, numSegments)
+		if _, err := io.ReadFull(r, segmentTable); err != nil {
+			break
+		}
+
+		// 计算页面数据大小并跳过
+		var pageSize int
+		for _, segSize := range segmentTable {
+			pageSize += int(segSize)
+		}
+		if _, err := r.Seek(int64(pageSize), io.SeekCurrent); err != nil {
+			break
+		}
+	}
+
+	// Opus 的采样率固定为 48000 Hz
+	duration := float64(totalGranulePos) / 48000.0
+	return duration, nil
+}
+
+// getAIFFDuration 解析 AIFF 文件头以获取时长。
+func getAIFFDuration(r io.ReadSeeker) (float64, error) {
+	if _, err := r.Seek(0, io.SeekStart); err != nil {
+		return 0, errors.Wrap(err, "failed to seek aiff file")
+	}
+
+	dec := aiff.NewDecoder(r)
+	if !dec.IsValidFile() {
+		return 0, errors.New("invalid aiff file")
+	}
+
+	d, err := dec.Duration()
+	if err != nil {
+		return 0, errors.Wrap(err, "failed to get aiff duration")
+	}
+
+	return d.Seconds(), nil
+}
+
+// getWebMDuration 解析 WebM 文件以获取时长。
+// WebM 使用 Matroska 容器格式
+func getWebMDuration(r io.ReadSeeker) (float64, error) {
+	if _, err := r.Seek(0, io.SeekStart); err != nil {
+		return 0, errors.Wrap(err, "failed to seek webm file")
+	}
+
+	// WebM/Matroska 文件的解析比较复杂
+	// 这里提供一个简化的实现,读取 EBML 头部
+	// 对于完整的 WebM 解析,可能需要使用专门的库
+
+	// 简单实现:查找 Duration 元素
+	// WebM Duration 的 Element ID 是 0x4489
+	// 这是一个简化版本,可能不适用于所有 WebM 文件
+	buf := make([]byte, 8192)
+	n, err := r.Read(buf)
+	if err != nil && err != io.EOF {
+		return 0, errors.Wrap(err, "failed to read webm file")
+	}
+
+	// 尝试查找 Duration 元素(这是一个简化的方法)
+	// 实际的 WebM 解析需要完整的 EBML 解析器
+	// 这里返回错误,建议使用专门的库
+	if n > 0 {
+		// 检查 EBML 标识
+		if len(buf) >= 4 && binary.BigEndian.Uint32(buf[0:4]) == 0x1A45DFA3 {
+			// 这是一个有效的 EBML 文件
+			// 但完整解析需要更复杂的逻辑
+			return 0, errors.New("webm duration parsing requires full EBML parser (consider using ffprobe for webm files)")
+		}
+	}
+
+	return 0, errors.New("failed to parse webm file")
+}
+
+// getAACDuration 解析 AAC (ADTS格式) 文件以获取时长。
+// 使用 gomedia 库来解析 AAC ADTS 帧
+func getAACDuration(r io.ReadSeeker) (float64, error) {
+	if _, err := r.Seek(0, io.SeekStart); err != nil {
+		return 0, errors.Wrap(err, "failed to seek aac file")
+	}
+
+	// 读取整个文件内容
+	data, err := io.ReadAll(r)
+	if err != nil {
+		return 0, errors.Wrap(err, "failed to read aac file")
+	}
+
+	var totalFrames int64
+	var sampleRate int
+
+	// 使用 gomedia 的 SplitAACFrame 函数来分割 AAC 帧
+	codec.SplitAACFrame(data, func(aac []byte) {
+		// 解析 ADTS 头部以获取采样率信息
+		if len(aac) >= 7 {
+			// 使用 ConvertADTSToASC 来获取音频配置信息
+			asc, err := codec.ConvertADTSToASC(aac)
+			if err == nil && sampleRate == 0 {
+				sampleRate = codec.AACSampleIdxToSample(int(asc.Sample_freq_index))
+			}
+			totalFrames++
+		}
+	})
+
+	if sampleRate == 0 || totalFrames == 0 {
+		return 0, errors.New("no valid aac frames found")
+	}
+
+	// 每个 AAC ADTS 帧包含 1024 个采样
+	totalSamples := totalFrames * 1024
+	duration := float64(totalSamples) / float64(sampleRate)
+	return duration, nil
+}

+ 4 - 4
common/gin.go

@@ -163,7 +163,7 @@ func parseFormData(data []byte, v any) error {
 		return err
 	}
 
-	return json.Unmarshal(jsonData, v)
+	return Unmarshal(jsonData, v)
 }
 
 func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
@@ -174,7 +174,7 @@ func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
 	}
 
 	if boundary == "" {
-		return json.Unmarshal(data, v) // Fallback to JSON
+		return Unmarshal(data, v) // Fallback to JSON
 	}
 
 	reader := multipart.NewReader(bytes.NewReader(data), boundary)
@@ -191,10 +191,10 @@ func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
 			formMap[key] = vals
 		}
 	}
-	jsonData, err := json.Marshal(formMap)
+	jsonData, err := Marshal(formMap)
 	if err != nil {
 		return err
 	}
 
-	return json.Unmarshal(jsonData, v)
+	return Unmarshal(jsonData, v)
 }

+ 0 - 39
common/utils.go

@@ -1,8 +1,6 @@
 package common
 
 import (
-	"bytes"
-	"context"
 	crand "crypto/rand"
 	"encoding/base64"
 	"encoding/json"
@@ -329,43 +327,6 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
 	return f.Name(), nil
 }
 
-// GetAudioDuration returns the duration of an audio file in seconds.
-func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
-	// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
-	c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
-	output, err := c.Output()
-	if err != nil {
-		return 0, errors.Wrap(err, "failed to get audio duration")
-	}
-	durationStr := string(bytes.TrimSpace(output))
-	if durationStr == "N/A" {
-		// Create a temporary output file name
-		tmpFp, err := os.CreateTemp("", "audio-*"+ext)
-		if err != nil {
-			return 0, errors.Wrap(err, "failed to create temporary file")
-		}
-		tmpName := tmpFp.Name()
-		// Close immediately so ffmpeg can open the file on Windows.
-		_ = tmpFp.Close()
-		defer os.Remove(tmpName)
-
-		// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
-		ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
-		if err := ffmpegCmd.Run(); err != nil {
-			return 0, errors.Wrap(err, "failed to run ffmpeg")
-		}
-
-		// Recalculate the duration of the new file
-		c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
-		output, err := c.Output()
-		if err != nil {
-			return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
-		}
-		durationStr = string(bytes.TrimSpace(output))
-	}
-	return strconv.ParseFloat(durationStr, 64)
-}
-
 // BuildURL concatenates base and endpoint, returns the complete url string
 func BuildURL(base string, endpoint string) string {
 	u, err := url.Parse(base)

+ 13 - 0
go.mod

@@ -5,6 +5,7 @@ go 1.25.1
 
 require (
 	github.com/Calcium-Ion/go-epay v0.0.4
+	github.com/abema/go-mp4 v1.4.1
 	github.com/andybalholm/brotli v1.1.1
 	github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
 	github.com/aws/aws-sdk-go-v2 v1.37.2
@@ -18,24 +19,30 @@ require (
 	github.com/gin-contrib/static v0.0.1
 	github.com/gin-gonic/gin v1.9.1
 	github.com/glebarez/sqlite v1.9.0
+	github.com/go-audio/aiff v1.1.0
+	github.com/go-audio/wav v1.1.0
 	github.com/go-playground/validator/v10 v10.20.0
 	github.com/go-redis/redis/v8 v8.11.5
 	github.com/go-webauthn/webauthn v0.14.0
 	github.com/golang-jwt/jwt/v5 v5.3.0
 	github.com/google/uuid v1.6.0
 	github.com/gorilla/websocket v1.5.0
+	github.com/jfreymuth/oggvorbis v1.0.5
 	github.com/jinzhu/copier v0.4.0
 	github.com/joho/godotenv v1.5.1
+	github.com/mewkiz/flac v1.0.13
 	github.com/pkg/errors v0.9.1
 	github.com/pquerna/otp v1.5.0
 	github.com/samber/lo v1.39.0
 	github.com/shirou/gopsutil v3.21.11+incompatible
 	github.com/shopspring/decimal v1.4.0
 	github.com/stripe/stripe-go/v81 v81.4.0
+	github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300
 	github.com/thanhpk/randstr v1.0.6
 	github.com/tidwall/gjson v1.18.0
 	github.com/tidwall/sjson v1.2.5
 	github.com/tiktoken-go/tokenizer v0.6.2
+	github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c
 	golang.org/x/crypto v0.42.0
 	golang.org/x/image v0.23.0
 	golang.org/x/net v0.43.0
@@ -62,6 +69,8 @@ require (
 	github.com/gabriel-vasile/mimetype v1.4.3 // indirect
 	github.com/gin-contrib/sse v0.1.0 // indirect
 	github.com/glebarez/go-sqlite v1.21.2 // indirect
+	github.com/go-audio/audio v1.0.0 // indirect
+	github.com/go-audio/riff v1.0.0 // indirect
 	github.com/go-ole/go-ole v1.2.6 // indirect
 	github.com/go-playground/locales v0.14.1 // indirect
 	github.com/go-playground/universal-translator v0.18.1 // indirect
@@ -73,16 +82,20 @@ require (
 	github.com/gorilla/context v1.1.1 // indirect
 	github.com/gorilla/securecookie v1.1.1 // indirect
 	github.com/gorilla/sessions v1.2.1 // indirect
+	github.com/icza/bitio v1.1.0 // indirect
 	github.com/jackc/pgpassfile v1.0.0 // indirect
 	github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
 	github.com/jackc/pgx/v5 v5.7.1 // indirect
 	github.com/jackc/puddle/v2 v2.2.2 // indirect
+	github.com/jfreymuth/vorbis v1.0.2 // indirect
 	github.com/jinzhu/inflection v1.0.0 // indirect
 	github.com/jinzhu/now v1.1.5 // indirect
 	github.com/json-iterator/go v1.1.12 // indirect
 	github.com/klauspost/cpuid/v2 v2.3.0 // indirect
 	github.com/leodido/go-urn v1.4.0 // indirect
 	github.com/mattn/go-isatty v0.0.20 // indirect
+	github.com/mewkiz/pkg v0.0.0-20250417130911-3f050ff8c56d // indirect
+	github.com/mewpkg/term v0.0.0-20241026122259-37a80af23985 // indirect
 	github.com/mitchellh/mapstructure v1.5.0 // indirect
 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 	github.com/modern-go/reflect2 v1.0.2 // indirect

+ 40 - 0
go.sum

@@ -1,5 +1,7 @@
 github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
 github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
+github.com/abema/go-mp4 v1.4.1 h1:YoS4VRqd+pAmddRPLFf8vMk74kuGl6ULSjzhsIqwr6M=
+github.com/abema/go-mp4 v1.4.1/go.mod h1:vPl9t5ZK7K0x68jh12/+ECWBCXoWuIDtNgPtU2f04ws=
 github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
 github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
@@ -33,6 +35,7 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
 github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
 github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
+github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
 github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@@ -67,6 +70,15 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g
 github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
 github.com/glebarez/sqlite v1.9.0 h1:Aj6bPA12ZEx5GbSF6XADmCkYXlljPNUY+Zf1EQxynXs=
 github.com/glebarez/sqlite v1.9.0/go.mod h1:YBYCoyupOao60lzp1MVBLEjZfgkq0tdB1voAQ09K9zw=
+github.com/go-audio/aiff v1.1.0 h1:m2LYgu/2BarpF2yZnFPWtY3Tp41k0A4y51gDRZZsEuU=
+github.com/go-audio/aiff v1.1.0/go.mod h1:sDik1muYvhPiccClfri0fv6U2fyH/dy4VRWmUz0cz9Q=
+github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
+github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
+github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
+github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
+github.com/go-audio/wav v1.0.0/go.mod h1:3yoReyQOsiARkvPl3ERCi8JFjihzG6WhjYpZCf5zAWE=
+github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
+github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
 github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
 github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
 github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
@@ -108,6 +120,7 @@ github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u
 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
 github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
 github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
+github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
 github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
@@ -118,6 +131,10 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
 github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
 github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
 github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
+github.com/icza/bitio v1.1.0 h1:ysX4vtldjdi3Ygai5m1cWy4oLkhWTAi+SyO6HC8L9T0=
+github.com/icza/bitio v1.1.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A=
+github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k=
+github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6/go.mod h1:xQig96I1VNBDIWGCdTt54nHt6EeI639SmHycLYL7FkA=
 github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
 github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
 github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -126,6 +143,10 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
 github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
 github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
 github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
+github.com/jfreymuth/oggvorbis v1.0.5 h1:u+Ck+R0eLSRhgq8WTmffYnrVtSztJcYrl588DM4e3kQ=
+github.com/jfreymuth/oggvorbis v1.0.5/go.mod h1:1U4pqWmghcoVsCJJ4fRBKv9peUJMBHixthRlBeD6uII=
+github.com/jfreymuth/vorbis v1.0.2 h1:m1xH6+ZI4thH927pgKD8JOH4eaGRm18rEE9/0WKjvNE=
+github.com/jfreymuth/vorbis v1.0.2/go.mod h1:DoftRo4AznKnShRl1GxiTFCseHr4zR9BN3TWXyuzrqQ=
 github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
 github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
 github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
@@ -145,6 +166,7 @@ github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfn
 github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
 github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
 github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
+github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
 github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
 github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
 github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
@@ -152,10 +174,17 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx
 github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
 github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
 github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
+github.com/mattetti/audio v0.0.0-20180912171649-01576cde1f21/go.mod h1:LlQmBGkOuV/SKzEDXBPKauvN2UqCgzXO2XjecTGj40s=
 github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
 github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
 github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
 github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mewkiz/flac v1.0.13 h1:6wF8rRQKBFW159Daqx6Ro7K5ZnlVhHUKfS5aTsC4oXs=
+github.com/mewkiz/flac v1.0.13/go.mod h1:HfPYDA+oxjyuqMu2V+cyKcxF51KM6incpw5eZXmfA6k=
+github.com/mewkiz/pkg v0.0.0-20250417130911-3f050ff8c56d h1:IL2tii4jXLdhCeQN69HNzYYW1kl0meSG0wt5+sLwszU=
+github.com/mewkiz/pkg v0.0.0-20250417130911-3f050ff8c56d/go.mod h1:SIpumAnUWSy0q9RzKD3pyH3g1t5vdawUAPcW5tQrUtI=
+github.com/mewpkg/term v0.0.0-20241026122259-37a80af23985 h1:h8O1byDZ1uk6RUXMhj1QJU3VXFKXHDZxr4TXRPGeBa8=
+github.com/mewpkg/term v0.0.0-20241026122259-37a80af23985/go.mod h1:uiPmbdUbdt1NkGApKl7htQjZ8S7XaGUAVulJUJ9v6q4=
 github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
 github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -170,6 +199,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
 github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
 github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
 github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
+github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e h1:s2RNOM/IGdY0Y6qfTeUKhDawdHDpK9RGBdx80qN4Ttw=
+github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e/go.mod h1:nBdnFKj15wFbf94Rwfq4m30eAcyY9V/IyKAGQFtqkW0=
 github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
 github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg=
 github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
@@ -209,6 +240,9 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
 github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
 github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
 github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
+github.com/sunfish-shogi/bufseekio v0.0.0-20210207115823-a4185644b365/go.mod h1:dEzdXgvImkQ3WLI+0KQpmEx8T/C/ma9KeS3AfmU899I=
+github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300 h1:XQdibLKagjdevRB6vAjVY4qbSr8rQ610YzTkWcxzxSI=
+github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300/go.mod h1:FNa/dfN95vAYCNFrIKRrlRo+MBLbwmR9Asa5f2ljmBI=
 github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
 github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
 github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -238,6 +272,8 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
 github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
 github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
 github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
+github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c h1:xA2TJS9Hu/ivzaZIrDcwvpJ3Fnpsk5fDOJ4iSnL6J0w=
+github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c/go.mod h1:WSZ59bidJOO40JSJmLqlkBJrjZCtjbKKkygEMfzY/kc=
 github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
 github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
 go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
@@ -257,6 +293,7 @@ golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
 golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
 golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
 golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
+golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -270,6 +307,7 @@ golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
 golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -286,6 +324,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
 gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
+gopkg.in/src-d/go-billy.v4 v4.3.2 h1:0SQA1pRztfTFx2miS8sA97XvooFeNOmvUenF4o0EcVg=
+gopkg.in/src-d/go-billy.v4 v4.3.2/go.mod h1:nDjArDMp+XMs1aFAESLRjfGSgfvoYN0hDfzEk0GjC98=
 gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
 gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
 gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

+ 47 - 30
middleware/distributor.go

@@ -86,7 +86,7 @@ func Distribute() func(c *gin.Context) {
 					playgroundRequest := &dto.PlayGroundRequest{}
 					err = common.UnmarshalBodyReusable(c, playgroundRequest)
 					if err != nil {
-						abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
+						abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的playground请求, "+err.Error())
 						return
 					}
 					if playgroundRequest.Group != "" {
@@ -124,6 +124,20 @@ func Distribute() func(c *gin.Context) {
 	}
 }
 
+// getModelFromRequest 从请求中读取模型信息
+// 根据 Content-Type 自动处理:
+// - application/json
+// - application/x-www-form-urlencoded
+// - multipart/form-data
+func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
+	var modelRequest ModelRequest
+	err := common.UnmarshalBodyReusable(c, &modelRequest)
+	if err != nil {
+		return nil, errors.New("无效的请求, " + err.Error())
+	}
+	return &modelRequest, nil
+}
+
 func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 	var modelRequest ModelRequest
 	shouldSelectChannel := true
@@ -139,7 +153,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 			midjourneyRequest := dto.MidjourneyRequest{}
 			err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
 			if err != nil {
-				return nil, false, err
+				return nil, false, errors.New("无效的midjourney请求, " + err.Error())
 			}
 			midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
 			if mjErr != nil {
@@ -176,23 +190,12 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		relayMode := relayconstant.RelayModeUnknown
 		if c.Request.Method == http.MethodPost {
 			relayMode = relayconstant.RelayModeVideoSubmit
-			contentType := c.Request.Header.Get("Content-Type")
-			if strings.HasPrefix(contentType, "multipart/form-data") {
-				form, err := common.ParseMultipartFormReusable(c)
-				if err != nil {
-					return nil, false, errors.New("无效的video请求, " + err.Error())
-				}
-				defer form.RemoveAll()
-				if form != nil {
-					if values, ok := form.Value["model"]; ok && len(values) > 0 {
-						modelRequest.Model = values[0]
-					}
-				}
-			} else if strings.HasPrefix(contentType, "application/json") {
-				err = common.UnmarshalBodyReusable(c, &modelRequest)
-				if err != nil {
-					return nil, false, errors.New("无效的video请求, " + err.Error())
-				}
+			req, err := getModelFromRequest(c)
+			if err != nil {
+				return nil, false, err
+			}
+			if req != nil {
+				modelRequest.Model = req.Model
 			}
 		} else if c.Request.Method == http.MethodGet {
 			relayMode = relayconstant.RelayModeVideoFetchByID
@@ -202,10 +205,11 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 	} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
 		relayMode := relayconstant.RelayModeUnknown
 		if c.Request.Method == http.MethodPost {
-			err = common.UnmarshalBodyReusable(c, &modelRequest)
+			req, err := getModelFromRequest(c)
 			if err != nil {
-				return nil, false, errors.New("video无效的请求, " + err.Error())
+				return nil, false, err
 			}
+			modelRequest.Model = req.Model
 			relayMode = relayconstant.RelayModeVideoSubmit
 		} else if c.Request.Method == http.MethodGet {
 			relayMode = relayconstant.RelayModeVideoFetchByID
@@ -223,10 +227,11 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		}
 		c.Set("relay_mode", relayMode)
 	} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
-		err = common.UnmarshalBodyReusable(c, &modelRequest)
-	}
-	if err != nil {
-		return nil, false, errors.New("无效的请求, " + err.Error())
+		req, err := getModelFromRequest(c)
+		if err != nil {
+			return nil, false, err
+		}
+		modelRequest.Model = req.Model
 	}
 	if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
 		//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
@@ -248,19 +253,29 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		//modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
 		contentType := c.ContentType()
 		if slices.Contains([]string{gin.MIMEPOSTForm, gin.MIMEMultipartPOSTForm}, contentType) {
-			modelRequest.Model = c.PostForm("model")
+			req, err := getModelFromRequest(c)
+			if err == nil && req.Model != "" {
+				modelRequest.Model = req.Model
+			}
 		}
 	}
 	if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 		relayMode := relayconstant.RelayModeAudioSpeech
 		if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
+
 			modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
 		} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
-			modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
+			// 先尝试从请求读取
+			if req, err := getModelFromRequest(c); err == nil && req.Model != "" {
+				modelRequest.Model = req.Model
+			}
 			modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
 			relayMode = relayconstant.RelayModeAudioTranslation
 		} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
-			modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
+			// 先尝试从请求读取
+			if req, err := getModelFromRequest(c); err == nil && req.Model != "" {
+				modelRequest.Model = req.Model
+			}
 			modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
 			relayMode = relayconstant.RelayModeAudioTranscription
 		}
@@ -268,10 +283,12 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 	}
 	if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
 		// playground chat completions
-		err = common.UnmarshalBodyReusable(c, &modelRequest)
+		req, err := getModelFromRequest(c)
 		if err != nil {
-			return nil, false, errors.New("无效的请求, " + err.Error())
+			return nil, false, err
 		}
+		modelRequest.Model = req.Model
+		modelRequest.Group = req.Group
 		common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
 	}
 	return &modelRequest, shouldSelectChannel, nil

+ 1 - 53
relay/channel/openai/relay-openai.go

@@ -1,15 +1,10 @@
 package openai
 
 import (
-	"bytes"
 	"encoding/json"
 	"fmt"
 	"io"
-	"math"
-	"mime/multipart"
 	"net/http"
-	"os"
-	"path/filepath"
 	"strings"
 
 	"github.com/QuantumNous/new-api/common"
@@ -26,7 +21,6 @@ import (
 	"github.com/bytedance/gopkg/util/gopool"
 	"github.com/gin-gonic/gin"
 	"github.com/gorilla/websocket"
-	"github.com/pkg/errors"
 )
 
 func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
@@ -361,59 +355,13 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 		}
 	}
 
-	audioTokens, err := countAudioTokens(c)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
-	}
 	usage := &dto.Usage{}
-	usage.PromptTokens = audioTokens
+	usage.PromptTokens = info.PromptTokens
 	usage.CompletionTokens = 0
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	return nil, usage
 }
 
-func countAudioTokens(c *gin.Context) (int, error) {
-	body, err := common.GetRequestBody(c)
-	if err != nil {
-		return 0, errors.WithStack(err)
-	}
-
-	var reqBody struct {
-		File *multipart.FileHeader `form:"file" binding:"required"`
-	}
-	c.Request.Body = io.NopCloser(bytes.NewReader(body))
-	if err = c.ShouldBind(&reqBody); err != nil {
-		return 0, errors.WithStack(err)
-	}
-	ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
-	reqFp, err := reqBody.File.Open()
-	if err != nil {
-		return 0, errors.WithStack(err)
-	}
-	defer reqFp.Close()
-
-	tmpFp, err := os.CreateTemp("", "audio-*"+ext)
-	if err != nil {
-		return 0, errors.WithStack(err)
-	}
-	defer os.Remove(tmpFp.Name())
-
-	_, err = io.Copy(tmpFp, reqFp)
-	if err != nil {
-		return 0, errors.WithStack(err)
-	}
-	if err = tmpFp.Close(); err != nil {
-		return 0, errors.WithStack(err)
-	}
-
-	duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
-	if err != nil {
-		return 0, errors.WithStack(err)
-	}
-
-	return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
-}
-
 func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
 	if info == nil || info.ClientWs == nil || info.TargetWs == nil {
 		return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil

+ 0 - 10
relay/helper/valid_request.go

@@ -62,19 +62,9 @@ func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest,
 			return nil, errors.New("model is required")
 		}
 	default:
-		err = c.Request.ParseForm()
-		if err != nil {
-			return nil, err
-		}
-		formData := c.Request.PostForm
-		if audioRequest.Model == "" {
-			audioRequest.Model = formData.Get("model")
-		}
-
 		if audioRequest.Model == "" {
 			return nil, errors.New("model is required")
 		}
-		audioRequest.ResponseFormat = formData.Get("response_format")
 		if audioRequest.ResponseFormat == "" {
 			audioRequest.ResponseFormat = "json"
 		}

+ 29 - 2
service/token_counter.go

@@ -10,6 +10,7 @@ import (
 	_ "image/png"
 	"log"
 	"math"
+	"path/filepath"
 	"strings"
 	"sync"
 	"unicode/utf8"
@@ -18,6 +19,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	constant2 "github.com/QuantumNous/new-api/relay/constant"
 	"github.com/QuantumNous/new-api/types"
 
 	"github.com/gin-gonic/gin"
@@ -254,6 +256,10 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
 }
 
 func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
+	if meta == nil {
+		return 0, errors.New("token count meta is nil")
+	}
+
 	if !constant.GetMediaToken {
 		return 0, nil
 	}
@@ -263,8 +269,29 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
 	if info.RelayFormat == types.RelayFormatOpenAIRealtime {
 		return 0, nil
 	}
-	if meta == nil {
-		return 0, errors.New("token count meta is nil")
+	if info.RelayMode == constant2.RelayModeAudioTranscription || info.RelayMode == constant2.RelayModeAudioTranslation {
+		multiForm, err := common.ParseMultipartFormReusable(c)
+		if err != nil {
+			return 0, fmt.Errorf("error parsing multipart form: %v", err)
+		}
+		fileHeaders := multiForm.File["file"]
+		totalAudioToken := 0
+		for _, fileHeader := range fileHeaders {
+			file, err := fileHeader.Open()
+			if err != nil {
+				return 0, fmt.Errorf("error opening audio file: %v", err)
+			}
+			defer file.Close()
+			// get ext and io.seeker
+			ext := filepath.Ext(fileHeader.Filename)
+			duration, err := common.GetAudioDuration(c.Request.Context(), file, ext)
+			if err != nil {
+				return 0, fmt.Errorf("error getting audio duration: %v", err)
+			}
+			// 一分钟 1000 token,与 $price / minute 对齐
+			totalAudioToken += int(math.Round(math.Ceil(duration) / 60.0 * 1000))
+		}
+		return totalAudioToken, nil
 	}
 
 	model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)