Просмотр исходного кода

refactor: improve relay's implementation

JustSong 2 лет назад
Родитель
Сommit
23ec541ba6
1 измененных файлов с 36 добавлено и 71 удалено
  1. 36 71
      controller/relay.go

+ 36 - 71
controller/relay.go

@@ -45,6 +45,18 @@ type StreamResponse struct {
 }
 }
 
 
 func Relay(c *gin.Context) {
 func Relay(c *gin.Context) {
+	err := relayHelper(c)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"error": gin.H{
+				"message": err.Error(),
+				"type":    "one_api_error",
+			},
+		})
+	}
+}
+
+func relayHelper(c *gin.Context) error {
 	channelType := c.GetInt("channel")
 	channelType := c.GetInt("channel")
 	tokenId := c.GetInt("token_id")
 	tokenId := c.GetInt("token_id")
 	consumeQuota := c.GetBool("consume_quota")
 	consumeQuota := c.GetBool("consume_quota")
@@ -54,47 +66,27 @@ func Relay(c *gin.Context) {
 	}
 	}
 	requestBody, err := io.ReadAll(c.Request.Body)
 	requestBody, err := io.ReadAll(c.Request.Body)
 	if err != nil {
 	if err != nil {
-		c.JSON(http.StatusOK, gin.H{
-			"error": gin.H{
-				"message": err.Error(),
-				"type":    "one_api_error",
-			},
-		})
-		return
+		return err
 	}
 	}
 	err = c.Request.Body.Close()
 	err = c.Request.Body.Close()
 	if err != nil {
 	if err != nil {
-		c.JSON(http.StatusOK, gin.H{
-			"error": gin.H{
-				"message": err.Error(),
-				"type":    "one_api_error",
-			},
-		})
-		return
+		return err
 	}
 	}
 	var textRequest TextRequest
 	var textRequest TextRequest
 	err = json.Unmarshal(requestBody, &textRequest)
 	err = json.Unmarshal(requestBody, &textRequest)
 	if err != nil {
 	if err != nil {
-		c.JSON(http.StatusOK, gin.H{
-			"error": gin.H{
-				"message": err.Error(),
-				"type":    "one_api_error",
-			},
-		})
-		return
+		return err
 	}
 	}
 	// Reset request body
 	// Reset request body
 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 	requestURL := c.Request.URL.String()
 	requestURL := c.Request.URL.String()
 	req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, requestURL), c.Request.Body)
 	req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, requestURL), c.Request.Body)
 	if err != nil {
 	if err != nil {
-		c.JSON(http.StatusOK, gin.H{
-			"error": gin.H{
-				"message": err.Error(),
-				"type":    "one_api_error",
-			},
-		})
-		return
+		return err
+	}
+	err = c.Request.Body.Close()
+	if err != nil {
+		return err
 	}
 	}
 	req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 	req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
@@ -103,23 +95,11 @@ func Relay(c *gin.Context) {
 	client := &http.Client{}
 	client := &http.Client{}
 	resp, err := client.Do(req)
 	resp, err := client.Do(req)
 	if err != nil {
 	if err != nil {
-		c.JSON(http.StatusOK, gin.H{
-			"error": gin.H{
-				"message": err.Error(),
-				"type":    "one_api_error",
-			},
-		})
-		return
+		return err
 	}
 	}
 	err = req.Body.Close()
 	err = req.Body.Close()
 	if err != nil {
 	if err != nil {
-		c.JSON(http.StatusOK, gin.H{
-			"error": gin.H{
-				"message": err.Error(),
-				"type":    "one_api_error",
-			},
-		})
-		return
+		return err
 	}
 	}
 
 
 	var textResponse TextResponse
 	var textResponse TextResponse
@@ -192,53 +172,38 @@ func Relay(c *gin.Context) {
 				return false
 				return false
 			}
 			}
 		})
 		})
-		return
+		err = resp.Body.Close()
+		if err != nil {
+			return err
+		}
+		return nil
 	} else {
 	} else {
 		for k, v := range resp.Header {
 		for k, v := range resp.Header {
 			c.Writer.Header().Set(k, v[0])
 			c.Writer.Header().Set(k, v[0])
 		}
 		}
 		responseBody, err := io.ReadAll(resp.Body)
 		responseBody, err := io.ReadAll(resp.Body)
 		if err != nil {
 		if err != nil {
-			c.JSON(http.StatusOK, gin.H{
-				"error": gin.H{
-					"message": err.Error(),
-					"type":    "one_api_error",
-				},
-			})
-			return
+			return err
 		}
 		}
 		err = resp.Body.Close()
 		err = resp.Body.Close()
 		if err != nil {
 		if err != nil {
-			c.JSON(http.StatusOK, gin.H{
-				"error": gin.H{
-					"message": err.Error(),
-					"type":    "one_api_error",
-				},
-			})
-			return
+			return err
 		}
 		}
 		err = json.Unmarshal(responseBody, &textResponse)
 		err = json.Unmarshal(responseBody, &textResponse)
 		if err != nil {
 		if err != nil {
-			c.JSON(http.StatusOK, gin.H{
-				"error": gin.H{
-					"message": err.Error(),
-					"type":    "one_api_error",
-				},
-			})
-			return
+			return err
 		}
 		}
 		// Reset response body
 		// Reset response body
 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 		_, err = io.Copy(c.Writer, resp.Body)
 		_, err = io.Copy(c.Writer, resp.Body)
 		if err != nil {
 		if err != nil {
-			c.JSON(http.StatusOK, gin.H{
-				"error": gin.H{
-					"message": err.Error(),
-					"type":    "one_api_error",
-				},
-			})
-			return
+			return err
+		}
+		err = resp.Body.Close()
+		if err != nil {
+			return err
 		}
 		}
+		return nil
 	}
 	}
 }
 }