channel.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. package controller
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "github.com/gin-gonic/gin"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/model"
  11. "strconv"
  12. "strings"
  13. "time"
  14. )
  15. func GetAllChannels(c *gin.Context) {
  16. p, _ := strconv.Atoi(c.Query("p"))
  17. if p < 0 {
  18. p = 0
  19. }
  20. channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage)
  21. if err != nil {
  22. c.JSON(http.StatusOK, gin.H{
  23. "success": false,
  24. "message": err.Error(),
  25. })
  26. return
  27. }
  28. c.JSON(http.StatusOK, gin.H{
  29. "success": true,
  30. "message": "",
  31. "data": channels,
  32. })
  33. return
  34. }
  35. func SearchChannels(c *gin.Context) {
  36. keyword := c.Query("keyword")
  37. channels, err := model.SearchChannels(keyword)
  38. if err != nil {
  39. c.JSON(http.StatusOK, gin.H{
  40. "success": false,
  41. "message": err.Error(),
  42. })
  43. return
  44. }
  45. c.JSON(http.StatusOK, gin.H{
  46. "success": true,
  47. "message": "",
  48. "data": channels,
  49. })
  50. return
  51. }
  52. func GetChannel(c *gin.Context) {
  53. id, err := strconv.Atoi(c.Param("id"))
  54. if err != nil {
  55. c.JSON(http.StatusOK, gin.H{
  56. "success": false,
  57. "message": err.Error(),
  58. })
  59. return
  60. }
  61. channel, err := model.GetChannelById(id, false)
  62. if err != nil {
  63. c.JSON(http.StatusOK, gin.H{
  64. "success": false,
  65. "message": err.Error(),
  66. })
  67. return
  68. }
  69. c.JSON(http.StatusOK, gin.H{
  70. "success": true,
  71. "message": "",
  72. "data": channel,
  73. })
  74. return
  75. }
  76. func AddChannel(c *gin.Context) {
  77. channel := model.Channel{}
  78. err := c.ShouldBindJSON(&channel)
  79. if err != nil {
  80. c.JSON(http.StatusOK, gin.H{
  81. "success": false,
  82. "message": err.Error(),
  83. })
  84. return
  85. }
  86. channel.CreatedTime = common.GetTimestamp()
  87. keys := strings.Split(channel.Key, "\n")
  88. channels := make([]model.Channel, 0)
  89. for _, key := range keys {
  90. if key == "" {
  91. continue
  92. }
  93. localChannel := channel
  94. localChannel.Key = key
  95. channels = append(channels, localChannel)
  96. }
  97. err = model.BatchInsertChannels(channels)
  98. if err != nil {
  99. c.JSON(http.StatusOK, gin.H{
  100. "success": false,
  101. "message": err.Error(),
  102. })
  103. return
  104. }
  105. c.JSON(http.StatusOK, gin.H{
  106. "success": true,
  107. "message": "",
  108. })
  109. return
  110. }
  111. func DeleteChannel(c *gin.Context) {
  112. id, _ := strconv.Atoi(c.Param("id"))
  113. channel := model.Channel{Id: id}
  114. err := channel.Delete()
  115. if err != nil {
  116. c.JSON(http.StatusOK, gin.H{
  117. "success": false,
  118. "message": err.Error(),
  119. })
  120. return
  121. }
  122. c.JSON(http.StatusOK, gin.H{
  123. "success": true,
  124. "message": "",
  125. })
  126. return
  127. }
  128. func UpdateChannel(c *gin.Context) {
  129. channel := model.Channel{}
  130. err := c.ShouldBindJSON(&channel)
  131. if err != nil {
  132. c.JSON(http.StatusOK, gin.H{
  133. "success": false,
  134. "message": err.Error(),
  135. })
  136. return
  137. }
  138. err = channel.Update()
  139. if err != nil {
  140. c.JSON(http.StatusOK, gin.H{
  141. "success": false,
  142. "message": err.Error(),
  143. })
  144. return
  145. }
  146. c.JSON(http.StatusOK, gin.H{
  147. "success": true,
  148. "message": "",
  149. "data": channel,
  150. })
  151. return
  152. }
  153. func testChannel(channel *model.Channel, request *ChatRequest) error {
  154. if request.Model == "" {
  155. request.Model = "gpt-3.5-turbo"
  156. if channel.Type == common.ChannelTypeAzure {
  157. request.Model = "gpt-35-turbo"
  158. }
  159. }
  160. requestURL := common.ChannelBaseURLs[channel.Type]
  161. if channel.Type == common.ChannelTypeAzure {
  162. requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
  163. } else {
  164. if channel.Type == common.ChannelTypeCustom {
  165. requestURL = channel.BaseURL
  166. }
  167. requestURL += "/v1/chat/completions"
  168. }
  169. jsonData, err := json.Marshal(request)
  170. if err != nil {
  171. return err
  172. }
  173. req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
  174. if err != nil {
  175. return err
  176. }
  177. if channel.Type == common.ChannelTypeAzure {
  178. req.Header.Set("api-key", channel.Key)
  179. } else {
  180. req.Header.Set("Authorization", "Bearer "+channel.Key)
  181. }
  182. req.Header.Set("Content-Type", "application/json")
  183. client := &http.Client{}
  184. resp, err := client.Do(req)
  185. if err != nil {
  186. return err
  187. }
  188. defer resp.Body.Close()
  189. var response TextResponse
  190. err = json.NewDecoder(resp.Body).Decode(&response)
  191. if err != nil {
  192. return err
  193. }
  194. if response.Error.Type != "" {
  195. return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
  196. }
  197. return nil
  198. }
  199. func TestChannel(c *gin.Context) {
  200. id, err := strconv.Atoi(c.Param("id"))
  201. if err != nil {
  202. c.JSON(http.StatusOK, gin.H{
  203. "success": false,
  204. "message": err.Error(),
  205. })
  206. return
  207. }
  208. channel, err := model.GetChannelById(id, true)
  209. if err != nil {
  210. c.JSON(http.StatusOK, gin.H{
  211. "success": false,
  212. "message": err.Error(),
  213. })
  214. return
  215. }
  216. model_ := c.Query("model")
  217. chatRequest := &ChatRequest{
  218. Model: model_,
  219. }
  220. testMessage := Message{
  221. Role: "user",
  222. Content: "echo hi",
  223. }
  224. chatRequest.Messages = append(chatRequest.Messages, testMessage)
  225. tik := time.Now()
  226. err = testChannel(channel, chatRequest)
  227. tok := time.Now()
  228. milliseconds := tok.Sub(tik).Milliseconds()
  229. go channel.UpdateResponseTime(milliseconds)
  230. consumedTime := float64(milliseconds) / 1000.0
  231. if err != nil {
  232. c.JSON(http.StatusOK, gin.H{
  233. "success": false,
  234. "message": err.Error(),
  235. "time": consumedTime,
  236. })
  237. return
  238. }
  239. c.JSON(http.StatusOK, gin.H{
  240. "success": true,
  241. "message": "",
  242. "time": consumedTime,
  243. })
  244. return
  245. }