relay.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package controller
  2. import (
  3. "bufio"
  4. "bytes"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "strings"
  11. )
  12. func Relay(c *gin.Context) {
  13. channelType := c.GetInt("channel")
  14. baseURL := common.ChannelBaseURLs[channelType]
  15. if channelType == common.ChannelTypeCustom {
  16. baseURL = c.GetString("base_url")
  17. }
  18. req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, c.Request.URL.String()), c.Request.Body)
  19. if err != nil {
  20. c.JSON(http.StatusOK, gin.H{
  21. "error": gin.H{
  22. "message": err.Error(),
  23. "type": "one_api_error",
  24. },
  25. })
  26. return
  27. }
  28. //req.Header = c.Request.Header.Clone()
  29. // Fix HTTP Decompression failed
  30. // https://github.com/stoplightio/prism/issues/1064#issuecomment-824682360
  31. //req.Header.Del("Accept-Encoding")
  32. req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
  33. req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
  34. req.Header.Set("Accept", c.Request.Header.Get("Accept"))
  35. req.Header.Set("Connection", c.Request.Header.Get("Connection"))
  36. client := &http.Client{}
  37. resp, err := client.Do(req)
  38. if err != nil {
  39. c.JSON(http.StatusOK, gin.H{
  40. "error": gin.H{
  41. "message": err.Error(),
  42. "type": "one_api_error",
  43. },
  44. })
  45. return
  46. }
  47. defer resp.Body.Close()
  48. scanner := bufio.NewScanner(resp.Body)
  49. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  50. if atEOF && len(data) == 0 {
  51. return 0, nil, nil
  52. }
  53. if i := strings.Index(string(data), "\n\n"); i >= 0 {
  54. return i + 2, data[0:i], nil
  55. }
  56. if atEOF {
  57. return len(data), data, nil
  58. }
  59. return 0, nil, nil
  60. })
  61. dataChan := make(chan string)
  62. stopChan := make(chan bool)
  63. go func() {
  64. for scanner.Scan() {
  65. data := scanner.Text()
  66. dataChan <- data
  67. }
  68. stopChan <- true
  69. }()
  70. for k, v := range resp.Header {
  71. c.Writer.Header().Set(k, v[0])
  72. }
  73. c.Stream(func(w io.Writer) bool {
  74. select {
  75. case data := <-dataChan:
  76. //fmt.Println(data)
  77. //c.Data(http.StatusOK, "text/event-stream", []byte(data))
  78. //c.Render(-1, common.Event{Data: data})
  79. //c.SSEvent("", data)
  80. //w.Write([]byte(data))
  81. //w.(http.Flusher).Flush()
  82. //c.Writer.Write(append([]byte(data), []byte("\n\n")...))
  83. outputBytes := bytes.NewBufferString(data)
  84. w.Write(outputBytes.Bytes())
  85. if strings.HasPrefix(data, "data: ") {
  86. w.Write([]byte("\n\n"))
  87. }
  88. //w.Write(append(outputBytes.Bytes(), []byte("\n\n")...))
  89. w.(http.Flusher).Flush()
  90. //fmt.Println(data)
  91. return true
  92. case <-stopChan:
  93. return false
  94. }
  95. })
  96. return
  97. }