channel.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. package controller
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/model"
  8. "strconv"
  9. "strings"
  10. "github.com/gin-gonic/gin"
  11. )
  12. type OpenAIModel struct {
  13. ID string `json:"id"`
  14. Object string `json:"object"`
  15. Created int64 `json:"created"`
  16. OwnedBy string `json:"owned_by"`
  17. Permission []struct {
  18. ID string `json:"id"`
  19. Object string `json:"object"`
  20. Created int64 `json:"created"`
  21. AllowCreateEngine bool `json:"allow_create_engine"`
  22. AllowSampling bool `json:"allow_sampling"`
  23. AllowLogprobs bool `json:"allow_logprobs"`
  24. AllowSearchIndices bool `json:"allow_search_indices"`
  25. AllowView bool `json:"allow_view"`
  26. AllowFineTuning bool `json:"allow_fine_tuning"`
  27. Organization string `json:"organization"`
  28. Group string `json:"group"`
  29. IsBlocking bool `json:"is_blocking"`
  30. } `json:"permission"`
  31. Root string `json:"root"`
  32. Parent string `json:"parent"`
  33. }
  34. type OpenAIModelsResponse struct {
  35. Data []OpenAIModel `json:"data"`
  36. Success bool `json:"success"`
  37. }
  38. func GetAllChannels(c *gin.Context) {
  39. p, _ := strconv.Atoi(c.Query("p"))
  40. pageSize, _ := strconv.Atoi(c.Query("page_size"))
  41. if p < 0 {
  42. p = 0
  43. }
  44. if pageSize < 0 {
  45. pageSize = common.ItemsPerPage
  46. }
  47. channelData := make([]*model.Channel, 0)
  48. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  49. enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
  50. if enableTagMode {
  51. tags, err := model.GetPaginatedTags(p*pageSize, pageSize)
  52. if err != nil {
  53. c.JSON(http.StatusOK, gin.H{
  54. "success": false,
  55. "message": err.Error(),
  56. })
  57. return
  58. }
  59. for _, tag := range tags {
  60. if tag != nil && *tag != "" {
  61. tagChannel, err := model.GetChannelsByTag(*tag)
  62. if err == nil {
  63. channelData = append(channelData, tagChannel...)
  64. }
  65. }
  66. }
  67. } else {
  68. channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
  69. if err != nil {
  70. c.JSON(http.StatusOK, gin.H{
  71. "success": false,
  72. "message": err.Error(),
  73. })
  74. return
  75. }
  76. channelData = channels
  77. }
  78. c.JSON(http.StatusOK, gin.H{
  79. "success": true,
  80. "message": "",
  81. "data": channelData,
  82. })
  83. return
  84. }
  85. func FetchUpstreamModels(c *gin.Context) {
  86. id, err := strconv.Atoi(c.Param("id"))
  87. if err != nil {
  88. c.JSON(http.StatusOK, gin.H{
  89. "success": false,
  90. "message": err.Error(),
  91. })
  92. return
  93. }
  94. channel, err := model.GetChannelById(id, true)
  95. if err != nil {
  96. c.JSON(http.StatusOK, gin.H{
  97. "success": false,
  98. "message": err.Error(),
  99. })
  100. return
  101. }
  102. if channel.Type != common.ChannelTypeOpenAI {
  103. c.JSON(http.StatusOK, gin.H{
  104. "success": false,
  105. "message": "仅支持 OpenAI 类型渠道",
  106. })
  107. return
  108. }
  109. url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
  110. body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
  111. if err != nil {
  112. c.JSON(http.StatusOK, gin.H{
  113. "success": false,
  114. "message": err.Error(),
  115. })
  116. }
  117. result := OpenAIModelsResponse{}
  118. err = json.Unmarshal(body, &result)
  119. if err != nil {
  120. c.JSON(http.StatusOK, gin.H{
  121. "success": false,
  122. "message": err.Error(),
  123. })
  124. }
  125. if !result.Success {
  126. c.JSON(http.StatusOK, gin.H{
  127. "success": false,
  128. "message": "上游返回错误",
  129. })
  130. }
  131. var ids []string
  132. for _, model := range result.Data {
  133. ids = append(ids, model.ID)
  134. }
  135. c.JSON(http.StatusOK, gin.H{
  136. "success": true,
  137. "message": "",
  138. "data": ids,
  139. })
  140. }
  141. func FixChannelsAbilities(c *gin.Context) {
  142. count, err := model.FixAbility()
  143. if err != nil {
  144. c.JSON(http.StatusOK, gin.H{
  145. "success": false,
  146. "message": err.Error(),
  147. })
  148. return
  149. }
  150. c.JSON(http.StatusOK, gin.H{
  151. "success": true,
  152. "message": "",
  153. "data": count,
  154. })
  155. }
  156. func SearchChannels(c *gin.Context) {
  157. keyword := c.Query("keyword")
  158. group := c.Query("group")
  159. modelKeyword := c.Query("model")
  160. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  161. channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
  162. if err != nil {
  163. c.JSON(http.StatusOK, gin.H{
  164. "success": false,
  165. "message": err.Error(),
  166. })
  167. return
  168. }
  169. c.JSON(http.StatusOK, gin.H{
  170. "success": true,
  171. "message": "",
  172. "data": channels,
  173. })
  174. return
  175. }
  176. func GetChannel(c *gin.Context) {
  177. id, err := strconv.Atoi(c.Param("id"))
  178. if err != nil {
  179. c.JSON(http.StatusOK, gin.H{
  180. "success": false,
  181. "message": err.Error(),
  182. })
  183. return
  184. }
  185. channel, err := model.GetChannelById(id, false)
  186. if err != nil {
  187. c.JSON(http.StatusOK, gin.H{
  188. "success": false,
  189. "message": err.Error(),
  190. })
  191. return
  192. }
  193. c.JSON(http.StatusOK, gin.H{
  194. "success": true,
  195. "message": "",
  196. "data": channel,
  197. })
  198. return
  199. }
  200. func AddChannel(c *gin.Context) {
  201. channel := model.Channel{}
  202. err := c.ShouldBindJSON(&channel)
  203. if err != nil {
  204. c.JSON(http.StatusOK, gin.H{
  205. "success": false,
  206. "message": err.Error(),
  207. })
  208. return
  209. }
  210. channel.CreatedTime = common.GetTimestamp()
  211. keys := strings.Split(channel.Key, "\n")
  212. if channel.Type == common.ChannelTypeVertexAi {
  213. if channel.Other == "" {
  214. c.JSON(http.StatusOK, gin.H{
  215. "success": false,
  216. "message": "部署地区不能为空",
  217. })
  218. return
  219. } else {
  220. if common.IsJsonStr(channel.Other) {
  221. // must have default
  222. regionMap := common.StrToMap(channel.Other)
  223. if regionMap["default"] == nil {
  224. c.JSON(http.StatusOK, gin.H{
  225. "success": false,
  226. "message": "部署地区必须包含default字段",
  227. })
  228. return
  229. }
  230. }
  231. }
  232. keys = []string{channel.Key}
  233. }
  234. channels := make([]model.Channel, 0, len(keys))
  235. for _, key := range keys {
  236. if key == "" {
  237. continue
  238. }
  239. localChannel := channel
  240. localChannel.Key = key
  241. channels = append(channels, localChannel)
  242. }
  243. err = model.BatchInsertChannels(channels)
  244. if err != nil {
  245. c.JSON(http.StatusOK, gin.H{
  246. "success": false,
  247. "message": err.Error(),
  248. })
  249. return
  250. }
  251. c.JSON(http.StatusOK, gin.H{
  252. "success": true,
  253. "message": "",
  254. })
  255. return
  256. }
  257. func DeleteChannel(c *gin.Context) {
  258. id, _ := strconv.Atoi(c.Param("id"))
  259. channel := model.Channel{Id: id}
  260. err := channel.Delete()
  261. if err != nil {
  262. c.JSON(http.StatusOK, gin.H{
  263. "success": false,
  264. "message": err.Error(),
  265. })
  266. return
  267. }
  268. c.JSON(http.StatusOK, gin.H{
  269. "success": true,
  270. "message": "",
  271. })
  272. return
  273. }
  274. func DeleteDisabledChannel(c *gin.Context) {
  275. rows, err := model.DeleteDisabledChannel()
  276. if err != nil {
  277. c.JSON(http.StatusOK, gin.H{
  278. "success": false,
  279. "message": err.Error(),
  280. })
  281. return
  282. }
  283. c.JSON(http.StatusOK, gin.H{
  284. "success": true,
  285. "message": "",
  286. "data": rows,
  287. })
  288. return
  289. }
  290. type ChannelTag struct {
  291. Tag string `json:"tag"`
  292. NewTag *string `json:"new_tag"`
  293. Priority *int64 `json:"priority"`
  294. Weight *uint `json:"weight"`
  295. ModelMapping *string `json:"model_mapping"`
  296. Models *string `json:"models"`
  297. Groups *string `json:"groups"`
  298. }
  299. func DisableTagChannels(c *gin.Context) {
  300. channelTag := ChannelTag{}
  301. err := c.ShouldBindJSON(&channelTag)
  302. if err != nil || channelTag.Tag == "" {
  303. c.JSON(http.StatusOK, gin.H{
  304. "success": false,
  305. "message": "参数错误",
  306. })
  307. return
  308. }
  309. err = model.DisableChannelByTag(channelTag.Tag)
  310. if err != nil {
  311. c.JSON(http.StatusOK, gin.H{
  312. "success": false,
  313. "message": err.Error(),
  314. })
  315. return
  316. }
  317. c.JSON(http.StatusOK, gin.H{
  318. "success": true,
  319. "message": "",
  320. })
  321. return
  322. }
  323. func EnableTagChannels(c *gin.Context) {
  324. channelTag := ChannelTag{}
  325. err := c.ShouldBindJSON(&channelTag)
  326. if err != nil || channelTag.Tag == "" {
  327. c.JSON(http.StatusOK, gin.H{
  328. "success": false,
  329. "message": "参数错误",
  330. })
  331. return
  332. }
  333. err = model.EnableChannelByTag(channelTag.Tag)
  334. if err != nil {
  335. c.JSON(http.StatusOK, gin.H{
  336. "success": false,
  337. "message": err.Error(),
  338. })
  339. return
  340. }
  341. c.JSON(http.StatusOK, gin.H{
  342. "success": true,
  343. "message": "",
  344. })
  345. return
  346. }
  347. func EditTagChannels(c *gin.Context) {
  348. channelTag := ChannelTag{}
  349. err := c.ShouldBindJSON(&channelTag)
  350. if err != nil {
  351. c.JSON(http.StatusOK, gin.H{
  352. "success": false,
  353. "message": "参数错误",
  354. })
  355. return
  356. }
  357. if channelTag.Tag == "" {
  358. c.JSON(http.StatusOK, gin.H{
  359. "success": false,
  360. "message": "tag不能为空",
  361. })
  362. return
  363. }
  364. err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
  365. if err != nil {
  366. c.JSON(http.StatusOK, gin.H{
  367. "success": false,
  368. "message": err.Error(),
  369. })
  370. return
  371. }
  372. c.JSON(http.StatusOK, gin.H{
  373. "success": true,
  374. "message": "",
  375. })
  376. return
  377. }
  378. type ChannelBatch struct {
  379. Ids []int `json:"ids"`
  380. }
  381. func DeleteChannelBatch(c *gin.Context) {
  382. channelBatch := ChannelBatch{}
  383. err := c.ShouldBindJSON(&channelBatch)
  384. if err != nil || len(channelBatch.Ids) == 0 {
  385. c.JSON(http.StatusOK, gin.H{
  386. "success": false,
  387. "message": "参数错误",
  388. })
  389. return
  390. }
  391. err = model.BatchDeleteChannels(channelBatch.Ids)
  392. if err != nil {
  393. c.JSON(http.StatusOK, gin.H{
  394. "success": false,
  395. "message": err.Error(),
  396. })
  397. return
  398. }
  399. c.JSON(http.StatusOK, gin.H{
  400. "success": true,
  401. "message": "",
  402. "data": len(channelBatch.Ids),
  403. })
  404. return
  405. }
  406. func UpdateChannel(c *gin.Context) {
  407. channel := model.Channel{}
  408. err := c.ShouldBindJSON(&channel)
  409. if err != nil {
  410. c.JSON(http.StatusOK, gin.H{
  411. "success": false,
  412. "message": err.Error(),
  413. })
  414. return
  415. }
  416. if channel.Type == common.ChannelTypeVertexAi {
  417. if channel.Other == "" {
  418. c.JSON(http.StatusOK, gin.H{
  419. "success": false,
  420. "message": "部署地区不能为空",
  421. })
  422. return
  423. } else {
  424. if common.IsJsonStr(channel.Other) {
  425. // must have default
  426. regionMap := common.StrToMap(channel.Other)
  427. if regionMap["default"] == nil {
  428. c.JSON(http.StatusOK, gin.H{
  429. "success": false,
  430. "message": "部署地区必须包含default字段",
  431. })
  432. return
  433. }
  434. }
  435. }
  436. }
  437. err = channel.Update()
  438. if err != nil {
  439. c.JSON(http.StatusOK, gin.H{
  440. "success": false,
  441. "message": err.Error(),
  442. })
  443. return
  444. }
  445. c.JSON(http.StatusOK, gin.H{
  446. "success": true,
  447. "message": "",
  448. "data": channel,
  449. })
  450. return
  451. }