relay_utils.go 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. package common
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. _ "image/gif"
  7. _ "image/jpeg"
  8. _ "image/png"
  9. "io"
  10. "net/http"
  11. "one-api/common"
  12. "one-api/dto"
  13. "strconv"
  14. "strings"
  15. )
  16. var StopFinishReason = "stop"
  17. func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
  18. OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
  19. StatusCode: resp.StatusCode,
  20. Error: dto.OpenAIError{
  21. Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
  22. Type: "upstream_error",
  23. Code: "bad_response_status_code",
  24. Param: strconv.Itoa(resp.StatusCode),
  25. },
  26. }
  27. responseBody, err := io.ReadAll(resp.Body)
  28. if err != nil {
  29. return
  30. }
  31. err = resp.Body.Close()
  32. if err != nil {
  33. return
  34. }
  35. var textResponse dto.TextResponse
  36. err = json.Unmarshal(responseBody, &textResponse)
  37. if err != nil {
  38. return
  39. }
  40. OpenAIErrorWithStatusCode.Error = *textResponse.Error
  41. return
  42. }
  43. func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
  44. fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
  45. if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
  46. switch channelType {
  47. case common.ChannelTypeOpenAI:
  48. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
  49. case common.ChannelTypeAzure:
  50. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
  51. }
  52. }
  53. return fullRequestURL
  54. }
  55. func GetAPIVersion(c *gin.Context) string {
  56. query := c.Request.URL.Query()
  57. apiVersion := query.Get("api-version")
  58. if apiVersion == "" {
  59. apiVersion = c.GetString("api_version")
  60. }
  61. return apiVersion
  62. }