فهرست منبع

refactor: Enhance user context and quota management

- Add new context keys for user-related information
- Modify user cache and authentication middleware to populate context
- Refactor quota and notification services to use context-based user data
- Remove redundant database queries by leveraging context information
- Update various components to use new context-based user retrieval methods
1808837298@qq.com 1 سال پیش
والد
کامیت
069f2672c1

+ 5 - 0
constant/context_key.go

@@ -2,4 +2,9 @@ package constant
 
 
 const (
 const (
 	ContextKeyRequestStartTime = "request_start_time"
 	ContextKeyRequestStartTime = "request_start_time"
+	ContextKeyUserSetting      = "user_setting"
+	ContextKeyUserQuota        = "user_quota"
+	ContextKeyUserStatus       = "user_status"
+	ContextKeyUserEmail        = "user_email"
+	ContextKeyUserGroup        = "user_group"
 )
 )

+ 1 - 1
controller/midjourney.go

@@ -159,7 +159,7 @@ func UpdateMidjourneyTaskBulk() {
 					common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
 					common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
 				} else {
 				} else {
 					if shouldReturnQuota {
 					if shouldReturnQuota {
-						err = model.IncreaseUserQuota(task.UserId, task.Quota)
+						err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
 						if err != nil {
 						if err != nil {
 							common.LogError(ctx, "fail to increase user quota: "+err.Error())
 							common.LogError(ctx, "fail to increase user quota: "+err.Error())
 						}
 						}

+ 1 - 1
controller/task.go

@@ -159,7 +159,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
 			} else {
 			} else {
 				quota := task.Quota
 				quota := task.Quota
 				if quota != 0 {
 				if quota != 0 {
-					err = model.IncreaseUserQuota(task.UserId, quota)
+					err = model.IncreaseUserQuota(task.UserId, quota, false)
 					if err != nil {
 					if err != nil {
 						common.LogError(ctx, "fail to increase user quota: "+err.Error())
 						common.LogError(ctx, "fail to increase user quota: "+err.Error())
 					}
 					}

+ 1 - 1
controller/topup.go

@@ -210,7 +210,7 @@ func EpayNotify(c *gin.Context) {
 			}
 			}
 			//user, _ := model.GetUserById(topUp.UserId, false)
 			//user, _ := model.GetUserById(topUp.UserId, false)
 			//user.Quota += topUp.Amount * 500000
 			//user.Quota += topUp.Amount * 500000
-			err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
+			err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit), true)
 			if err != nil {
 			if err != nil {
 				log.Printf("易支付回调更新用户失败: %v", topUp)
 				log.Printf("易支付回调更新用户失败: %v", topUp)
 				return
 				return

+ 5 - 1
middleware/auth.go

@@ -199,15 +199,19 @@ func TokenAuth() func(c *gin.Context) {
 			abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
 			abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
 			return
 			return
 		}
 		}
-		userEnabled, err := model.IsUserEnabled(token.UserId, false)
+		userCache, err := model.GetUserCache(token.UserId)
 		if err != nil {
 		if err != nil {
 			abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
 			abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
 			return
 			return
 		}
 		}
+		userEnabled := userCache.Status == common.UserStatusEnabled
 		if !userEnabled {
 		if !userEnabled {
 			abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
 			abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
 			return
 			return
 		}
 		}
+
+		userCache.WriteContext(c)
+
 		c.Set("id", token.UserId)
 		c.Set("id", token.UserId)
 		c.Set("token_id", token.Id)
 		c.Set("token_id", token.Id)
 		c.Set("token_key", token.Key)
 		c.Set("token_key", token.Key)

+ 1 - 2
middleware/distributor.go

@@ -32,7 +32,6 @@ func Distribute() func(c *gin.Context) {
 				return
 				return
 			}
 			}
 		}
 		}
-		userId := c.GetInt("id")
 		var channel *model.Channel
 		var channel *model.Channel
 		channelId, ok := c.Get("specific_channel_id")
 		channelId, ok := c.Get("specific_channel_id")
 		modelRequest, shouldSelectChannel, err := getModelRequest(c)
 		modelRequest, shouldSelectChannel, err := getModelRequest(c)
