error.go 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. package service
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. "strconv"
  10. "strings"
  11. )
  12. // OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
  13. func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
  14. text := err.Error()
  15. // 定义一个正则表达式匹配URL
  16. if strings.Contains(text, "Post") {
  17. common.SysLog(fmt.Sprintf("error: %s", text))
  18. text = "请求上游地址失败"
  19. }
  20. //避免暴露内部错误
  21. openAIError := dto.OpenAIError{
  22. Message: text,
  23. Type: "new_api_error",
  24. Code: code,
  25. }
  26. return &dto.OpenAIErrorWithStatusCode{
  27. Error: openAIError,
  28. StatusCode: statusCode,
  29. }
  30. }
  31. func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
  32. errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
  33. StatusCode: resp.StatusCode,
  34. Error: dto.OpenAIError{
  35. Message: "",
  36. Type: "upstream_error",
  37. Code: "bad_response_status_code",
  38. Param: strconv.Itoa(resp.StatusCode),
  39. },
  40. }
  41. responseBody, err := io.ReadAll(resp.Body)
  42. if err != nil {
  43. return
  44. }
  45. err = resp.Body.Close()
  46. if err != nil {
  47. return
  48. }
  49. var errResponse dto.GeneralErrorResponse
  50. err = json.Unmarshal(responseBody, &errResponse)
  51. if err != nil {
  52. return
  53. }
  54. if errResponse.Error.Message != "" {
  55. // OpenAI format error, so we override the default one
  56. errWithStatusCode.Error = errResponse.Error
  57. } else {
  58. errWithStatusCode.Error.Message = errResponse.ToMessage()
  59. }
  60. if errWithStatusCode.Error.Message == "" {
  61. errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
  62. }
  63. return
  64. }