소스 검색

refactor: update logging related logic

JustSong 2 년 전
부모
커밋
42451d9d02
16개의 변경된 파일149개의 추가작업 그리고 93개의 파일을 삭제
  1. 2 1
      .gitignore
  2. 4 0
      common/constants.go
  3. 1 1
      common/init.go
  4. 33 9
      common/logger.go
  5. 9 0
      common/utils.go
  6. 4 3
      controller/relay-audio.go
  7. 4 3
      controller/relay-image.go
  8. 10 8
      controller/relay-text.go
  9. 3 1
      controller/relay.go
  10. 5 2
      main.go
  11. 4 28
      middleware/auth.go
  12. 5 35
      middleware/distributor.go
  13. 25 0
      middleware/logger.go
  14. 18 0
      middleware/request-id.go
  15. 17 0
      middleware/utils.go
  16. 5 2
      model/log.go

+ 2 - 1
.gitignore

@@ -4,4 +4,5 @@ upload
 *.exe
 *.db
 build
-*.db-journal
+*.db-journal
+logs

+ 4 - 0
common/constants.go

@@ -97,6 +97,10 @@ var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQU
 var BatchUpdateEnabled = false
 var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
 
+const (
+	RequestIdKey = "X-Oneapi-Request-Id"
+)
+
 const (
 	RoleGuestUser  = 0
 	RoleCommonUser = 1

+ 1 - 1
common/init.go

@@ -12,7 +12,7 @@ var (
 	Port         = flag.Int("port", 3000, "the listening port")
 	PrintVersion = flag.Bool("version", false, "print version and exit")
 	PrintHelp    = flag.Bool("help", false, "print help and exit")
-	LogDir       = flag.String("log-dir", "", "specify the log directory")
+	LogDir       = flag.String("log-dir", "./logs", "specify the log directory")
 )
 
 func printHelp() {

+ 33 - 9
common/logger.go

@@ -1,6 +1,7 @@
 package common
 
 import (
+	"context"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"io"
@@ -10,20 +11,21 @@ import (
 	"time"
 )
 
+const (
+	loggerINFO  = "INFO"
+	loggerWarn  = "WARN"
+	loggerError = "ERR"
+)
+
 func SetupGinLog() {
 	if *LogDir != "" {
-		commonLogPath := filepath.Join(*LogDir, "common.log")
-		errorLogPath := filepath.Join(*LogDir, "error.log")
-		commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
+		logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
+		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 		if err != nil {
 			log.Fatal("failed to open log file")
 		}
-		errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
-		if err != nil {
-			log.Fatal("failed to open log file")
-		}
-		gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
-		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
+		gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
+		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
 	}
 }
 
@@ -37,6 +39,28 @@ func SysError(s string) {
 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 }
 
+func LogInfo(ctx context.Context, msg string) {
+	logHelper(ctx, loggerINFO, msg)
+}
+
+func LogWarn(ctx context.Context, msg string) {
+	logHelper(ctx, loggerWarn, msg)
+}
+
+func LogError(ctx context.Context, msg string) {
+	logHelper(ctx, loggerError, msg)
+}
+
+func logHelper(ctx context.Context, level string, msg string) {
+	writer := gin.DefaultErrorWriter
+	if level == loggerINFO {
+		writer = gin.DefaultWriter
+	}
+	id := ctx.Value(RequestIdKey)
+	now := time.Now()
+	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
+}
+
 func FatalLog(v ...any) {
 	t := time.Now()
 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)

+ 9 - 0
common/utils.go

@@ -171,6 +171,11 @@ func GetTimestamp() int64 {
 	return time.Now().Unix()
 }
 
+func GetTimeString() string {
+	now := time.Now()
+	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
+}
+
 func Max(a int, b int) int {
 	if a >= b {
 		return a
@@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int {
 	}
 	return num
 }
+
+func MessageWithRequestId(message string, id string) string {
+	return fmt.Sprintf("%s (request id: %s)", message, id)
+}

+ 4 - 3
controller/relay-audio.go

@@ -2,6 +2,7 @@ package controller
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"fmt"
 	"io"
@@ -91,7 +92,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	}
 	var audioResponse AudioResponse
 
-	defer func() {
+	defer func(ctx context.Context) {
 		go func() {
 			quota := countTokenText(audioResponse.Text, audioModel)
 			quotaDelta := quota - preConsumedQuota
@@ -106,13 +107,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 			if quota != 0 {
 				tokenName := c.GetString("token_name")
 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-				model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
+				model.RecordConsumeLog(ctx, userId, 0, 0, audioModel, tokenName, quota, logContent)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				channelId := c.GetInt("channel_id")
 				model.UpdateChannelUsedQuota(channelId, quota)
 			}
 		}()
-	}()
+	}(c.Request.Context())
 
 	responseBody, err := io.ReadAll(resp.Body)
 

+ 4 - 3
controller/relay-image.go

@@ -2,6 +2,7 @@ package controller
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -124,7 +125,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	}
 	var textResponse ImageResponse
 
-	defer func() {
+	defer func(ctx context.Context) {
 		if consumeQuota {
 			err := model.PostConsumeTokenQuota(tokenId, quota)
 			if err != nil {
@@ -137,13 +138,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 			if quota != 0 {
 				tokenName := c.GetString("token_name")
 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-				model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
+				model.RecordConsumeLog(ctx, userId, 0, 0, imageModel, tokenName, quota, logContent)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				channelId := c.GetInt("channel_id")
 				model.UpdateChannelUsedQuota(channelId, quota)
 			}
 		}
-	}()
+	}(c.Request.Context())
 
 	if consumeQuota {
 		responseBody, err := io.ReadAll(resp.Body)

+ 10 - 8
controller/relay-text.go

@@ -2,6 +2,7 @@ package controller
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -210,6 +211,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		// in this case, we do not pre-consume quota
 		// because the user has enough quota
 		preConsumedQuota = 0
+		common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
 	}
 	if consumeQuota && preConsumedQuota > 0 {
 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
@@ -348,13 +350,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 
 		if resp.StatusCode != http.StatusOK {
 			if preConsumedQuota != 0 {
-				go func() {
+				go func(ctx context.Context) {
 					// return pre-consumed quota
 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
 					if err != nil {
-						common.SysError("error return pre-consumed quota: " + err.Error())
+						common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
 					}
-				}()
+				}(c.Request.Context())
 			}
 			return relayErrorHandler(resp)
 		}
@@ -364,7 +366,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	tokenName := c.GetString("token_name")
 	channelId := c.GetInt("channel_id")
 
-	defer func() {
+	defer func(ctx context.Context) {
 		// c.Writer.Flush()
 		go func() {
 			if consumeQuota {
@@ -387,21 +389,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 				quotaDelta := quota - preConsumedQuota
 				err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 				if err != nil {
-					common.SysError("error consuming token remain quota: " + err.Error())
+					common.LogError(ctx, "error consuming token remain quota: "+err.Error())
 				}
 				err = model.CacheUpdateUserQuota(userId)
 				if err != nil {
-					common.SysError("error update user quota cache: " + err.Error())
+					common.LogError(ctx, "error update user quota cache: "+err.Error())
 				}
 				if quota != 0 {
 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-					model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
+					model.RecordConsumeLog(ctx, userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 					model.UpdateChannelUsedQuota(channelId, quota)
 				}
 			}
 		}()
-	}()
+	}(c.Request.Context())
 	switch apiType {
 	case APITypeOpenAI:
 		if isStream {

+ 3 - 1
controller/relay.go

@@ -196,6 +196,7 @@ func Relay(c *gin.Context) {
 		err = relayTextHelper(c, relayMode)
 	}
 	if err != nil {
+		requestId := c.GetString(common.RequestIdKey)
 		retryTimesStr := c.Query("retry")
 		retryTimes, _ := strconv.Atoi(retryTimesStr)
 		if retryTimesStr == "" {
@@ -207,12 +208,13 @@ func Relay(c *gin.Context) {
 			if err.StatusCode == http.StatusTooManyRequests {
 				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
 			}
+			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
 			c.JSON(err.StatusCode, gin.H{
 				"error": err.OpenAIError,
 			})
 		}
 		channelId := c.GetInt("channel_id")
-		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
+		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 		// https://platform.openai.com/docs/guides/error-codes/api-errors
 		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
 			channelId := c.GetInt("channel_id")

+ 5 - 2
main.go

@@ -7,6 +7,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"one-api/common"
 	"one-api/controller"
+	"one-api/middleware"
 	"one-api/model"
 	"one-api/router"
 	"os"
@@ -84,10 +85,12 @@ func main() {
 	controller.InitTokenEncoders()
 
 	// Initialize HTTP server
-	server := gin.Default()
+	server := gin.New()
+	server.Use(gin.Recovery())
 	// This will cause SSE not to work!!!
 	//server.Use(gzip.Gzip(gzip.DefaultCompression))
-
+	server.Use(middleware.RequestId())
+	middleware.SetUpLogger(server)
 	// Initialize session store
 	store := cookie.NewStore([]byte(common.SessionSecret))
 	server.Use(sessions.Sessions("session", store))

+ 4 - 28
middleware/auth.go

@@ -91,34 +91,16 @@ func TokenAuth() func(c *gin.Context) {
 		key = parts[0]
 		token, err := model.ValidateUserToken(key)
 		if err != nil {
-			c.JSON(http.StatusUnauthorized, gin.H{
-				"error": gin.H{
-					"message": err.Error(),
-					"type":    "one_api_error",
-				},
-			})
-			c.Abort()
+			abortWithMessage(c, http.StatusUnauthorized, err.Error())
 			return
 		}
 		userEnabled, err := model.IsUserEnabled(token.UserId)
 		if err != nil {
-			c.JSON(http.StatusInternalServerError, gin.H{
-				"error": gin.H{
-					"message": err.Error(),
-					"type":    "one_api_error",
-				},
-			})
-			c.Abort()
+			abortWithMessage(c, http.StatusInternalServerError, err.Error())
 			return
 		}
 		if !userEnabled {
-			c.JSON(http.StatusForbidden, gin.H{
-				"error": gin.H{
-					"message": "用户已被封禁",
-					"type":    "one_api_error",
-				},
-			})
-			c.Abort()
+			abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
 			return
 		}
 		c.Set("id", token.UserId)
@@ -134,13 +116,7 @@ func TokenAuth() func(c *gin.Context) {
 			if model.IsAdmin(token.UserId) {
 				c.Set("channelId", parts[1])
 			} else {
-				c.JSON(http.StatusForbidden, gin.H{
-					"error": gin.H{
-						"message": "普通用户不支持指定渠道",
-						"type":    "one_api_error",
-					},
-				})
-				c.Abort()
+				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
 				return
 			}
 		}

+ 5 - 35
middleware/distributor.go

@@ -25,34 +25,16 @@ func Distribute() func(c *gin.Context) {
 		if ok {
 			id, err := strconv.Atoi(channelId.(string))
 			if err != nil {
-				c.JSON(http.StatusBadRequest, gin.H{
-					"error": gin.H{
-						"message": "无效的渠道 ID",
-						"type":    "one_api_error",
-					},
-				})
-				c.Abort()
+				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
 				return
 			}
 			channel, err = model.GetChannelById(id, true)
 			if err != nil {
-				c.JSON(http.StatusBadRequest, gin.H{
-					"error": gin.H{
-						"message": "无效的渠道 ID",
-						"type":    "one_api_error",
-					},
-				})
-				c.Abort()
+				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
 				return
 			}
 			if channel.Status != common.ChannelStatusEnabled {
-				c.JSON(http.StatusForbidden, gin.H{
-					"error": gin.H{
-						"message": "该渠道已被禁用",
-						"type":    "one_api_error",
-					},
-				})
-				c.Abort()
+				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
 				return
 			}
 		} else {
@@ -63,13 +45,7 @@ func Distribute() func(c *gin.Context) {
 				err = common.UnmarshalBodyReusable(c, &modelRequest)
 			}
 			if err != nil {
-				c.JSON(http.StatusBadRequest, gin.H{
-					"error": gin.H{
-						"message": "无效的请求",
-						"type":    "one_api_error",
-					},
-				})
-				c.Abort()
+				abortWithMessage(c, http.StatusBadRequest, "无效的请求")
 				return
 			}
 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
@@ -99,13 +75,7 @@ func Distribute() func(c *gin.Context) {
 					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
 					message = "数据库一致性已被破坏,请联系管理员"
 				}
-				c.JSON(http.StatusServiceUnavailable, gin.H{
-					"error": gin.H{
-						"message": message,
-						"type":    "one_api_error",
-					},
-				})
-				c.Abort()
+				abortWithMessage(c, http.StatusServiceUnavailable, message)
 				return
 			}
 		}

+ 25 - 0
middleware/logger.go

@@ -0,0 +1,25 @@
+package middleware
+
+import (
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"one-api/common"
+)
+
+func SetUpLogger(server *gin.Engine) {
+	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
+		var requestID string
+		if param.Keys != nil {
+			requestID = param.Keys[common.RequestIdKey].(string)
+		}
+		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
+			param.TimeStamp.Format("2006/01/02 - 15:04:05"),
+			requestID,
+			param.StatusCode,
+			param.Latency,
+			param.ClientIP,
+			param.Method,
+			param.Path,
+		)
+	}))
+}

+ 18 - 0
middleware/request-id.go

@@ -0,0 +1,18 @@
+package middleware
+
+import (
+	"context"
+	"github.com/gin-gonic/gin"
+	"one-api/common"
+)
+
+func RequestId() func(c *gin.Context) {
+	return func(c *gin.Context) {
+		id := common.GetTimeString() + common.GetRandomString(8)
+		c.Set(common.RequestIdKey, id)
+		ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
+		c.Request = c.Request.WithContext(ctx)
+		c.Header(common.RequestIdKey, id)
+		c.Next()
+	}
+}

+ 17 - 0
middleware/utils.go

@@ -0,0 +1,17 @@
+package middleware
+
+import (
+	"github.com/gin-gonic/gin"
+	"one-api/common"
+)
+
+func abortWithMessage(c *gin.Context, statusCode int, message string) {
+	c.JSON(statusCode, gin.H{
+		"error": gin.H{
+			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
+			"type":    "one_api_error",
+		},
+	})
+	c.Abort()
+	common.LogError(c.Request.Context(), message)
+}

+ 5 - 2
model/log.go

@@ -1,6 +1,8 @@
 package model
 
 import (
+	"context"
+	"fmt"
 	"gorm.io/gorm"
 	"one-api/common"
 )
@@ -44,7 +46,8 @@ func RecordLog(userId int, logType int, content string) {
 	}
 }
 
-func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
+func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
+	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, promptTokens, completionTokens, modelName, tokenName, quota, content))
 	if !common.LogConsumeEnabled {
 		return
 	}
@@ -62,7 +65,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
 	}
 	err := DB.Create(log).Error
 	if err != nil {
-		common.SysError("failed to record log: " + err.Error())
+		common.LogError(ctx, "failed to record log: "+err.Error())
 	}
 }