Просмотр исходного кода

Merge branch 'Calcium-Ion:main' into main

H. 1 год назад
Родитель
Сommit
0b3a00640e

+ 1 - 1
.github/workflows/docker-image-amd64.yml

@@ -12,7 +12,7 @@ on:
 jobs:
   push_to_registries:
     name: Push Docker image to multiple registries
-    runs-on: self-hosted
+    runs-on: ubuntu-latest
     permissions:
       packages: write
       contents: read

+ 1 - 1
.github/workflows/linux-release.yml

@@ -9,7 +9,7 @@ on:
       - '!*-alpha*'
 jobs:
   release:
-    runs-on: self-hosted
+    runs-on: ubuntu-latest
     steps:
       - name: Checkout
         uses: actions/checkout@v3

+ 1 - 1
.github/workflows/macos-release.yml

@@ -9,7 +9,7 @@ on:
       - '!*-alpha*'
 jobs:
   release:
-    runs-on: self-hosted
+    runs-on: ubuntu-latest
     steps:
       - name: Checkout
         uses: actions/checkout@v3

+ 1 - 1
.github/workflows/windows-release.yml

@@ -9,7 +9,7 @@ on:
       - '!*-alpha*'
 jobs:
   release:
-    runs-on: self-hosted
+    runs-on: ubuntu-latest
     defaults:
       run:
         shell: bash

+ 1 - 1
Dockerfile

@@ -24,7 +24,7 @@ FROM alpine
 
 RUN apk update \
     && apk upgrade \
-    && apk add --no-cache ca-certificates tzdata \
+    && apk add --no-cache ca-certificates tzdata ffmpeg\
     && update-ca-certificates 2>/dev/null || true
 
 COPY --from=builder2 /build/one-api /

+ 4 - 1
common/model-ratio.go

@@ -401,10 +401,13 @@ func GetCompletionRatio(name string) float64 {
 		case "command-r-plus-08-2024":
 			return 4
 		default:
-			return 2
+			return 4
 		}
 	}
 	if strings.HasPrefix(name, "deepseek") {
+		if name == "deepseek-reasoner" {
+			return 4
+		}
 		return 2
 	}
 	if strings.HasPrefix(name, "ERNIE-Speed-") {

+ 33 - 0
common/utils.go

@@ -1,14 +1,19 @@
 package common
 
 import (
+	"bytes"
+	"context"
 	crand "crypto/rand"
 	"encoding/base64"
 	"fmt"
+	"github.com/pkg/errors"
 	"html/template"
+	"io"
 	"log"
 	"math/big"
 	"math/rand"
 	"net"
+	"os"
 	"os/exec"
 	"runtime"
 	"strconv"
@@ -207,3 +212,31 @@ func RandomSleep() {
 	// Sleep for 0-3000 ms
 	time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
 }
+
+// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
+func SaveTmpFile(filename string, data io.Reader) (string, error) {
+	f, err := os.CreateTemp(os.TempDir(), filename)
+	if err != nil {
+		return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
+	}
+	defer f.Close()
+
+	_, err = io.Copy(f, data)
+	if err != nil {
+		return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
+	}
+
+	return f.Name(), nil
+}
+
+// GetAudioDuration returns the duration of an audio file in seconds.
+func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
+	// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
+	c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
+	output, err := c.Output()
+	if err != nil {
+		return 0, errors.Wrap(err, "failed to get audio duration")
+	}
+
+	return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
+}

+ 9 - 0
model/log.go

@@ -27,6 +27,7 @@ type Log struct {
 	UseTime          int    `json:"use_time" gorm:"default:0"`
 	IsStream         bool   `json:"is_stream" gorm:"default:false"`
 	ChannelId        int    `json:"channel" gorm:"index"`
+	ChannelName      string `json:"channel_name" gorm:"->"`
 	TokenId          int    `json:"token_id" gorm:"default:0;index"`
 	Group            string `json:"group" gorm:"index"`
 	Other            string `json:"other"`
@@ -130,6 +131,10 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
 	} else {
 		tx = LOG_DB.Where("type = ?", logType)
 	}
+
+	tx = tx.Joins("LEFT JOIN channels ON logs.channel_id = channels.id")
+	tx = tx.Select("logs.*, channels.name as channel_name")
+
 	if modelName != "" {
 		tx = tx.Where("model_name like ?", modelName)
 	}
@@ -169,6 +174,10 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
 	} else {
 		tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType)
 	}
+
+	tx = tx.Joins("LEFT JOIN channels ON logs.channel_id = channels.id")
+	tx = tx.Select("logs.*, channels.name as channel_name")
+
 	if modelName != "" {
 		tx = tx.Where("model_name like ?", modelName)
 	}

+ 1 - 1
relay/channel/deepseek/constants.go

@@ -1,7 +1,7 @@
 package deepseek
 
 var ModelList = []string{
-	"deepseek-chat", "deepseek-coder",
+	"deepseek-chat", "deepseek-reasoner",
 }
 
 var ChannelName = "deepseek"

+ 42 - 51
relay/channel/openai/relay-openai.go

