فهرست منبع

Merge pull request #13 from Calcium-Ion/main

1
luxl 2 سال پیش
والد
کامیت
734628e27f
4فایلهای تغییر یافته به همراه63 افزوده شده و 66 حذف شده
  1. 1 1
      common/database.go
  2. 21 23
      controller/relay-openai.go
  3. 39 40
      controller/relay-text.go
  4. 2 2
      model/ability.go

+ 1 - 1
common/database.go

@@ -3,4 +3,4 @@ package common
 var UsingSQLite = false
 var UsingPostgreSQL = false
 
-var SQLitePath = "one-api.db"
+var SQLitePath = "one-api.db?_busy_timeout=5000"

+ 21 - 23
controller/relay-openai.go

@@ -88,30 +88,28 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 	return nil, responseText
 }
 
-func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
+func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
 	var textResponse TextResponse
-	if consumeQuota {
-		responseBody, err := io.ReadAll(resp.Body)
-		if err != nil {
-			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
-		}
-		err = resp.Body.Close()
-		if err != nil {
-			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-		}
-		err = json.Unmarshal(responseBody, &textResponse)
-		if err != nil {
-			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
-		}
-		if textResponse.Error.Type != "" {
-			return &OpenAIErrorWithStatusCode{
-				OpenAIError: textResponse.Error,
-				StatusCode:  resp.StatusCode,
-			}, nil
-		}
-		// Reset response body
-		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = json.Unmarshal(responseBody, &textResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	if textResponse.Error.Type != "" {
+		return &OpenAIErrorWithStatusCode{
+			OpenAIError: textResponse.Error,
+			StatusCode:  resp.StatusCode,
+		}, nil
 	}
+	// Reset response body
+	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 	// We shouldn't set the header before we parse the response body, because the parse part may fail.
 	// And then we will have to send an error response, but in this case, the header has already been set.
 	// So the httpClient will be confused by the response.
@@ -120,7 +118,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
 		c.Writer.Header().Set(k, v[0])
 	}
 	c.Writer.WriteHeader(resp.StatusCode)
-	_, err := io.Copy(c.Writer, resp.Body)
+	_, err = io.Copy(c.Writer, resp.Body)
 	if err != nil {
 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 	}

+ 39 - 40
controller/relay-text.go

@@ -50,14 +50,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	channelId := c.GetInt("channel_id")
 	tokenId := c.GetInt("token_id")
 	userId := c.GetInt("id")
-	consumeQuota := c.GetBool("consume_quota")
 	group := c.GetString("group")
+	startTime := time.Now()
 	var textRequest GeneralOpenAIRequest
-	if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
-		err := common.UnmarshalBodyReusable(c, &textRequest)
-		if err != nil {
-			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
-		}
+
+	err := common.UnmarshalBodyReusable(c, &textRequest)
+	if err != nil {
+		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 	}
 	if relayMode == RelayModeModerations && textRequest.Model == "" {
 		textRequest.Model = "text-moderation-latest"
@@ -199,7 +198,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	}
 	var promptTokens int
 	var completionTokens int
-	var err error
 	switch relayMode {
 	case RelayModeChatCompletions:
 		promptTokens, err = countTokenMessages(textRequest.Messages, textRequest.Model)
@@ -236,7 +234,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		preConsumedQuota = 0
 		//common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
 	}
-	if consumeQuota && preConsumedQuota > 0 {
+	if preConsumedQuota > 0 {
 		userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 		if err != nil {
 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
@@ -420,39 +418,40 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	defer func(ctx context.Context) {
 		// c.Writer.Flush()
 		go func() {
-			if consumeQuota {
-				quota := 0
-				completionRatio := common.GetCompletionRatio(textRequest.Model)
-				promptTokens = textResponse.Usage.PromptTokens
-				completionTokens = textResponse.Usage.CompletionTokens
+			quota := 0
+			completionRatio := common.GetCompletionRatio(textRequest.Model)
+			promptTokens = textResponse.Usage.PromptTokens
+			completionTokens = textResponse.Usage.CompletionTokens
 
-				quota = promptTokens + int(float64(completionTokens)*completionRatio)
-				quota = int(float64(quota) * ratio)
-				if ratio != 0 && quota <= 0 {
-					quota = 1
-				}
-				totalTokens := promptTokens + completionTokens
-				if totalTokens == 0 {
-					// in this case, must be some error happened
-					// we cannot just return, because we may have to return the pre-consumed quota
-					quota = 0
-				}
-				quotaDelta := quota - preConsumedQuota
-				err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
-				if err != nil {
-					common.LogError(ctx, "error consuming token remain quota: "+err.Error())
-				}
-				err = model.CacheUpdateUserQuota(userId)
-				if err != nil {
-					common.LogError(ctx, "error update user quota cache: "+err.Error())
-				}
-				if quota != 0 {
-					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-					model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
-					model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
-					model.UpdateChannelUsedQuota(channelId, quota)
-				}
+			quota = promptTokens + int(float64(completionTokens)*completionRatio)
+			quota = int(float64(quota) * ratio)
+			if ratio != 0 && quota <= 0 {
+				quota = 1
+			}
+			totalTokens := promptTokens + completionTokens
+			if totalTokens == 0 {
+				// in this case, must be some error happened
+				// we cannot just return, because we may have to return the pre-consumed quota
+				quota = 0
+			}
+			quotaDelta := quota - preConsumedQuota
+			err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
+			if err != nil {
+				common.LogError(ctx, "error consuming token remain quota: "+err.Error())
 			}
+			err = model.CacheUpdateUserQuota(userId)
+			if err != nil {
+				common.LogError(ctx, "error update user quota cache: "+err.Error())
+			}
+			// record all the consume log even if quota is 0
+			useTimeSeconds := time.Now().Unix() - startTime.Unix()
+			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
+			model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
+			model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+			model.UpdateChannelUsedQuota(channelId, quota)
+			//if quota != 0 {
+			//
+			//}
 		}()
 	}(c.Request.Context())
 	switch apiType {
@@ -466,7 +465,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 			return nil
 		} else {
-			err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
+			err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
 			if err != nil {
 				return err
 			}

+ 2 - 2
model/ability.go

@@ -15,8 +15,8 @@ type Ability struct {
 
 func GetGroupModels(group string) []string {
 	var abilities []Ability
-	//去重
-	DB.Where("`group` = ?", group).Distinct("model").Find(&abilities)
+	//去重 enabled = true
+	DB.Where("`group` = ? and enabled = ?", group, true).Find(&abilities)
 	models := make([]string, 0, len(abilities))
 	for _, ability := range abilities {
 		models = append(models, ability.Model)