@@ -40,7 +39,7 @@ func Distribute() func(c *gin.Context) {
 			abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
 			abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
 			return
 			return
 		}
 		}
-		userGroup, _ := model.GetUserGroup(userId, false)
+		userGroup := c.GetString(constant.ContextKeyUserGroup)
 		tokenGroup := c.GetString("token_group")
 		tokenGroup := c.GetString("token_group")
 		if tokenGroup != "" {
 		if tokenGroup != "" {
 			// check common.UserUsableGroups[userGroup]
 			// check common.UserUsableGroups[userGroup]

+ 5 - 5
model/log.go

@@ -1,8 +1,8 @@
 package model
 package model
 
 
 import (
 import (
-	"context"
 	"fmt"
 	"fmt"
+	"github.com/gin-gonic/gin"
 	"one-api/common"
 	"one-api/common"
 	"os"
 	"os"
 	"strings"
 	"strings"
@@ -87,14 +87,14 @@ func RecordLog(userId int, logType int, content string) {
 	}
 	}
 }
 }
 
 
-func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int,
+func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
 	modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
 	modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
 	isStream bool, group string, other map[string]interface{}) {
 	isStream bool, group string, other map[string]interface{}) {
-	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
+	common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
 	if !common.LogConsumeEnabled {
 	if !common.LogConsumeEnabled {
 		return
 		return
 	}
 	}
-	username, _ := GetUsernameById(userId, false)
+	username := c.GetString("username")
 	otherStr := common.MapToJsonStr(other)
 	otherStr := common.MapToJsonStr(other)
 	log := &Log{
 	log := &Log{
 		UserId:           userId,
 		UserId:           userId,
@@ -116,7 +116,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
 	}
 	}
 	err := LOG_DB.Create(log).Error
 	err := LOG_DB.Create(log).Error
 	if err != nil {
 	if err != nil {
-		common.LogError(ctx, "failed to record log: "+err.Error())
+		common.LogError(c, "failed to record log: "+err.Error())
 	}
 	}
 	if common.DataExportEnabled {
 	if common.DataExportEnabled {
 		gopool.Go(func() {
 		gopool.Go(func() {

+ 33 - 33
model/user.go

@@ -320,7 +320,7 @@ func (user *User) Insert(inviterId int) error {
 	}
 	}
 	if inviterId != 0 {
 	if inviterId != 0 {
 		if common.QuotaForInvitee > 0 {
 		if common.QuotaForInvitee > 0 {
-			_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
+			_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
 			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
 			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
 		}
 		}
 		if common.QuotaForInviter > 0 {
 		if common.QuotaForInviter > 0 {
@@ -502,35 +502,35 @@ func IsAdmin(userId int) bool {
 	return user.Role >= common.RoleAdminUser
 	return user.Role >= common.RoleAdminUser
 }
 }
 
 
-// IsUserEnabled checks user status from Redis first, falls back to DB if needed
-func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
-	defer func() {
-		// Update Redis cache asynchronously on successful DB read
-		if shouldUpdateRedis(fromDB, err) {
-			gopool.Go(func() {
-				if err := updateUserStatusCache(id, status); err != nil {
-					common.SysError("failed to update user status cache: " + err.Error())
-				}
-			})
-		}
-	}()
-	if !fromDB && common.RedisEnabled {
-		// Try Redis first
-		status, err := getUserStatusCache(id)
-		if err == nil {
-			return status == common.UserStatusEnabled, nil
-		}
-		// Don't return error - fall through to DB
-	}
-	fromDB = true
-	var user User
-	err = DB.Where("id = ?", id).Select("status").Find(&user).Error
-	if err != nil {
-		return false, err
-	}
-
-	return user.Status == common.UserStatusEnabled, nil
-}
+//// IsUserEnabled checks user status from Redis first, falls back to DB if needed
+//func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
+//	defer func() {
+//		// Update Redis cache asynchronously on successful DB read
+//		if shouldUpdateRedis(fromDB, err) {
+//			gopool.Go(func() {
+//				if err := updateUserStatusCache(id, status); err != nil {
+//					common.SysError("failed to update user status cache: " + err.Error())
+//				}
+//			})
+//		}
+//	}()
+//	if !fromDB && common.RedisEnabled {
+//		// Try Redis first
+//		status, err := getUserStatusCache(id)
+//		if err == nil {
+//			return status == common.UserStatusEnabled, nil
+//		}
+//		// Don't return error - fall through to DB
+//	}
+//	fromDB = true
+//	var user User
+//	err = DB.Where("id = ?", id).Select("status").Find(&user).Error
+//	if err != nil {
+//		return false, err
+//	}
+//
+//	return user.Status == common.UserStatusEnabled, nil
+//}
 
 
 func ValidateAccessToken(token string) (user *User) {
 func ValidateAccessToken(token string) (user *User) {
 	if token == "" {
 	if token == "" {
@@ -639,7 +639,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
 	return common.StrToMap(setting), nil
 	return common.StrToMap(setting), nil
 }
 }
 
 
-func IncreaseUserQuota(id int, quota int) (err error) {
+func IncreaseUserQuota(id int, quota int, db bool) (err error) {
 	if quota < 0 {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 		return errors.New("quota 不能为负数!")
 	}
 	}
@@ -649,7 +649,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
 			common.SysError("failed to increase user quota: " + err.Error())
 			common.SysError("failed to increase user quota: " + err.Error())
 		}
 		}
 	})
 	})
-	if common.BatchUpdateEnabled {
+	if !db && common.BatchUpdateEnabled {
 		addNewRecord(BatchUpdateTypeUserQuota, id, quota)
 		addNewRecord(BatchUpdateTypeUserQuota, id, quota)
 		return nil
 		return nil
 	}
 	}
@@ -694,7 +694,7 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
 		return nil
 		return nil
 	}
 	}
 	if delta > 0 {
 	if delta > 0 {
-		return IncreaseUserQuota(id, delta)
+		return IncreaseUserQuota(id, delta, false)
 	} else {
 	} else {
 		return DecreaseUserQuota(id, -delta)
 		return DecreaseUserQuota(id, -delta)
 	}
 	}

+ 10 - 0
model/user_cache.go

@@ -3,6 +3,7 @@ package model
 import (
 import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
+	"github.com/gin-gonic/gin"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"time"
 	"time"
@@ -21,6 +22,15 @@ type UserBase struct {
 	Setting  string `json:"setting"`
 	Setting  string `json:"setting"`
 }
 }
 
 
+func (user *UserBase) WriteContext(c *gin.Context) {
+	c.Set(constant.ContextKeyUserGroup, user.Group)
+	c.Set(constant.ContextKeyUserQuota, user.Quota)
+	c.Set(constant.ContextKeyUserStatus, user.Status)
+	c.Set(constant.ContextKeyUserEmail, user.Email)
+	c.Set("username", user.Username)
+	c.Set(constant.ContextKeyUserSetting, user.GetSetting())
+}
+
 func (user *UserBase) GetSetting() map[string]interface{} {
 func (user *UserBase) GetSetting() map[string]interface{} {
 	if user.Setting == "" {
 	if user.Setting == "" {
 		return nil
 		return nil

+ 1 - 1
relay/channel/api_request.go

@@ -130,7 +130,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("setup request header failed: %w", err)
 		return nil, fmt.Errorf("setup request header failed: %w", err)
 	}
 	}
-	resp, err := doRequest(c, req, info.ToRelayInfo())
+	resp, err := doRequest(c, req, info.RelayInfo)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("do request failed: %w", err)
 		return nil, fmt.Errorf("do request failed: %w", err)
 	}
 	}

+ 8 - 54
relay/common/relay_info.go

@@ -50,6 +50,9 @@ type RelayInfo struct {
 	AudioUsage           bool
 	AudioUsage           bool
 	ReasoningEffort      string
 	ReasoningEffort      string
 	ChannelSetting       map[string]interface{}
 	ChannelSetting       map[string]interface{}
+	UserSetting          map[string]interface{}
+	UserEmail            string
+	UserQuota            int
 }
 }
 
 
 // 定义支持流式选项的通道类型
 // 定义支持流式选项的通道类型
@@ -89,6 +92,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 	apiType, _ := relayconstant.ChannelType2APIType(channelType)
 	apiType, _ := relayconstant.ChannelType2APIType(channelType)
 
 
 	info := &RelayInfo{
 	info := &RelayInfo{
+		UserQuota:         c.GetInt(constant.ContextKeyUserQuota),
+		UserSetting:       c.GetStringMap(constant.ContextKeyUserSetting),
+		UserEmail:         c.GetString(constant.ContextKeyUserEmail),
 		IsFirstResponse:   true,
 		IsFirstResponse:   true,
 		RelayMode:         relayconstant.Path2RelayMode(c.Request.URL.Path),
 		RelayMode:         relayconstant.Path2RelayMode(c.Request.URL.Path),
 		BaseUrl:           c.GetString("base_url"),
 		BaseUrl:           c.GetString("base_url"),
@@ -148,19 +154,7 @@ func (info *RelayInfo) SetFirstResponseTime() {
 }
 }
 
 
 type TaskRelayInfo struct {
 type TaskRelayInfo struct {
-	ChannelType       int
-	ChannelId         int
-	TokenId           int
-	UserId            int
-	Group             string
-	StartTime         time.Time
-	ApiType           int
-	RelayMode         int
-	UpstreamModelName string
-	RequestURLPath    string
-	ApiKey            string
-	BaseUrl           string
-
+	*RelayInfo
 	Action       string
 	Action       string
 	OriginTaskID string
 	OriginTaskID string
 
 
@@ -168,48 +162,8 @@ type TaskRelayInfo struct {
 }
 }
 
 
 func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
 func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
-	channelType := c.GetInt("channel_type")
-	channelId := c.GetInt("channel_id")
-
-	tokenId := c.GetInt("token_id")
-	userId := c.GetInt("id")
-	group := c.GetString("group")
-	startTime := time.Now()
-
-	apiType, _ := relayconstant.ChannelType2APIType(channelType)
-
 	info := &TaskRelayInfo{
 	info := &TaskRelayInfo{
-		RelayMode:      relayconstant.Path2RelayMode(c.Request.URL.Path),
-		BaseUrl:        c.GetString("base_url"),
-		RequestURLPath: c.Request.URL.String(),
-		ChannelType:    channelType,
-		ChannelId:      channelId,
-		TokenId:        tokenId,
-		UserId:         userId,
-		Group:          group,
-		StartTime:      startTime,
-		ApiType:        apiType,
-		ApiKey:         strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
-	}
-	if info.BaseUrl == "" {
-		info.BaseUrl = common.ChannelBaseURLs[channelType]
+		RelayInfo: GenRelayInfo(c),
 	}
 	}
 	return info
 	return info
 }
 }
-
-func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
-	return &RelayInfo{
-		ChannelType:       info.ChannelType,
-		ChannelId:         info.ChannelId,
-		TokenId:           info.TokenId,
-		UserId:            info.UserId,
-		Group:             info.Group,
-		StartTime:         info.StartTime,
-		ApiType:           info.ApiType,
-		RelayMode:         info.RelayMode,
-		UpstreamModelName: info.UpstreamModelName,
-		RequestURLPath:    info.RequestURLPath,
-		ApiKey:            info.ApiKey,
-		BaseUrl:           info.BaseUrl,
-	}
-}

+ 6 - 7
relay/relay-mj.go

@@ -2,7 +2,6 @@ package relay
 
 
 import (
 import (
 	"bytes"
 	"bytes"
-	"context"
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
@@ -192,7 +191,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
 	if err != nil {
 	if err != nil {
 		return &mjResp.Response
 		return &mjResp.Response
 	}
 	}
-	defer func(ctx context.Context) {
+	defer func() {
 		if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
 		if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
 			err := service.PostConsumeQuota(relayInfo, quota, 0, true)
 			err := service.PostConsumeQuota(relayInfo, quota, 0, true)
 			if err != nil {
 			if err != nil {
@@ -208,14 +207,14 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
 				other := make(map[string]interface{})
 				other := make(map[string]interface{})
 				other["model_price"] = modelPrice
 				other["model_price"] = modelPrice
 				other["group_ratio"] = groupRatio
 				other["group_ratio"] = groupRatio
-				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
+				model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
 					quota, logContent, tokenId, userQuota, 0, false, group, other)
 					quota, logContent, tokenId, userQuota, 0, false, group, other)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				channelId := c.GetInt("channel_id")
 				channelId := c.GetInt("channel_id")
 				model.UpdateChannelUsedQuota(channelId, quota)
 				model.UpdateChannelUsedQuota(channelId, quota)
 			}
 			}
 		}
 		}
-	}(c.Request.Context())
+	}()
 	midjResponse := &mjResp.Response
 	midjResponse := &mjResp.Response
 	midjourneyTask := &model.Midjourney{
 	midjourneyTask := &model.Midjourney{
 		UserId:      userId,
 		UserId:      userId,
@@ -498,7 +497,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 	}
 	}
 	midjResponse := &midjResponseWithStatus.Response
 	midjResponse := &midjResponseWithStatus.Response
 
 
