Просмотр исходного кода

Merge pull request #4 from Calcium-Ion/main

1
luxl 2 лет назад
Родитель
Сommit
98326bcdb4

+ 16 - 38
.github/workflows/docker-image-amd64.yml

@@ -1,54 +1,32 @@
-name: Publish Docker image (amd64)
+name: Docker Image CI
 
 on:
   push:
-    tags:
-      - '*'
-  workflow_dispatch:
-    inputs:
-      name:
-        description: 'reason'
-        required: false
+    branches: [ "main" ]
+  pull_request:
+    branches: [ "main" ]
+
 jobs:
-  push_to_registries:
-    name: Push Docker image to multiple registries
-    runs-on: ubuntu-latest
-    permissions:
-      packages: write
-      contents: read
-    steps:
-      - name: Check out the repo
-        uses: actions/checkout@v3
 
-      - name: Save version info
-        run: |
-          git describe --tags > VERSION 
+  build:
+
+    runs-on: ubuntu-latest
 
-      - name: Log in to Docker Hub
-        uses: docker/login-action@v2
+    steps:
+      - uses: actions/checkout@v3
+      - uses: docker/login-action@v3.0.0
         with:
           username: ${{ secrets.DOCKERHUB_USERNAME }}
           password: ${{ secrets.DOCKERHUB_TOKEN }}
-
-      - name: Log in to the Container registry
-        uses: docker/login-action@v2
-        with:
-          registry: ghcr.io
-          username: ${{ github.actor }}
-          password: ${{ secrets.GITHUB_TOKEN }}
-
       - name: Extract metadata (tags, labels) for Docker
         id: meta
-        uses: docker/metadata-action@v4
+        uses: docker/metadata-action@v3
         with:
-          images: |
-            justsong/one-api
-            ghcr.io/${{ github.repository }}
-
-      - name: Build and push Docker images
-        uses: docker/build-push-action@v3
+          images: calciumion/neko-api
+      - name: Build the Docker image
+        uses: docker/build-push-action@v5.0.0
         with:
           context: .
           push: true
           tags: ${{ steps.meta.outputs.tags }}
-          labels: ${{ steps.meta.outputs.labels }}
+          labels: ${{ steps.meta.outputs.labels }}

+ 4 - 0
README.md

@@ -28,6 +28,10 @@
     + 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用情况,方便二次分销
 5. 渠道显示已使用额度,支持指定组织访问
 6. 分页支持选择每页显示数量
+7. 支持gpt-4-1106-vision-preview,dall-e-3,tts-1
+
+## 交流群
+<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="500">
 
 ## 界面截图
 ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)  

+ 5 - 1
common/model-ratio.go

@@ -37,7 +37,11 @@ var ModelRatio = map[string]float64{
 	"text-davinci-003":          10,
 	"text-davinci-edit-001":     10,
 	"code-davinci-edit-001":     10,
-	"whisper-1":                 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
+	"whisper-1":                 15,  // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
+	"tts-1":                     7.5, // 1k characters -> $0.015
+	"tts-1-1106":                7.5, // 1k characters -> $0.015
+	"tts-1-hd":                  15,  // 1k characters -> $0.03
+	"tts-1-hd-1106":             15,  // 1k characters -> $0.03
 	"davinci":                   10,
 	"curie":                     10,
 	"babbage":                   10,

+ 9 - 0
common/utils.go

@@ -207,3 +207,12 @@ func String2Int(str string) int {
 	}
 	return num
 }
+
+func StringsContains(strs []string, str string) bool {
+	for _, s := range strs {
+		if s == str {
+			return true
+		}
+	}
+	return false
+}

+ 9 - 3
controller/channel-test.go

@@ -86,9 +86,10 @@ func buildTestRequest() *ChatRequest {
 		Model:     "", // this will be set later
 		MaxTokens: 1,
 	}
+	content, _ := json.Marshal("hi")
 	testMessage := Message{
 		Role:    "user",
-		Content: "hi",
+		Content: content,
 	}
 	testRequest.Messages = append(testRequest.Messages, testMessage)
 	return testRequest
@@ -180,11 +181,16 @@ func testAllChannels(notify bool) error {
 			err, openaiErr := testChannel(channel, *testRequest)
 			tok := time.Now()
 			milliseconds := tok.Sub(tik).Milliseconds()
+
+			ban := false
 			if milliseconds > disableThreshold {
 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
-				disableChannel(channel.Id, channel.Name, err.Error())
+				ban = true
+			}
+			if openaiErr != nil {
+				err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message))
+				ban = true
 			}
