channel.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. package controller
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "github.com/gin-gonic/gin"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/model"
  11. "strconv"
  12. "strings"
  13. "time"
  14. )
  15. func GetAllChannels(c *gin.Context) {
  16. p, _ := strconv.Atoi(c.Query("p"))
  17. if p < 0 {
  18. p = 0
  19. }
  20. channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage)
  21. if err != nil {
  22. c.JSON(http.StatusOK, gin.H{
  23. "success": false,
  24. "message": err.Error(),
  25. })
  26. return
  27. }
  28. c.JSON(http.StatusOK, gin.H{
  29. "success": true,
  30. "message": "",
  31. "data": channels,
  32. })
  33. return
  34. }
  35. func SearchChannels(c *gin.Context) {
  36. keyword := c.Query("keyword")
  37. channels, err := model.SearchChannels(keyword)
  38. if err != nil {
  39. c.JSON(http.StatusOK, gin.H{
  40. "success": false,
  41. "message": err.Error(),
  42. })
  43. return
  44. }
  45. c.JSON(http.StatusOK, gin.H{
  46. "success": true,
  47. "message": "",
  48. "data": channels,
  49. })
  50. return
  51. }
  52. func GetChannel(c *gin.Context) {
  53. id, err := strconv.Atoi(c.Param("id"))
  54. if err != nil {
  55. c.JSON(http.StatusOK, gin.H{
  56. "success": false,
  57. "message": err.Error(),
  58. })
  59. return
  60. }
  61. channel, err := model.GetChannelById(id, false)
  62. if err != nil {
  63. c.JSON(http.StatusOK, gin.H{
  64. "success": false,
  65. "message": err.Error(),
  66. })
  67. return
  68. }
  69. c.JSON(http.StatusOK, gin.H{
  70. "success": true,
  71. "message": "",
  72. "data": channel,
  73. })
  74. return
  75. }
  76. func AddChannel(c *gin.Context) {
  77. channel := model.Channel{}
  78. err := c.ShouldBindJSON(&channel)
  79. if err != nil {
  80. c.JSON(http.StatusOK, gin.H{
  81. "success": false,
  82. "message": err.Error(),
  83. })
  84. return
  85. }
  86. channel.CreatedTime = common.GetTimestamp()
  87. channel.AccessedTime = common.GetTimestamp()
  88. keys := strings.Split(channel.Key, "\n")
  89. channels := make([]model.Channel, 0)
  90. for _, key := range keys {
  91. if key == "" {
  92. continue
  93. }
  94. localChannel := channel
  95. localChannel.Key = key
  96. channels = append(channels, localChannel)
  97. }
  98. err = model.BatchInsertChannels(channels)
  99. if err != nil {
  100. c.JSON(http.StatusOK, gin.H{
  101. "success": false,
  102. "message": err.Error(),
  103. })
  104. return
  105. }
  106. c.JSON(http.StatusOK, gin.H{
  107. "success": true,
  108. "message": "",
  109. })
  110. return
  111. }
  112. func DeleteChannel(c *gin.Context) {
  113. id, _ := strconv.Atoi(c.Param("id"))
  114. channel := model.Channel{Id: id}
  115. err := channel.Delete()
  116. if err != nil {
  117. c.JSON(http.StatusOK, gin.H{
  118. "success": false,
  119. "message": err.Error(),
  120. })
  121. return
  122. }
  123. c.JSON(http.StatusOK, gin.H{
  124. "success": true,
  125. "message": "",
  126. })
  127. return
  128. }
  129. func UpdateChannel(c *gin.Context) {
  130. channel := model.Channel{}
  131. err := c.ShouldBindJSON(&channel)
  132. if err != nil {
  133. c.JSON(http.StatusOK, gin.H{
  134. "success": false,
  135. "message": err.Error(),
  136. })
  137. return
  138. }
  139. err = channel.Update()
  140. if err != nil {
  141. c.JSON(http.StatusOK, gin.H{
  142. "success": false,
  143. "message": err.Error(),
  144. })
  145. return
  146. }
  147. c.JSON(http.StatusOK, gin.H{
  148. "success": true,
  149. "message": "",
  150. "data": channel,
  151. })
  152. return
  153. }
  154. func testChannel(channel *model.Channel, request *ChatRequest) error {
  155. if request.Model == "" {
  156. request.Model = "gpt-3.5-turbo"
  157. if channel.Type == common.ChannelTypeAzure {
  158. request.Model = "gpt-35-turbo"
  159. }
  160. }
  161. requestURL := common.ChannelBaseURLs[channel.Type]
  162. if channel.Type == common.ChannelTypeAzure {
  163. requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
  164. } else {
  165. if channel.Type == common.ChannelTypeCustom {
  166. requestURL = channel.BaseURL
  167. }
  168. requestURL += "/v1/chat/completions"
  169. }
  170. jsonData, err := json.Marshal(request)
  171. if err != nil {
  172. return err
  173. }
  174. req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
  175. if err != nil {
  176. return err
  177. }
  178. if channel.Type == common.ChannelTypeAzure {
  179. req.Header.Set("api-key", channel.Key)
  180. } else {
  181. req.Header.Set("Authorization", "Bearer "+channel.Key)
  182. }
  183. req.Header.Set("Content-Type", "application/json")
  184. client := &http.Client{}
  185. resp, err := client.Do(req)
  186. if err != nil {
  187. return err
  188. }
  189. defer resp.Body.Close()
  190. var response TextResponse
  191. err = json.NewDecoder(resp.Body).Decode(&response)
  192. if err != nil {
  193. return err
  194. }
  195. if response.Error.Type != "" {
  196. return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
  197. }
  198. return nil
  199. }
  200. func TestChannel(c *gin.Context) {
  201. id, err := strconv.Atoi(c.Param("id"))
  202. if err != nil {
  203. c.JSON(http.StatusOK, gin.H{
  204. "success": false,
  205. "message": err.Error(),
  206. })
  207. return
  208. }
  209. channel, err := model.GetChannelById(id, true)
  210. if err != nil {
  211. c.JSON(http.StatusOK, gin.H{
  212. "success": false,
  213. "message": err.Error(),
  214. })
  215. return
  216. }
  217. model_ := c.Query("model")
  218. chatRequest := &ChatRequest{
  219. Model: model_,
  220. }
  221. testMessage := Message{
  222. Role: "user",
  223. Content: "echo hi",
  224. }
  225. chatRequest.Messages = append(chatRequest.Messages, testMessage)
  226. tik := time.Now()
  227. err = testChannel(channel, chatRequest)
  228. tok := time.Now()
  229. consumedTime := float64(tok.Sub(tik).Milliseconds()) / 1000.0
  230. if err != nil {
  231. c.JSON(http.StatusOK, gin.H{
  232. "success": false,
  233. "message": err.Error(),
  234. "time": consumedTime,
  235. })
  236. return
  237. }
  238. c.JSON(http.StatusOK, gin.H{
  239. "success": true,
  240. "message": "",
  241. "time": consumedTime,
  242. })
  243. return
  244. }