-	defer func(ctx context.Context) {
+	defer func() {
 		if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
 		if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
 			err := service.PostConsumeQuota(relayInfo, quota, 0, true)
 			err := service.PostConsumeQuota(relayInfo, quota, 0, true)
 			if err != nil {
 			if err != nil {
@@ -510,14 +509,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 				other := make(map[string]interface{})
 				other := make(map[string]interface{})
 				other["model_price"] = modelPrice
 				other["model_price"] = modelPrice
 				other["group_ratio"] = groupRatio
 				other["group_ratio"] = groupRatio
-				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
+				model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
 					quota, logContent, tokenId, userQuota, 0, false, group, other)
 					quota, logContent, tokenId, userQuota, 0, false, group, other)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				channelId := c.GetInt("channel_id")
 				channelId := c.GetInt("channel_id")
 				model.UpdateChannelUsedQuota(channelId, quota)
 				model.UpdateChannelUsedQuota(channelId, quota)
 			}
 			}
 		}
 		}
-	}(c.Request.Context())
+	}()
 
 
 	// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
 	// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
 	//1-提交成功
 	//1-提交成功

+ 2 - 1
relay/relay-text.go

@@ -248,6 +248,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
 	if userQuota-preConsumedQuota < 0 {
 	if userQuota-preConsumedQuota < 0 {
 		return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
 		return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
 	}
 	}
