channel.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. package controller
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "github.com/gin-gonic/gin"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/model"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "time"
  15. )
  16. func GetAllChannels(c *gin.Context) {
  17. p, _ := strconv.Atoi(c.Query("p"))
  18. if p < 0 {
  19. p = 0
  20. }
  21. channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
  22. if err != nil {
  23. c.JSON(http.StatusOK, gin.H{
  24. "success": false,
  25. "message": err.Error(),
  26. })
  27. return
  28. }
  29. c.JSON(http.StatusOK, gin.H{
  30. "success": true,
  31. "message": "",
  32. "data": channels,
  33. })
  34. return
  35. }
  36. func SearchChannels(c *gin.Context) {
  37. keyword := c.Query("keyword")
  38. channels, err := model.SearchChannels(keyword)
  39. if err != nil {
  40. c.JSON(http.StatusOK, gin.H{
  41. "success": false,
  42. "message": err.Error(),
  43. })
  44. return
  45. }
  46. c.JSON(http.StatusOK, gin.H{
  47. "success": true,
  48. "message": "",
  49. "data": channels,
  50. })
  51. return
  52. }
  53. func GetChannel(c *gin.Context) {
  54. id, err := strconv.Atoi(c.Param("id"))
  55. if err != nil {
  56. c.JSON(http.StatusOK, gin.H{
  57. "success": false,
  58. "message": err.Error(),
  59. })
  60. return
  61. }
  62. channel, err := model.GetChannelById(id, false)
  63. if err != nil {
  64. c.JSON(http.StatusOK, gin.H{
  65. "success": false,
  66. "message": err.Error(),
  67. })
  68. return
  69. }
  70. c.JSON(http.StatusOK, gin.H{
  71. "success": true,
  72. "message": "",
  73. "data": channel,
  74. })
  75. return
  76. }
  77. func AddChannel(c *gin.Context) {
  78. channel := model.Channel{}
  79. err := c.ShouldBindJSON(&channel)
  80. if err != nil {
  81. c.JSON(http.StatusOK, gin.H{
  82. "success": false,
  83. "message": err.Error(),
  84. })
  85. return
  86. }
  87. channel.CreatedTime = common.GetTimestamp()
  88. keys := strings.Split(channel.Key, "\n")
  89. channels := make([]model.Channel, 0)
  90. for _, key := range keys {
  91. if key == "" {
  92. continue
  93. }
  94. localChannel := channel
  95. localChannel.Key = key
  96. channels = append(channels, localChannel)
  97. }
  98. err = model.BatchInsertChannels(channels)
  99. if err != nil {
  100. c.JSON(http.StatusOK, gin.H{
  101. "success": false,
  102. "message": err.Error(),
  103. })
  104. return
  105. }
  106. c.JSON(http.StatusOK, gin.H{
  107. "success": true,
  108. "message": "",
  109. })
  110. return
  111. }
  112. func DeleteChannel(c *gin.Context) {
  113. id, _ := strconv.Atoi(c.Param("id"))
  114. channel := model.Channel{Id: id}
  115. err := channel.Delete()
  116. if err != nil {
  117. c.JSON(http.StatusOK, gin.H{
  118. "success": false,
  119. "message": err.Error(),
  120. })
  121. return
  122. }
  123. c.JSON(http.StatusOK, gin.H{
  124. "success": true,
  125. "message": "",
  126. })
  127. return
  128. }
  129. func UpdateChannel(c *gin.Context) {
  130. channel := model.Channel{}
  131. err := c.ShouldBindJSON(&channel)
  132. if err != nil {
  133. c.JSON(http.StatusOK, gin.H{
  134. "success": false,
  135. "message": err.Error(),
  136. })
  137. return
  138. }
  139. err = channel.Update()
  140. if err != nil {
  141. c.JSON(http.StatusOK, gin.H{
  142. "success": false,
  143. "message": err.Error(),
  144. })
  145. return
  146. }
  147. c.JSON(http.StatusOK, gin.H{
  148. "success": true,
  149. "message": "",
  150. "data": channel,
  151. })
  152. return
  153. }
  154. func testChannel(channel *model.Channel, request *ChatRequest) error {
  155. if request.Model == "" {
  156. request.Model = "gpt-3.5-turbo"
  157. if channel.Type == common.ChannelTypeAzure {
  158. request.Model = "gpt-35-turbo"
  159. }
  160. }
  161. requestURL := common.ChannelBaseURLs[channel.Type]
  162. if channel.Type == common.ChannelTypeAzure {
  163. requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
  164. } else {
  165. if channel.Type == common.ChannelTypeCustom {
  166. requestURL = channel.BaseURL
  167. }
  168. requestURL += "/v1/chat/completions"
  169. }
  170. jsonData, err := json.Marshal(request)
  171. if err != nil {
  172. return err
  173. }
  174. req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
  175. if err != nil {
  176. return err
  177. }
  178. if channel.Type == common.ChannelTypeAzure {
  179. req.Header.Set("api-key", channel.Key)
  180. } else {
  181. req.Header.Set("Authorization", "Bearer "+channel.Key)
  182. }
  183. req.Header.Set("Content-Type", "application/json")
  184. client := &http.Client{}
  185. resp, err := client.Do(req)
  186. if err != nil {
  187. return err
  188. }
  189. defer resp.Body.Close()
  190. var response TextResponse
  191. err = json.NewDecoder(resp.Body).Decode(&response)
  192. if err != nil {
  193. return err
  194. }
  195. if response.Error.Type != "" {
  196. return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
  197. }
  198. return nil
  199. }
  200. func buildTestRequest(c *gin.Context) *ChatRequest {
  201. model_ := c.Query("model")
  202. testRequest := &ChatRequest{
  203. Model: model_,
  204. }
  205. testMessage := Message{
  206. Role: "user",
  207. Content: "echo hi",
  208. }
  209. testRequest.Messages = append(testRequest.Messages, testMessage)
  210. return testRequest
  211. }
  212. func TestChannel(c *gin.Context) {
  213. id, err := strconv.Atoi(c.Param("id"))
  214. if err != nil {
  215. c.JSON(http.StatusOK, gin.H{
  216. "success": false,
  217. "message": err.Error(),
  218. })
  219. return
  220. }
  221. channel, err := model.GetChannelById(id, true)
  222. if err != nil {
  223. c.JSON(http.StatusOK, gin.H{
  224. "success": false,
  225. "message": err.Error(),
  226. })
  227. return
  228. }
  229. testRequest := buildTestRequest(c)
  230. tik := time.Now()
  231. err = testChannel(channel, testRequest)
  232. tok := time.Now()
  233. milliseconds := tok.Sub(tik).Milliseconds()
  234. go channel.UpdateResponseTime(milliseconds)
  235. consumedTime := float64(milliseconds) / 1000.0
  236. if err != nil {
  237. c.JSON(http.StatusOK, gin.H{
  238. "success": false,
  239. "message": err.Error(),
  240. "time": consumedTime,
  241. })
  242. return
  243. }
  244. c.JSON(http.StatusOK, gin.H{
  245. "success": true,
  246. "message": "",
  247. "time": consumedTime,
  248. })
  249. return
  250. }
  251. var testAllChannelsLock sync.Mutex
  252. func testAllChannels(c *gin.Context) error {
  253. ok := testAllChannelsLock.TryLock()
  254. if !ok {
  255. return errors.New("测试已在运行")
  256. }
  257. defer testAllChannelsLock.Unlock()
  258. channels, err := model.GetAllChannels(0, 0, true)
  259. if err != nil {
  260. c.JSON(http.StatusOK, gin.H{
  261. "success": false,
  262. "message": err.Error(),
  263. })
  264. return err
  265. }
  266. testRequest := buildTestRequest(c)
  267. var disableThreshold int64 = 5000 // TODO: make it configurable
  268. email := model.GetRootUserEmail()
  269. go func() {
  270. for _, channel := range channels {
  271. if channel.Status != common.ChannelStatusEnabled {
  272. continue
  273. }
  274. tik := time.Now()
  275. err := testChannel(channel, testRequest)
  276. tok := time.Now()
  277. milliseconds := tok.Sub(tik).Milliseconds()
  278. if err != nil || milliseconds > disableThreshold {
  279. if milliseconds > disableThreshold {
  280. err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
  281. }
  282. // disable & notify
  283. channel.UpdateStatus(common.ChannelStatusDisabled)
  284. subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channel.Name, channel.Id)
  285. content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channel.Name, channel.Id, err.Error())
  286. err = common.SendEmail(subject, email, content)
  287. if err != nil {
  288. common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
  289. }
  290. }
  291. channel.UpdateResponseTime(milliseconds)
  292. }
  293. err := common.SendEmail("通道测试完成", email, "通道测试完成")
  294. if err != nil {
  295. common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
  296. }
  297. }()
  298. return nil
  299. }
  300. func TestAllChannels(c *gin.Context) {
  301. err := testAllChannels(c)
  302. if err != nil {
  303. c.JSON(http.StatusOK, gin.H{
  304. "success": false,
  305. "message": err.Error(),
  306. })
  307. return
  308. }
  309. c.JSON(http.StatusOK, gin.H{
  310. "success": true,
  311. "message": "",
  312. })
  313. return
  314. }