channel.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. package controller
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "github.com/gin-gonic/gin"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/model"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "time"
  15. )
  16. func GetAllChannels(c *gin.Context) {
  17. p, _ := strconv.Atoi(c.Query("p"))
  18. if p < 0 {
  19. p = 0
  20. }
  21. channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
  22. if err != nil {
  23. c.JSON(http.StatusOK, gin.H{
  24. "success": false,
  25. "message": err.Error(),
  26. })
  27. return
  28. }
  29. c.JSON(http.StatusOK, gin.H{
  30. "success": true,
  31. "message": "",
  32. "data": channels,
  33. })
  34. return
  35. }
  36. func SearchChannels(c *gin.Context) {
  37. keyword := c.Query("keyword")
  38. channels, err := model.SearchChannels(keyword)
  39. if err != nil {
  40. c.JSON(http.StatusOK, gin.H{
  41. "success": false,
  42. "message": err.Error(),
  43. })
  44. return
  45. }
  46. c.JSON(http.StatusOK, gin.H{
  47. "success": true,
  48. "message": "",
  49. "data": channels,
  50. })
  51. return
  52. }
  53. func GetChannel(c *gin.Context) {
  54. id, err := strconv.Atoi(c.Param("id"))
  55. if err != nil {
  56. c.JSON(http.StatusOK, gin.H{
  57. "success": false,
  58. "message": err.Error(),
  59. })
  60. return
  61. }
  62. channel, err := model.GetChannelById(id, false)
  63. if err != nil {
  64. c.JSON(http.StatusOK, gin.H{
  65. "success": false,
  66. "message": err.Error(),
  67. })
  68. return
  69. }
  70. c.JSON(http.StatusOK, gin.H{
  71. "success": true,
  72. "message": "",
  73. "data": channel,
  74. })
  75. return
  76. }
  77. func AddChannel(c *gin.Context) {
  78. channel := model.Channel{}
  79. err := c.ShouldBindJSON(&channel)
  80. if err != nil {
  81. c.JSON(http.StatusOK, gin.H{
  82. "success": false,
  83. "message": err.Error(),
  84. })
  85. return
  86. }
  87. channel.CreatedTime = common.GetTimestamp()
  88. keys := strings.Split(channel.Key, "\n")
  89. channels := make([]model.Channel, 0)
  90. for _, key := range keys {
  91. if key == "" {
  92. continue
  93. }
  94. localChannel := channel
  95. localChannel.Key = key
  96. channels = append(channels, localChannel)
  97. }
  98. err = model.BatchInsertChannels(channels)
  99. if err != nil {
  100. c.JSON(http.StatusOK, gin.H{
  101. "success": false,
  102. "message": err.Error(),
  103. })
  104. return
  105. }
  106. c.JSON(http.StatusOK, gin.H{
  107. "success": true,
  108. "message": "",
  109. })
  110. return
  111. }
  112. func DeleteChannel(c *gin.Context) {
  113. id, _ := strconv.Atoi(c.Param("id"))
  114. channel := model.Channel{Id: id}
  115. err := channel.Delete()
  116. if err != nil {
  117. c.JSON(http.StatusOK, gin.H{
  118. "success": false,
  119. "message": err.Error(),
  120. })
  121. return
  122. }
  123. c.JSON(http.StatusOK, gin.H{
  124. "success": true,
  125. "message": "",
  126. })
  127. return
  128. }
  129. func UpdateChannel(c *gin.Context) {
  130. channel := model.Channel{}
  131. err := c.ShouldBindJSON(&channel)
  132. if err != nil {
  133. c.JSON(http.StatusOK, gin.H{
  134. "success": false,
  135. "message": err.Error(),
  136. })
  137. return
  138. }
  139. err = channel.Update()
  140. if err != nil {
  141. c.JSON(http.StatusOK, gin.H{
  142. "success": false,
  143. "message": err.Error(),
  144. })
  145. return
  146. }
  147. c.JSON(http.StatusOK, gin.H{
  148. "success": true,
  149. "message": "",
  150. "data": channel,
  151. })
  152. return
  153. }
  154. func testChannel(channel *model.Channel, request *ChatRequest) error {
  155. if request.Model == "" {
  156. request.Model = "gpt-3.5-turbo"
  157. if channel.Type == common.ChannelTypeAzure {
  158. request.Model = "gpt-35-turbo"
  159. }
  160. }
  161. requestURL := common.ChannelBaseURLs[channel.Type]
  162. if channel.Type == common.ChannelTypeAzure {
  163. requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
  164. } else {
  165. if channel.Type == common.ChannelTypeCustom {
  166. requestURL = channel.BaseURL
  167. }
  168. requestURL += "/v1/chat/completions"
  169. }
  170. jsonData, err := json.Marshal(request)
  171. if err != nil {
  172. return err
  173. }
  174. req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
  175. if err != nil {
  176. return err
  177. }
  178. if channel.Type == common.ChannelTypeAzure {
  179. req.Header.Set("api-key", channel.Key)
  180. } else {
  181. req.Header.Set("Authorization", "Bearer "+channel.Key)
  182. }
  183. req.Header.Set("Content-Type", "application/json")
  184. client := &http.Client{}
  185. resp, err := client.Do(req)
  186. if err != nil {
  187. return err
  188. }
  189. defer resp.Body.Close()
  190. var response TextResponse
  191. err = json.NewDecoder(resp.Body).Decode(&response)
  192. if err != nil {
  193. return err
  194. }
  195. if response.Error.Type != "" {
  196. return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
  197. }
  198. return nil
  199. }
  200. func buildTestRequest(c *gin.Context) *ChatRequest {
  201. model_ := c.Query("model")
  202. testRequest := &ChatRequest{
  203. Model: model_,
  204. MaxTokens: 1,
  205. }
  206. testMessage := Message{
  207. Role: "user",
  208. Content: "hi",
  209. }
  210. testRequest.Messages = append(testRequest.Messages, testMessage)
  211. return testRequest
  212. }
  213. func TestChannel(c *gin.Context) {
  214. id, err := strconv.Atoi(c.Param("id"))
  215. if err != nil {
  216. c.JSON(http.StatusOK, gin.H{
  217. "success": false,
  218. "message": err.Error(),
  219. })
  220. return
  221. }
  222. channel, err := model.GetChannelById(id, true)
  223. if err != nil {
  224. c.JSON(http.StatusOK, gin.H{
  225. "success": false,
  226. "message": err.Error(),
  227. })
  228. return
  229. }
  230. testRequest := buildTestRequest(c)
  231. tik := time.Now()
  232. err = testChannel(channel, testRequest)
  233. tok := time.Now()
  234. milliseconds := tok.Sub(tik).Milliseconds()
  235. go channel.UpdateResponseTime(milliseconds)
  236. consumedTime := float64(milliseconds) / 1000.0
  237. if err != nil {
  238. c.JSON(http.StatusOK, gin.H{
  239. "success": false,
  240. "message": err.Error(),
  241. "time": consumedTime,
  242. })
  243. return
  244. }
  245. c.JSON(http.StatusOK, gin.H{
  246. "success": true,
  247. "message": "",
  248. "time": consumedTime,
  249. })
  250. return
  251. }
  252. var testAllChannelsLock sync.Mutex
  253. var testAllChannelsRunning bool = false
  254. // disable & notify
  255. func disableChannel(channelId int, channelName string, reason string) {
  256. if common.RootUserEmail == "" {
  257. common.RootUserEmail = model.GetRootUserEmail()
  258. }
  259. model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
  260. subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
  261. content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
  262. err := common.SendEmail(subject, common.RootUserEmail, content)
  263. if err != nil {
  264. common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
  265. }
  266. }
  267. func testAllChannels(c *gin.Context) error {
  268. testAllChannelsLock.Lock()
  269. if testAllChannelsRunning {
  270. testAllChannelsLock.Unlock()
  271. return errors.New("测试已在运行中")
  272. }
  273. testAllChannelsRunning = true
  274. testAllChannelsLock.Unlock()
  275. channels, err := model.GetAllChannels(0, 0, true)
  276. if err != nil {
  277. c.JSON(http.StatusOK, gin.H{
  278. "success": false,
  279. "message": err.Error(),
  280. })
  281. return err
  282. }
  283. testRequest := buildTestRequest(c)
  284. var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
  285. if disableThreshold == 0 {
  286. disableThreshold = 10000000 // a impossible value
  287. }
  288. go func() {
  289. for _, channel := range channels {
  290. if channel.Status != common.ChannelStatusEnabled {
  291. continue
  292. }
  293. tik := time.Now()
  294. err := testChannel(channel, testRequest)
  295. tok := time.Now()
  296. milliseconds := tok.Sub(tik).Milliseconds()
  297. if err != nil || milliseconds > disableThreshold {
  298. if milliseconds > disableThreshold {
  299. err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
  300. }
  301. disableChannel(channel.Id, channel.Name, err.Error())
  302. }
  303. channel.UpdateResponseTime(milliseconds)
  304. }
  305. err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
  306. if err != nil {
  307. common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
  308. }
  309. testAllChannelsLock.Lock()
  310. testAllChannelsRunning = false
  311. testAllChannelsLock.Unlock()
  312. }()
  313. return nil
  314. }
  315. func TestAllChannels(c *gin.Context) {
  316. err := testAllChannels(c)
  317. if err != nil {
  318. c.JSON(http.StatusOK, gin.H{
  319. "success": false,
  320. "message": err.Error(),
  321. })
  322. return
  323. }
  324. c.JSON(http.StatusOK, gin.H{
  325. "success": true,
  326. "message": "",
  327. })
  328. return
  329. }