+	relayInfo.UserQuota = userQuota
 	if userQuota > 100*preConsumedQuota {
 	if userQuota > 100*preConsumedQuota {
 		// 用户额度充足,判断令牌额度是否充足
 		// 用户额度充足,判断令牌额度是否充足
 		if !relayInfo.TokenUnlimited {
 		if !relayInfo.TokenUnlimited {
@@ -267,7 +268,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
 	}
 	}
 
 
 	if preConsumedQuota > 0 {
 	if preConsumedQuota > 0 {
-		err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
+		err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
 		if err != nil {
 		if err != nil {
 			return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 			return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 		}
 		}

+ 4 - 5
relay/relay_task.go

@@ -2,7 +2,6 @@ package relay
 
 
 import (
 import (
 	"bytes"
 	"bytes"
-	"context"
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
@@ -109,11 +108,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
 		return
 		return
 	}
 	}
 
 
-	defer func(ctx context.Context) {
+	defer func() {
 		// release quota
 		// release quota
 		if relayInfo.ConsumeQuota && taskErr == nil {
 		if relayInfo.ConsumeQuota && taskErr == nil {
 
 
-			err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
+			err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
 			if err != nil {
 			if err != nil {
 				common.SysError("error consuming token remain quota: " + err.Error())
 				common.SysError("error consuming token remain quota: " + err.Error())
 			}
 			}
@@ -123,13 +122,13 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
 				other := make(map[string]interface{})
 				other := make(map[string]interface{})
 				other["model_price"] = modelPrice
 				other["model_price"] = modelPrice
 				other["group_ratio"] = groupRatio
 				other["group_ratio"] = groupRatio
-				model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
+				model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
 					modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
 					modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
 				model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 				model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 				model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 				model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 			}
 			}
 		}
 		}
-	}(c.Request.Context())
+	}()
 
 
 	taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
 	taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
 	if taskErr != nil {
 	if taskErr != nil {

+ 7 - 11
service/quota.go

@@ -276,7 +276,7 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
 	if quota > 0 {
 	if quota > 0 {
 		err = model.DecreaseUserQuota(relayInfo.UserId, quota)
 		err = model.DecreaseUserQuota(relayInfo.UserId, quota)
 	} else {
 	} else {
-		err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
+		err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
 	}
 	}
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -295,20 +295,16 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
 
 
 	if sendEmail {
 	if sendEmail {
 		if (quota + preConsumedQuota) != 0 {
 		if (quota + preConsumedQuota) != 0 {
-			checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
+			checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
 		}
 		}
 	}
 	}
 
 
 	return nil
 	return nil
 }
 }
 
 
-func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
+func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
 	gopool.Go(func() {
 	gopool.Go(func() {
-		userCache, err := model.GetUserCache(userId)
-		if err != nil {
-			common.SysError("failed to get user cache: " + err.Error())
-		}
-		userSetting := userCache.GetSetting()
+		userSetting := relayInfo.UserSetting
 		threshold := common.QuotaRemindThreshold
 		threshold := common.QuotaRemindThreshold
 		if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
 		if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
 			threshold = int(userCustomThreshold.(float64))
 			threshold = int(userCustomThreshold.(float64))
@@ -317,16 +313,16 @@ func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
 		//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
 		//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
 		quotaTooLow := false
 		quotaTooLow := false
 		consumeQuota := quota + preConsumedQuota
 		consumeQuota := quota + preConsumedQuota
-		if userCache.Quota-consumeQuota < threshold {
+		if relayInfo.UserQuota-consumeQuota < threshold {
 			quotaTooLow = true
 			quotaTooLow = true
 		}
 		}
 		if quotaTooLow {
 		if quotaTooLow {
 			prompt := "您的额度即将用尽"
 			prompt := "您的额度即将用尽"
 			topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
 			topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
 			content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
 			content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
-			err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
+			err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
 			if err != nil {
 			if err != nil {
-				common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
+				common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
 			}
 			}
 		}
 		}
 	})
 	})

+ 7 - 9
service/user_notify.go

@@ -11,47 +11,45 @@ import (
 
 
 func NotifyRootUser(t string, subject string, content string) {
 func NotifyRootUser(t string, subject string, content string) {
 	user := model.GetRootUser().ToBaseUser()
 	user := model.GetRootUser().ToBaseUser()
-	_ = NotifyUser(user, dto.NewNotify(t, subject, content, nil))
+	_ = NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
 }
 }
 
 
-func NotifyUser(user *model.UserBase, data dto.Notify) error {
-	userSetting := user.GetSetting()
+func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
 	notifyType, ok := userSetting[constant.UserSettingNotifyType]
 	notifyType, ok := userSetting[constant.UserSettingNotifyType]
 	if !ok {
 	if !ok {
 		notifyType = constant.NotifyTypeEmail
 		notifyType = constant.NotifyTypeEmail
 	}
 	}
 
 
 	// Check notification limit
 	// Check notification limit
-	canSend, err := CheckNotificationLimit(user.Id, data.Type)
+	canSend, err := CheckNotificationLimit(userId, data.Type)
 	if err != nil {
 	if err != nil {
 		common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
 		common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
 		return err
 		return err
 	}
 	}
 	if !canSend {
 	if !canSend {
-		return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType)
+		return fmt.Errorf("notification limit exceeded for user %d with type %s", userId, notifyType)
 	}
 	}
 
 
 	switch notifyType {
 	switch notifyType {
 	case constant.NotifyTypeEmail:
 	case constant.NotifyTypeEmail:
-		userEmail := user.Email
 		// check setting email
 		// check setting email
 		if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
 		if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
 			userEmail = settingEmail.(string)
 			userEmail = settingEmail.(string)
 		}
 		}
 		if userEmail == "" {
 		if userEmail == "" {
-			common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id))
+			common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
 			return nil
 			return nil
 		}
 		}
 		return sendEmailNotify(userEmail, data)
 		return sendEmailNotify(userEmail, data)
 	case constant.NotifyTypeWebhook:
 	case constant.NotifyTypeWebhook:
 		webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
 		webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
 		if !ok {
 		if !ok {
-			common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id))
+			common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
 			return nil
 			return nil
 		}
 		}
 		webhookURLStr, ok := webhookURL.(string)
 		webhookURLStr, ok := webhookURL.(string)
 		if !ok {
 		if !ok {
-			common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id))
+			common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId))
 			return nil
 			return nil
 		}
 		}