relay_utils.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. package common
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. _ "image/gif"
  7. _ "image/jpeg"
  8. _ "image/png"
  9. "io"
  10. "net/http"
  11. "one-api/common"
  12. "one-api/dto"
  13. "strconv"
  14. "strings"
  15. )
  16. var StopFinishReason = "stop"
  17. func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
  18. OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
  19. StatusCode: resp.StatusCode,
  20. Error: dto.OpenAIError{
  21. Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
  22. Type: "upstream_error",
  23. Code: "bad_response_status_code",
  24. Param: strconv.Itoa(resp.StatusCode),
  25. },
  26. }
  27. responseBody, err := io.ReadAll(resp.Body)
  28. if err != nil {
  29. return
  30. }
  31. err = resp.Body.Close()
  32. if err != nil {
  33. return
  34. }
  35. var textResponse dto.TextResponseWithError
  36. err = json.Unmarshal(responseBody, &textResponse)
  37. if err != nil {
  38. OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody)
  39. return
  40. }
  41. OpenAIErrorWithStatusCode.Error = textResponse.Error
  42. return
  43. }
  44. func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
  45. fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
  46. if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
  47. switch channelType {
  48. case common.ChannelTypeOpenAI:
  49. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
  50. case common.ChannelTypeAzure:
  51. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
  52. }
  53. }
  54. return fullRequestURL
  55. }
  56. func GetAPIVersion(c *gin.Context) string {
  57. query := c.Request.URL.Query()
  58. apiVersion := query.Get("api-version")
  59. if apiVersion == "" {
  60. apiVersion = c.GetString("api_version")
  61. }
  62. return apiVersion
  63. }