-			ban := true
 			// parse *int to bool
 			if channel.AutoBan != nil && *channel.AutoBan == 0 {
 				ban = false

+ 36 - 0
controller/model.go

@@ -90,6 +90,42 @@ func init() {
 			Root:       "whisper-1",
 			Parent:     nil,
 		},
+		{
+			Id:         "tts-1",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "tts-1",
+			Parent:     nil,
+		},
+		{
+			Id:         "tts-1-1106",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "tts-1-1106",
+			Parent:     nil,
+		},
+		{
+			Id:         "tts-1-hd",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "tts-1-hd",
+			Parent:     nil,
+		},
+		{
+			Id:         "tts-1-hd-1106",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "tts-1-hd-1106",
+			Parent:     nil,
+		},
 		{
 			Id:         "gpt-3.5-turbo",
 			Object:     "model",

+ 2 - 2
controller/relay-aiproxy.go

@@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct {
 func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
 	query := ""
 	if len(request.Messages) != 0 {
-		query = request.Messages[len(request.Messages)-1].Content
+		query = string(request.Messages[len(request.Messages)-1].Content)
 	}
 	return &AIProxyLibraryRequest{
 		Model:  request.Model,
@@ -69,7 +69,7 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
 }
 
 func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
-	content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
+	content, _ := json.Marshal(response.Answer + aiProxyDocuments2Markdown(response.Documents))
 	choice := OpenAITextResponseChoice{
 		Index: 0,
 		Message: Message{

+ 6 - 5
controller/relay-ali.go

@@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
 		message := request.Messages[i]
 		if message.Role == "system" {
 			messages = append(messages, AliMessage{
-				User: message.Content,
+				User: string(message.Content),
 				Bot:  "Okay",
 			})
 			continue
 		} else {
 			if i == len(request.Messages)-1 {
-				prompt = message.Content
+				prompt = string(message.Content)
 				break
 			}
 			messages = append(messages, AliMessage{
-				User: message.Content,
-				Bot:  request.Messages[i+1].Content,
+				User: string(message.Content),
+				Bot:  string(request.Messages[i+1].Content),
 			})
 			i++
 		}
@@ -184,11 +184,12 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin
 }
 
 func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
+	content, _ := json.Marshal(response.Output.Text)
 	choice := OpenAITextResponseChoice{
 		Index: 0,
 		Message: Message{
 			Role:    "assistant",
-			Content: response.Output.Text,
+			Content: content,
 		},
 		FinishReason: response.Output.FinishReason,
 	}

+ 54 - 10
controller/relay-audio.go

@@ -11,10 +11,19 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/model"
+	"strings"
 )
 
+var availableVoices = []string{
+	"alloy",
+	"echo",
+	"fable",
+	"onyx",
+	"nova",
+	"shimmer",
+}
+
 func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
-	audioModel := "whisper-1"
 
 	tokenId := c.GetInt("token_id")
 	channelType := c.GetInt("channel")
@@ -22,8 +31,28 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	userId := c.GetInt("id")
 	group := c.GetString("group")
 
+	var audioRequest AudioRequest
+	err := common.UnmarshalBodyReusable(c, &audioRequest)
+	if err != nil {
+		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
+	}
+
+	// request validation
+	if audioRequest.Model == "" {
+		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
+	}
+
+	if strings.HasPrefix(audioRequest.Model, "tts-1") {
+		if audioRequest.Voice == "" {
+			return errorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
+		}
+		if !common.StringsContains(availableVoices, audioRequest.Voice) {
+			return errorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
+		}
+	}
+
 	preConsumedTokens := common.PreConsumedQuota
-	modelRatio := common.GetModelRatio(audioModel)
+	modelRatio := common.GetModelRatio(audioRequest.Model)
 	groupRatio := common.GetGroupRatio(group)
 	ratio := modelRatio * groupRatio
 	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
@@ -58,8 +87,8 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 		if err != nil {
 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 		}
-		if modelMap[audioModel] != "" {
-			audioModel = modelMap[audioModel]
+		if modelMap[audioRequest.Model] != "" {
+			audioRequest.Model = modelMap[audioRequest.Model]
 		}
 	}
 
@@ -97,9 +126,20 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 
 	defer func(ctx context.Context) {
 		go func() {
-			quota := countTokenText(audioResponse.Text, audioModel)
+			quota := 0
+			var promptTokens = 0
+			if strings.HasPrefix(audioRequest.Model, "tts-1") {
+				quota = countAudioToken(audioRequest.Input, audioRequest.Model)
+				promptTokens = quota
+			} else {
+				quota = countAudioToken(audioResponse.Text, audioRequest.Model)
+			}
+			quota = int(float64(quota) * ratio)
+			if ratio != 0 && quota <= 0 {
+				quota = 1
+			}
 			quotaDelta := quota - preConsumedQuota
-			err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota)
+			err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
 			if err != nil {
 				common.SysError("error consuming token remain quota: " + err.Error())
 			}
@@ -110,7 +150,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 			if quota != 0 {
 				tokenName := c.GetString("token_name")
 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent, tokenId)
+				model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				channelId := c.GetInt("channel_id")
 				model.UpdateChannelUsedQuota(channelId, quota)
@@ -127,9 +167,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	if err != nil {
 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 	}
-	err = json.Unmarshal(responseBody, &audioResponse)
-	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+	if strings.HasPrefix(audioRequest.Model, "tts-1") {
+
+	} else {
+		err = json.Unmarshal(responseBody, &audioResponse)
+		if err != nil {
+			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+		}
 	}
 
 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))

+ 4 - 3
controller/relay-baidu.go

@@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 		if message.Role == "system" {
 			messages = append(messages, BaiduMessage{
 				Role:    "user",
-				Content: message.Content,
+				Content: string(message.Content),
 			})
 			messages = append(messages, BaiduMessage{
 				Role:    "assistant",
@@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 		} else {
 			messages = append(messages, BaiduMessage{
 				Role:    message.Role,
-				Content: message.Content,
+				Content: string(message.Content),
 			})
 		}
 	}
@@ -109,11 +109,12 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 }
 
 func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
+	content, _ := json.Marshal(response.Result)
 	choice := OpenAITextResponseChoice{
 		Index: 0,
 		Message: Message{
 			Role:    "assistant",
-			Content: response.Result,
+			Content: content,
 		},
 		FinishReason: "stop",
 	}

+ 2 - 1
controller/relay-claude.go

@@ -93,11 +93,12 @@ func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletion
 }
 
 func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
+	content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
 	choice := OpenAITextResponseChoice{
 		Index: 0,
 		Message: Message{
 			Role:    "assistant",
-			Content: strings.TrimPrefix(claudeResponse.Completion, " "),
+			Content: content,
 			Name:    nil,
 		},
 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),

+ 1 - 1
controller/relay-image.go

@@ -147,7 +147,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	var textResponse ImageResponse
 	defer func(ctx context.Context) {
 		if consumeQuota {
-			err := model.PostConsumeTokenQuota(tokenId, userId, quota, 0)
+			err := model.PostConsumeTokenQuota(tokenId, userId, quota, 0, true)
 			if err != nil {
 				common.SysError("error consuming token remain quota: " + err.Error())
 			}

+ 1 - 1
controller/relay-mj.go

@@ -359,7 +359,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 
 	defer func(ctx context.Context) {
 		if consumeQuota {
-			err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0)
+			err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
 			if err != nil {
 				common.SysError("error consuming token remain quota: " + err.Error())
 			}

+ 1 - 1
controller/relay-openai.go

@@ -132,7 +132,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
 	if textResponse.Usage.TotalTokens == 0 {
 		completionTokens := 0
 		for _, choice := range textResponse.Choices {
-			completionTokens += countTokenText(choice.Message.Content, model)
+			completionTokens += countTokenText(string(choice.Message.Content), model)
 		}
 		textResponse.Usage = Usage{
 			PromptTokens:     promptTokens,

+ 3 - 2
controller/relay-palm.go

@@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
 	}
 	for _, message := range textRequest.Messages {
 		palmMessage := PaLMChatMessage{
-			Content: message.Content,
+			Content: string(message.Content),
 		}
 		if message.Role == "user" {
 			palmMessage.Author = "0"
@@ -76,11 +76,12 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
 	}
 	for i, candidate := range response.Candidates {
+		content, _ := json.Marshal(candidate.Content)
 		choice := OpenAITextResponseChoice{
 			Index: i,
 			Message: Message{
 				Role:    "assistant",
-				Content: candidate.Content,
+				Content: content,
 			},
 			FinishReason: "stop",
 		}

+ 4 - 3
controller/relay-tencent.go

@@ -84,7 +84,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
 		if message.Role == "system" {
 			messages = append(messages, TencentMessage{
 				Role:    "user",
-				Content: message.Content,
+				Content: string(message.Content),
 			})
 			messages = append(messages, TencentMessage{
 				Role:    "assistant",
@@ -93,7 +93,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
 			continue
 		}
 		messages = append(messages, TencentMessage{
-			Content: message.Content,
+			Content: string(message.Content),
 			Role:    message.Role,
 		})
 	}
@@ -119,11 +119,12 @@ func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
 		Usage:   response.Usage,
 	}
 	if len(response.Choices) > 0 {
+		content, _ := json.Marshal(response.Choices[0].Messages.Content)
 		choice := OpenAITextResponseChoice{
 			Index: 0,
 			Message: Message{
 				Role:    "assistant",
-				Content: response.Choices[0].Messages.Content,
+				Content: content,
 			},
 			FinishReason: response.Choices[0].FinishReason,
 		}

+ 7 - 3
controller/relay-text.go

@@ -199,9 +199,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	}
 	var promptTokens int
 	var completionTokens int
+	var err error
 	switch relayMode {
 	case RelayModeChatCompletions:
-		promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
+		promptTokens, err = countTokenMessages(textRequest.Messages, textRequest.Model)
+		if err != nil {
+			return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
+		}
 	case RelayModeCompletions:
 		promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
 	case RelayModeModerations:
@@ -400,7 +404,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			if preConsumedQuota != 0 {
 				go func(ctx context.Context) {
 					// return pre-consumed quota
-					err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0)
+					err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
 					if err != nil {
 						common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
 					}
@@ -434,7 +438,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 					quota = 0
 				}
 				quotaDelta := quota - preConsumedQuota
-				err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota)
+				err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
 				if err != nil {
 					common.LogError(ctx, "error consuming token remain quota: "+err.Error())
 				}

+ 104 - 6
controller/relay-utils.go

@@ -2,14 +2,23 @@ package controller
 
 import (
 	"encoding/json"
+	"errors"
 	"fmt"
+	"github.com/chai2010/webp"
 	"github.com/gin-gonic/gin"
 	"github.com/pkoukk/tiktoken-go"
+	"image"
+	_ "image/gif"
+	_ "image/jpeg"
+	_ "image/png"
 	"io"
+	"log"
+	"math"
 	"net/http"
 	"one-api/common"
 	"strconv"
 	"strings"
+	"unicode/utf8"
 )
 
 var stopFinishReason = "stop"
@@ -62,7 +71,66 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
 	return len(tokenEncoder.Encode(text, nil, nil))
 }
 
-func countTokenMessages(messages []Message, model string) int {
+func getImageToken(imageUrl MessageImageUrl) (int, error) {
+	if imageUrl.Detail == "low" {
+		return 85, nil
+	}
+
+	response, err := http.Get(imageUrl.Url)
+	if err != nil {
+		fmt.Println("Error: Failed to get the URL")
+		return 0, err
+	}
+
+	defer response.Body.Close()
+
+	// 限制读取的字节数,防止下载整个图片
+	limitReader := io.LimitReader(response.Body, 8192)
+
+	// 读取图片的头部信息来获取图片尺寸
+	config, _, err := image.DecodeConfig(limitReader)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
+		config, err = webp.DecodeConfig(limitReader)
+		if err != nil {
+			common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
+		}
+	}
+	if config.Width == 0 || config.Height == 0 {
+		return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", err.Error()))
+	}
+	if config.Width < 512 && config.Height < 512 {
+		if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
+			return 85, nil
+		}
+	}
+
+	shortSide := config.Width
+	otherSide := config.Height
+	log.Printf("width: %d, height: %d", config.Width, config.Height)
+	// 缩放倍数
+	scale := 1.0
+	if config.Height < shortSide {
+		shortSide = config.Height
+		otherSide = config.Width
+	}
+
+	// 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768
+	if shortSide > 768 {
+		scale = float64(shortSide) / 768
+		shortSide = 768
+	}
+	// 将另一边按照相同的比例缩小,向上取整
+	otherSide = int(math.Ceil(float64(otherSide) / scale))
+	log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale)
+	// 计算图片的token数量(边的长度除以512,向上取整)
+	tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
+	log.Printf("tiles: %d", tiles)
+	return tiles*170 + 85, nil
+}
+
+func countTokenMessages(messages []Message, model string) (int, error) {
+	//recover when panic
 	tokenEncoder := getTokenEncoder(model)
 	// Reference:
 	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
@@ -81,15 +149,37 @@ func countTokenMessages(messages []Message, model string) int {
 	tokenNum := 0
 	for _, message := range messages {
 		tokenNum += tokensPerMessage
-		tokenNum += getTokenNum(tokenEncoder, message.Content)
 		tokenNum += getTokenNum(tokenEncoder, message.Role)
-		if message.Name != nil {
-			tokenNum += tokensPerName
-			tokenNum += getTokenNum(tokenEncoder, *message.Name)
+		var arrayContent []MediaMessage
+		if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
+
+			var stringContent string
+			if err := json.Unmarshal(message.Content, &stringContent); err != nil {
+				return 0, err
+			} else {
+				tokenNum += getTokenNum(tokenEncoder, stringContent)
+				if message.Name != nil {
+					tokenNum += tokensPerName
+					tokenNum += getTokenNum(tokenEncoder, *message.Name)
+				}
+			}
+		} else {
+			for _, m := range arrayContent {
+				if m.Type == "image_url" {
+					imageTokenNum, err := getImageToken(m.ImageUrl)
+					if err != nil {
+						return 0, err
+					}
+					tokenNum += imageTokenNum
+					log.Printf("image token num: %d", imageTokenNum)
+				} else {
+					tokenNum += getTokenNum(tokenEncoder, m.Text)
+				}
+			}
 		}
 	}
 	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
-	return tokenNum
+	return tokenNum, nil
 }
 
 func countTokenInput(input any, model string) int {
@@ -106,6 +196,14 @@ func countTokenInput(input any, model string) int {
 	return 0
 }
 
+func countAudioToken(text string, model string) int {
+	if strings.HasPrefix(model, "tts") {
+		return utf8.RuneCountInString(text)
+	} else {
+		return countTokenText(text, model)
+	}
+}
+
 func countTokenText(text string, model string) int {
 	tokenEncoder := getTokenEncoder(model)
 	return getTokenNum(tokenEncoder, text)

+ 4 - 3
controller/relay-xunfei.go

@@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
 		if message.Role == "system" {
 			messages = append(messages, XunfeiMessage{
 				Role:    "user",
-				Content: message.Content,
+				Content: string(message.Content),
 			})
 			messages = append(messages, XunfeiMessage{
 				Role:    "assistant",
@@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
 		} else {
 			messages = append(messages, XunfeiMessage{
 				Role:    message.Role,
-				Content: message.Content,
+				Content: string(message.Content),
 			})
 		}
 	}
@@ -112,11 +112,12 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
 			},
 		}
 	}
+	content, _ := json.Marshal(response.Payload.Choices.Text[0].Content)
 	choice := OpenAITextResponseChoice{
 		Index: 0,
 		Message: Message{
 			Role:    "assistant",
-			Content: response.Payload.Choices.Text[0].Content,
+			Content: content,
 		},
 		FinishReason: stopFinishReason,
 	}

+ 4 - 3
controller/relay-zhipu.go

@@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
 		if message.Role == "system" {
 			messages = append(messages, ZhipuMessage{
 				Role:    "system",
-				Content: message.Content,
+				Content: string(message.Content),
 			})
 			messages = append(messages, ZhipuMessage{
 				Role:    "user",
@@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
 		} else {
 			messages = append(messages, ZhipuMessage{
 				Role:    message.Role,
-				Content: message.Content,
+				Content: string(message.Content),
 			})
 		}
 	}
@@ -144,11 +144,12 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
 		Usage:   response.Data.Usage,
 	}
 	for i, choice := range response.Data.Choices {
+		content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
 		openaiChoice := OpenAITextResponseChoice{
 			Index: i,
 			Message: Message{
 				Role:    choice.Role,
-				Content: strings.Trim(choice.Content, "\""),
+				Content: content,
 			},
 			FinishReason: "",
 		}

+ 28 - 8
controller/relay.go

@@ -1,6 +1,7 @@
 package controller
 
 import (
+	"encoding/json"
 	"fmt"
 	"log"
 	"net/http"
@@ -12,9 +13,20 @@ import (
 )
 
 type Message struct {
-	Role    string  `json:"role"`
-	Content string  `json:"content"`
-	Name    *string `json:"name,omitempty"`
+	Role    string          `json:"role"`
+	Content json.RawMessage `json:"content"`
+	Name    *string         `json:"name,omitempty"`
+}
+
+type MediaMessage struct {
+	Type     string          `json:"type"`
+	Text     string          `json:"text"`
+	ImageUrl MessageImageUrl `json:"image_url,omitempty"`
+}
+
+type MessageImageUrl struct {
+	Url    string `json:"url"`
+	Detail string `json:"detail"`
 }
 
 const (
@@ -70,6 +82,12 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
 	return input
 }
 
+type AudioRequest struct {
+	Model string `json:"model"`
+	Voice string `json:"voice"`
+	Input string `json:"input"`
+}
+
 type ChatRequest struct {
 	Model     string    `json:"model"`
 	Messages  []Message `json:"messages"`
@@ -85,11 +103,13 @@ type TextRequest struct {
 }
 
 type ImageRequest struct {
-	Model   string `json:"model"`
-	Quality string `json:"quality"`
-	Prompt  string `json:"prompt"`
-	N       int    `json:"n"`
-	Size    string `json:"size"`
+	Model          string `json:"model"`
+	Prompt         string `json:"prompt"`
+	N              int    `json:"n"`
+	Size           string `json:"size"`
+	Quality        string `json:"quality,omitempty"`
+	ResponseFormat string `json:"response_format,omitempty"`
+	Style          string `json:"style,omitempty"`
 }
 
 type AudioResponse struct {

+ 1 - 0
go.mod

@@ -4,6 +4,7 @@ module one-api
 go 1.18
 
 require (
+	github.com/chai2010/webp v1.1.1
 	github.com/gin-contrib/cors v1.4.0
 	github.com/gin-contrib/gzip v0.0.6
 	github.com/gin-contrib/sessions v0.0.5

+ 2 - 0
go.sum

@@ -3,6 +3,8 @@ github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s
 github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
 github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
 github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
+github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
 github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
 github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
 github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=

+ 6 - 3
middleware/distributor.go

@@ -46,9 +46,8 @@ func Distribute() func(c *gin.Context) {
 				if modelRequest.Model == "" {
 					modelRequest.Model = "midjourney"
 				}
-			} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
-				err = common.UnmarshalBodyReusable(c, &modelRequest)
 			}
+			err = common.UnmarshalBodyReusable(c, &modelRequest)
 			if err != nil {
 				abortWithMessage(c, http.StatusBadRequest, "无效的请求")
 				return
@@ -70,7 +69,11 @@ func Distribute() func(c *gin.Context) {
 			}
 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 				if modelRequest.Model == "" {
-					modelRequest.Model = "whisper-1"
+					if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
+						modelRequest.Model = "tts-1"
+					} else {
+						modelRequest.Model = "whisper-1"
+					}
 				}
 			}
 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)

+ 16 - 10
model/token.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"gorm.io/gorm"
 	"one-api/common"
+	"strconv"
 	"strings"
 )
 
@@ -194,22 +195,31 @@ func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
 		return 0, err
 	}
 	if userQuota < quota {
-		return userQuota, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
+		return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
 	}
 	if !token.UnlimitedQuota {
 		err = DecreaseTokenQuota(tokenId, quota)
 		if err != nil {
-			return userQuota, err
+			return 0, err
 		}
 	}
 	err = DecreaseUserQuota(token.UserId, quota)
-	return userQuota, err
+	return userQuota - quota, err
 }
 
-func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int) (err error) {
+func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
 	token, err := GetTokenById(tokenId)
 
 	if quota > 0 {
+		err = DecreaseUserQuota(token.UserId, quota)
+	} else {
+		err = IncreaseUserQuota(token.UserId, -quota)
+	}
+	if err != nil {
+		return err
+	}
+
+	if sendEmail {
 		quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
 		noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
 		if quotaTooLow || noMoreQuota {
@@ -229,16 +239,12 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
 					if err != nil {
 						common.SysError("failed to send email" + err.Error())
 					}
+					common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
 				}
 			}()
 		}
-		err = DecreaseUserQuota(token.UserId, quota)
-	} else {
-		err = IncreaseUserQuota(token.UserId, -quota)
-	}
-	if err != nil {
-		return err
 	}
+
 	if !token.UnlimitedQuota {
 		if quota > 0 {
 			err = DecreaseTokenQuota(tokenId, quota)

+ 1 - 0
router/relay-router.go

@@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) {
 		relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
 		relayV1Router.POST("/audio/transcriptions", controller.Relay)
 		relayV1Router.POST("/audio/translations", controller.Relay)
+		relayV1Router.POST("/audio/speech", controller.Relay)
 		relayV1Router.GET("/files", controller.RelayNotImplemented)
 		relayV1Router.POST("/files", controller.RelayNotImplemented)
 		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)