api_request.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. package channel
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. common2 "one-api/common"
  9. "one-api/logger"
  10. "one-api/relay/common"
  11. "one-api/relay/constant"
  12. "one-api/relay/helper"
  13. "one-api/service"
  14. "one-api/setting/operation_setting"
  15. "sync"
  16. "time"
  17. "github.com/bytedance/gopkg/util/gopool"
  18. "github.com/gin-gonic/gin"
  19. "github.com/gorilla/websocket"
  20. )
  21. func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
  22. if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
  23. // multipart/form-data
  24. } else if info.RelayMode == constant.RelayModeRealtime {
  25. // websocket
  26. } else {
  27. req.Set("Content-Type", c.Request.Header.Get("Content-Type"))
  28. req.Set("Accept", c.Request.Header.Get("Accept"))
  29. if info.IsStream && c.Request.Header.Get("Accept") == "" {
  30. req.Set("Accept", "text/event-stream")
  31. }
  32. }
  33. }
  34. func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
  35. fullRequestURL, err := a.GetRequestURL(info)
  36. if err != nil {
  37. return nil, fmt.Errorf("get request url failed: %w", err)
  38. }
  39. if common2.DebugEnabled {
  40. println("fullRequestURL:", fullRequestURL)
  41. }
  42. req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
  43. if err != nil {
  44. return nil, fmt.Errorf("new request failed: %w", err)
  45. }
  46. err = a.SetupRequestHeader(c, &req.Header, info)
  47. if err != nil {
  48. return nil, fmt.Errorf("setup request header failed: %w", err)
  49. }
  50. resp, err := doRequest(c, req, info)
  51. if err != nil {
  52. return nil, fmt.Errorf("do request failed: %w", err)
  53. }
  54. return resp, nil
  55. }
  56. func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
  57. fullRequestURL, err := a.GetRequestURL(info)
  58. if err != nil {
  59. return nil, fmt.Errorf("get request url failed: %w", err)
  60. }
  61. if common2.DebugEnabled {
  62. println("fullRequestURL:", fullRequestURL)
  63. }
  64. req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
  65. if err != nil {
  66. return nil, fmt.Errorf("new request failed: %w", err)
  67. }
  68. // set form data
  69. req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
  70. err = a.SetupRequestHeader(c, &req.Header, info)
  71. if err != nil {
  72. return nil, fmt.Errorf("setup request header failed: %w", err)
  73. }
  74. resp, err := doRequest(c, req, info)
  75. if err != nil {
  76. return nil, fmt.Errorf("do request failed: %w", err)
  77. }
  78. return resp, nil
  79. }
  80. func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) {
  81. fullRequestURL, err := a.GetRequestURL(info)
  82. if err != nil {
  83. return nil, fmt.Errorf("get request url failed: %w", err)
  84. }
  85. targetHeader := http.Header{}
  86. err = a.SetupRequestHeader(c, &targetHeader, info)
  87. if err != nil {
  88. return nil, fmt.Errorf("setup request header failed: %w", err)
  89. }
  90. targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
  91. targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
  92. if err != nil {
  93. return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err)
  94. }
  95. // send request body
  96. //all, err := io.ReadAll(requestBody)
  97. //err = service.WssString(c, targetConn, string(all))
  98. return targetConn, nil
  99. }
  100. func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc {
  101. pingerCtx, stopPinger := context.WithCancel(context.Background())
  102. gopool.Go(func() {
  103. defer func() {
  104. // 增加panic恢复处理
  105. if r := recover(); r != nil {
  106. if common2.DebugEnabled {
  107. println("SSE ping goroutine panic recovered:", fmt.Sprintf("%v", r))
  108. }
  109. }
  110. if common2.DebugEnabled {
  111. println("SSE ping goroutine stopped.")
  112. }
  113. }()
  114. if pingInterval <= 0 {
  115. pingInterval = helper.DefaultPingInterval
  116. }
  117. ticker := time.NewTicker(pingInterval)
  118. // 确保在任何情况下都清理ticker
  119. defer func() {
  120. ticker.Stop()
  121. if common2.DebugEnabled {
  122. println("SSE ping ticker stopped")
  123. }
  124. }()
  125. var pingMutex sync.Mutex
  126. if common2.DebugEnabled {
  127. println("SSE ping goroutine started")
  128. }
  129. // 增加超时控制,防止goroutine长时间运行
  130. maxPingDuration := 120 * time.Minute // 最大ping持续时间
  131. pingTimeout := time.NewTimer(maxPingDuration)
  132. defer pingTimeout.Stop()
  133. for {
  134. select {
  135. // 发送 ping 数据
  136. case <-ticker.C:
  137. if err := sendPingData(c, &pingMutex); err != nil {
  138. if common2.DebugEnabled {
  139. println("SSE ping error, stopping goroutine:", err.Error())
  140. }
  141. return
  142. }
  143. // 收到退出信号
  144. case <-pingerCtx.Done():
  145. return
  146. // request 结束
  147. case <-c.Request.Context().Done():
  148. return
  149. // 超时保护,防止goroutine无限运行
  150. case <-pingTimeout.C:
  151. if common2.DebugEnabled {
  152. println("SSE ping goroutine timeout, stopping")
  153. }
  154. return
  155. }
  156. }
  157. })
  158. return stopPinger
  159. }
  160. func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
  161. // 增加超时控制,防止锁死等待
  162. done := make(chan error, 1)
  163. go func() {
  164. mutex.Lock()
  165. defer mutex.Unlock()
  166. err := helper.PingData(c)
  167. if err != nil {
  168. logger.LogError(c, "SSE ping error: "+err.Error())
  169. done <- err
  170. return
  171. }
  172. if common2.DebugEnabled {
  173. println("SSE ping data sent.")
  174. }
  175. done <- nil
  176. }()
  177. // 设置发送ping数据的超时时间
  178. select {
  179. case err := <-done:
  180. return err
  181. case <-time.After(10 * time.Second):
  182. return errors.New("SSE ping data send timeout")
  183. case <-c.Request.Context().Done():
  184. return errors.New("request context cancelled during ping")
  185. }
  186. }
  187. func DoRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
  188. return doRequest(c, req, info)
  189. }
  190. func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
  191. var client *http.Client
  192. var err error
  193. if info.ChannelSetting.Proxy != "" {
  194. client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
  195. if err != nil {
  196. return nil, fmt.Errorf("new proxy http client failed: %w", err)
  197. }
  198. } else {
  199. client = service.GetHttpClient()
  200. }
  201. var stopPinger context.CancelFunc
  202. if info.IsStream {
  203. helper.SetEventStreamHeaders(c)
  204. // 处理流式请求的 ping 保活
  205. generalSettings := operation_setting.GetGeneralSetting()
  206. if generalSettings.PingIntervalEnabled && !info.DisablePing {
  207. pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
  208. stopPinger = startPingKeepAlive(c, pingInterval)
  209. // 使用defer确保在任何情况下都能停止ping goroutine
  210. defer func() {
  211. if stopPinger != nil {
  212. stopPinger()
  213. if common2.DebugEnabled {
  214. println("SSE ping goroutine stopped by defer")
  215. }
  216. }
  217. }()
  218. }
  219. }
  220. resp, err := client.Do(req)
  221. if err != nil {
  222. return nil, err
  223. }
  224. if resp == nil {
  225. return nil, errors.New("resp is nil")
  226. }
  227. _ = req.Body.Close()
  228. _ = c.Request.Body.Close()
  229. return resp, nil
  230. }
  231. func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
  232. fullRequestURL, err := a.BuildRequestURL(info)
  233. if err != nil {
  234. return nil, err
  235. }
  236. req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
  237. if err != nil {
  238. return nil, fmt.Errorf("new request failed: %w", err)
  239. }
  240. req.GetBody = func() (io.ReadCloser, error) {
  241. return io.NopCloser(requestBody), nil
  242. }
  243. err = a.BuildRequestHeader(c, req, info)
  244. if err != nil {
  245. return nil, fmt.Errorf("setup request header failed: %w", err)
  246. }
  247. resp, err := doRequest(c, req, info.RelayInfo)
  248. if err != nil {
  249. return nil, fmt.Errorf("do request failed: %w", err)
  250. }
  251. return resp, nil
  252. }