@@ -5,7 +5,10 @@ import (
 	"bytes"
 	"encoding/json"
 	"fmt"
+	"github.com/pkg/errors"
 	"io"
+	"math"
+	"mime/multipart"
 	"net/http"
 	"one-api/common"
 	"one-api/constant"
@@ -13,6 +16,7 @@ import (
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
+	"os"
 	"strings"
 	"sync"
 	"time"
@@ -316,6 +320,11 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 }
 
 func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	// count tokens by audio file duration
+	audioTokens, err := countAudioTokens(c)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
+	}
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -340,70 +349,52 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	}
 	resp.Body.Close()
 
-	var text string
-	switch responseFormat {
-	case "json":
-		text, err = getTextFromJSON(responseBody)
-	case "text":
-		text, err = getTextFromText(responseBody)
-	case "srt":
-		text, err = getTextFromSRT(responseBody)
-	case "verbose_json":
-		text, err = getTextFromVerboseJSON(responseBody)
-	case "vtt":
-		text, err = getTextFromVTT(responseBody)
-	}
-
 	usage := &dto.Usage{}
-	usage.PromptTokens = info.PromptTokens
-	usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
+	usage.PromptTokens = audioTokens
+	usage.CompletionTokens = 0
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	return nil, usage
 }
 
-func getTextFromVTT(body []byte) (string, error) {
-	return getTextFromSRT(body)
-}
+func countAudioTokens(c *gin.Context) (int, error) {
+	body, err := common.GetRequestBody(c)
+	if err != nil {
+		return 0, errors.WithStack(err)
+	}
 
-func getTextFromVerboseJSON(body []byte) (string, error) {
-	var whisperResponse dto.WhisperVerboseJSONResponse
-	if err := json.Unmarshal(body, &whisperResponse); err != nil {
-		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
+	var reqBody struct {
+		File *multipart.FileHeader `form:"file" binding:"required"`
+	}
+	c.Request.Body = io.NopCloser(bytes.NewReader(body))
+	if err = c.ShouldBind(&reqBody); err != nil {
+		return 0, errors.WithStack(err)
 	}
-	return whisperResponse.Text, nil
-}
 
-func getTextFromSRT(body []byte) (string, error) {
-	scanner := bufio.NewScanner(strings.NewReader(string(body)))
-	var builder strings.Builder
-	var textLine bool
-	for scanner.Scan() {
-		line := scanner.Text()
-		if textLine {
-			builder.WriteString(line)
-			textLine = false
-			continue
-		} else if strings.Contains(line, "-->") {
-			textLine = true
-			continue
-		}
+	reqFp, err := reqBody.File.Open()
+	if err != nil {
+		return 0, errors.WithStack(err)
 	}
-	if err := scanner.Err(); err != nil {
-		return "", err
+
+	tmpFp, err := os.CreateTemp("", "audio-*")
+	if err != nil {
+		return 0, errors.WithStack(err)
 	}
-	return builder.String(), nil
-}
+	defer os.Remove(tmpFp.Name())
 
-func getTextFromText(body []byte) (string, error) {
-	return strings.TrimSuffix(string(body), "\n"), nil
-}
+	_, err = io.Copy(tmpFp, reqFp)
+	if err != nil {
+		return 0, errors.WithStack(err)
+	}
+	if err = tmpFp.Close(); err != nil {
+		return 0, errors.WithStack(err)
+	}
 
-func getTextFromJSON(body []byte) (string, error) {
-	var whisperResponse dto.AudioResponse
-	if err := json.Unmarshal(body, &whisperResponse); err != nil {
-		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
+	duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
+	if err != nil {
+		return 0, errors.WithStack(err)
 	}
-	return whisperResponse.Text, nil
+
+	return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
 }
 
 func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {

+ 22 - 10
web/src/components/LogsTable.js

@@ -157,13 +157,15 @@ const LogsTable = () => {
           record.type === 0 || record.type === 2 ? (
             <div>
               {
-                <Tag
-                  color={colors[parseInt(text) % colors.length]}
-                  size='large'
-                >
-                  {' '}
-                  {text}{' '}
-                </Tag>
+                <Tooltip content={record.channel_name || '[未知]'}>
+                  <Tag
+                    color={colors[parseInt(text) % colors.length]}
+                    size='large'
+                  >
+                    {' '}
+                    {text}{' '}
+                  </Tag>
+                </Tooltip>
               }
             </div>
           ) : (
@@ -234,7 +236,12 @@ const LogsTable = () => {
               </>
             );
          } else {
-           let other = JSON.parse(record.other);
+           let other = null;
+           try {
+             other = JSON.parse(record.other);
+           } catch (e) {
+             console.error(`Failed to parse record.other: "${record.other}".`, e);
+           }
            if (other === null) {
              return <></>;
            }
@@ -543,6 +550,12 @@ const LogsTable = () => {
         //   key: '渠道重试',
         //   value: content,
         // })
+      }      
+      if (isAdminUser && (logs[i].type === 0 || logs[i].type === 2)) {
+        expandDataLocal.push({
+          key: t('渠道信息'),
+          value: `${logs[i].channel} - ${logs[i].channel_name || '[未知]'}`
+        });
       }
       if (other?.ws || other?.audio) {
         expandDataLocal.push({
@@ -595,13 +608,12 @@ const LogsTable = () => {
           key: t('计费过程'),
           value: content,
         });
-      }
 
+      }
       expandDatesLocal[logs[i].key] = expandDataLocal;
     }
 
     setExpandData(expandDatesLocal);
-
     setLogs(logs);
   };