channel.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873
  1. package controller
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/constant"
  8. "one-api/model"
  9. "strconv"
  10. "strings"
  11. "github.com/gin-gonic/gin"
  12. )
  13. type OpenAIModel struct {
  14. ID string `json:"id"`
  15. Object string `json:"object"`
  16. Created int64 `json:"created"`
  17. OwnedBy string `json:"owned_by"`
  18. Permission []struct {
  19. ID string `json:"id"`
  20. Object string `json:"object"`
  21. Created int64 `json:"created"`
  22. AllowCreateEngine bool `json:"allow_create_engine"`
  23. AllowSampling bool `json:"allow_sampling"`
  24. AllowLogprobs bool `json:"allow_logprobs"`
  25. AllowSearchIndices bool `json:"allow_search_indices"`
  26. AllowView bool `json:"allow_view"`
  27. AllowFineTuning bool `json:"allow_fine_tuning"`
  28. Organization string `json:"organization"`
  29. Group string `json:"group"`
  30. IsBlocking bool `json:"is_blocking"`
  31. } `json:"permission"`
  32. Root string `json:"root"`
  33. Parent string `json:"parent"`
  34. }
  35. type OpenAIModelsResponse struct {
  36. Data []OpenAIModel `json:"data"`
  37. Success bool `json:"success"`
  38. }
  39. func parseStatusFilter(statusParam string) int {
  40. switch strings.ToLower(statusParam) {
  41. case "enabled", "1":
  42. return common.ChannelStatusEnabled
  43. case "disabled", "0":
  44. return 0
  45. default:
  46. return -1
  47. }
  48. }
  49. func GetAllChannels(c *gin.Context) {
  50. p, _ := strconv.Atoi(c.Query("p"))
  51. pageSize, _ := strconv.Atoi(c.Query("page_size"))
  52. if p < 1 {
  53. p = 1
  54. }
  55. if pageSize < 1 {
  56. pageSize = common.ItemsPerPage
  57. }
  58. channelData := make([]*model.Channel, 0)
  59. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  60. enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
  61. statusParam := c.Query("status")
  62. // statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
  63. statusFilter := parseStatusFilter(statusParam)
  64. // type filter
  65. typeStr := c.Query("type")
  66. typeFilter := -1
  67. if typeStr != "" {
  68. if t, err := strconv.Atoi(typeStr); err == nil {
  69. typeFilter = t
  70. }
  71. }
  72. var total int64
  73. if enableTagMode {
  74. tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
  75. if err != nil {
  76. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  77. return
  78. }
  79. for _, tag := range tags {
  80. if tag == nil || *tag == "" {
  81. continue
  82. }
  83. tagChannels, err := model.GetChannelsByTag(*tag, idSort)
  84. if err != nil {
  85. continue
  86. }
  87. filtered := make([]*model.Channel, 0)
  88. for _, ch := range tagChannels {
  89. if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
  90. continue
  91. }
  92. if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
  93. continue
  94. }
  95. if typeFilter >= 0 && ch.Type != typeFilter {
  96. continue
  97. }
  98. filtered = append(filtered, ch)
  99. }
  100. channelData = append(channelData, filtered...)
  101. }
  102. total, _ = model.CountAllTags()
  103. } else {
  104. baseQuery := model.DB.Model(&model.Channel{})
  105. if typeFilter >= 0 {
  106. baseQuery = baseQuery.Where("type = ?", typeFilter)
  107. }
  108. if statusFilter == common.ChannelStatusEnabled {
  109. baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
  110. } else if statusFilter == 0 {
  111. baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
  112. }
  113. baseQuery.Count(&total)
  114. order := "priority desc"
  115. if idSort {
  116. order = "id desc"
  117. }
  118. err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error
  119. if err != nil {
  120. c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
  121. return
  122. }
  123. }
  124. countQuery := model.DB.Model(&model.Channel{})
  125. if statusFilter == common.ChannelStatusEnabled {
  126. countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
  127. } else if statusFilter == 0 {
  128. countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
  129. }
  130. var results []struct {
  131. Type int64
  132. Count int64
  133. }
  134. _ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
  135. typeCounts := make(map[int64]int64)
  136. for _, r := range results {
  137. typeCounts[r.Type] = r.Count
  138. }
  139. c.JSON(http.StatusOK, gin.H{
  140. "success": true,
  141. "message": "",
  142. "data": gin.H{
  143. "items": channelData,
  144. "total": total,
  145. "page": p,
  146. "page_size": pageSize,
  147. "type_counts": typeCounts,
  148. },
  149. })
  150. return
  151. }
  152. func FetchUpstreamModels(c *gin.Context) {
  153. id, err := strconv.Atoi(c.Param("id"))
  154. if err != nil {
  155. c.JSON(http.StatusOK, gin.H{
  156. "success": false,
  157. "message": err.Error(),
  158. })
  159. return
  160. }
  161. channel, err := model.GetChannelById(id, true)
  162. if err != nil {
  163. c.JSON(http.StatusOK, gin.H{
  164. "success": false,
  165. "message": err.Error(),
  166. })
  167. return
  168. }
  169. baseURL := constant.ChannelBaseURLs[channel.Type]
  170. if channel.GetBaseURL() != "" {
  171. baseURL = channel.GetBaseURL()
  172. }
  173. url := fmt.Sprintf("%s/v1/models", baseURL)
  174. switch channel.Type {
  175. case constant.ChannelTypeGemini:
  176. url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
  177. case constant.ChannelTypeAli:
  178. url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
  179. }
  180. body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
  181. if err != nil {
  182. c.JSON(http.StatusOK, gin.H{
  183. "success": false,
  184. "message": err.Error(),
  185. })
  186. return
  187. }
  188. var result OpenAIModelsResponse
  189. if err = json.Unmarshal(body, &result); err != nil {
  190. c.JSON(http.StatusOK, gin.H{
  191. "success": false,
  192. "message": fmt.Sprintf("解析响应失败: %s", err.Error()),
  193. })
  194. return
  195. }
  196. var ids []string
  197. for _, model := range result.Data {
  198. id := model.ID
  199. if channel.Type == constant.ChannelTypeGemini {
  200. id = strings.TrimPrefix(id, "models/")
  201. }
  202. ids = append(ids, id)
  203. }
  204. c.JSON(http.StatusOK, gin.H{
  205. "success": true,
  206. "message": "",
  207. "data": ids,
  208. })
  209. }
  210. func FixChannelsAbilities(c *gin.Context) {
  211. count, err := model.FixAbility()
  212. if err != nil {
  213. c.JSON(http.StatusOK, gin.H{
  214. "success": false,
  215. "message": err.Error(),
  216. })
  217. return
  218. }
  219. c.JSON(http.StatusOK, gin.H{
  220. "success": true,
  221. "message": "",
  222. "data": count,
  223. })
  224. }
  225. func SearchChannels(c *gin.Context) {
  226. keyword := c.Query("keyword")
  227. group := c.Query("group")
  228. modelKeyword := c.Query("model")
  229. statusParam := c.Query("status")
  230. statusFilter := parseStatusFilter(statusParam)
  231. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  232. enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
  233. channelData := make([]*model.Channel, 0)
  234. if enableTagMode {
  235. tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
  236. if err != nil {
  237. c.JSON(http.StatusOK, gin.H{
  238. "success": false,
  239. "message": err.Error(),
  240. })
  241. return
  242. }
  243. for _, tag := range tags {
  244. if tag != nil && *tag != "" {
  245. tagChannel, err := model.GetChannelsByTag(*tag, idSort)
  246. if err == nil {
  247. channelData = append(channelData, tagChannel...)
  248. }
  249. }
  250. }
  251. } else {
  252. channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
  253. if err != nil {
  254. c.JSON(http.StatusOK, gin.H{
  255. "success": false,
  256. "message": err.Error(),
  257. })
  258. return
  259. }
  260. channelData = channels
  261. }
  262. if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
  263. filtered := make([]*model.Channel, 0, len(channelData))
  264. for _, ch := range channelData {
  265. if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
  266. continue
  267. }
  268. if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
  269. continue
  270. }
  271. filtered = append(filtered, ch)
  272. }
  273. channelData = filtered
  274. }
  275. // calculate type counts for search results
  276. typeCounts := make(map[int64]int64)
  277. for _, channel := range channelData {
  278. typeCounts[int64(channel.Type)]++
  279. }
  280. typeParam := c.Query("type")
  281. typeFilter := -1
  282. if typeParam != "" {
  283. if tp, err := strconv.Atoi(typeParam); err == nil {
  284. typeFilter = tp
  285. }
  286. }
  287. if typeFilter >= 0 {
  288. filtered := make([]*model.Channel, 0, len(channelData))
  289. for _, ch := range channelData {
  290. if ch.Type == typeFilter {
  291. filtered = append(filtered, ch)
  292. }
  293. }
  294. channelData = filtered
  295. }
  296. page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
  297. pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
  298. if page < 1 {
  299. page = 1
  300. }
  301. if pageSize <= 0 {
  302. pageSize = 20
  303. }
  304. total := len(channelData)
  305. startIdx := (page - 1) * pageSize
  306. if startIdx > total {
  307. startIdx = total
  308. }
  309. endIdx := startIdx + pageSize
  310. if endIdx > total {
  311. endIdx = total
  312. }
  313. pagedData := channelData[startIdx:endIdx]
  314. c.JSON(http.StatusOK, gin.H{
  315. "success": true,
  316. "message": "",
  317. "data": gin.H{
  318. "items": pagedData,
  319. "total": total,
  320. "type_counts": typeCounts,
  321. },
  322. })
  323. return
  324. }
  325. func GetChannel(c *gin.Context) {
  326. id, err := strconv.Atoi(c.Param("id"))
  327. if err != nil {
  328. c.JSON(http.StatusOK, gin.H{
  329. "success": false,
  330. "message": err.Error(),
  331. })
  332. return
  333. }
  334. channel, err := model.GetChannelById(id, false)
  335. if err != nil {
  336. c.JSON(http.StatusOK, gin.H{
  337. "success": false,
  338. "message": err.Error(),
  339. })
  340. return
  341. }
  342. c.JSON(http.StatusOK, gin.H{
  343. "success": true,
  344. "message": "",
  345. "data": channel,
  346. })
  347. return
  348. }
  349. type AddChannelRequest struct {
  350. Mode string `json:"mode"`
  351. Channel *model.Channel `json:"channel"`
  352. }
  353. func AddChannel(c *gin.Context) {
  354. addChannelRequest := AddChannelRequest{}
  355. err := c.ShouldBindJSON(&addChannelRequest)
  356. if err != nil {
  357. c.JSON(http.StatusOK, gin.H{
  358. "success": false,
  359. "message": err.Error(),
  360. })
  361. return
  362. }
  363. if addChannelRequest.Channel == nil || addChannelRequest.Channel.Key == "" {
  364. c.JSON(http.StatusOK, gin.H{
  365. "success": false,
  366. "message": "channel cannot be empty",
  367. })
  368. return
  369. }
  370. // Validate the length of the model name
  371. for _, m := range addChannelRequest.Channel.GetModels() {
  372. if len(m) > 255 {
  373. c.JSON(http.StatusOK, gin.H{
  374. "success": false,
  375. "message": fmt.Sprintf("模型名称过长: %s", m),
  376. })
  377. return
  378. }
  379. }
  380. if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
  381. if addChannelRequest.Channel.Other == "" {
  382. c.JSON(http.StatusOK, gin.H{
  383. "success": false,
  384. "message": "部署地区不能为空",
  385. })
  386. return
  387. } else {
  388. if common.IsJsonStr(addChannelRequest.Channel.Other) {
  389. // must have default
  390. regionMap := common.StrToMap(addChannelRequest.Channel.Other)
  391. if regionMap["default"] == nil {
  392. c.JSON(http.StatusOK, gin.H{
  393. "success": false,
  394. "message": "部署地区必须包含default字段",
  395. })
  396. return
  397. }
  398. }
  399. }
  400. }
  401. addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
  402. keys := make([]string, 0)
  403. switch addChannelRequest.Mode {
  404. case "multi_to_single":
  405. addChannelRequest.Channel.ChannelInfo.MultiKeyMode = true
  406. if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
  407. if !common.IsJsonStr(addChannelRequest.Channel.Key) {
  408. c.JSON(http.StatusOK, gin.H{
  409. "success": false,
  410. "message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入",
  411. })
  412. return
  413. }
  414. toMap := common.StrToMap(addChannelRequest.Channel.Key)
  415. if toMap != nil {
  416. addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(toMap)
  417. } else {
  418. addChannelRequest.Channel.ChannelInfo.MultiKeySize = 0
  419. }
  420. } else {
  421. cleanKeys := make([]string, 0)
  422. for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
  423. if key == "" {
  424. continue
  425. }
  426. cleanKeys = append(cleanKeys, key)
  427. }
  428. addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
  429. addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
  430. }
  431. keys = []string{addChannelRequest.Channel.Key}
  432. case "batch":
  433. if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
  434. // multi json
  435. toMap := common.StrToMap(addChannelRequest.Channel.Key)
  436. if toMap == nil {
  437. c.JSON(http.StatusOK, gin.H{
  438. "success": false,
  439. "message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入",
  440. })
  441. return
  442. }
  443. keys = make([]string, 0, len(toMap))
  444. for k := range toMap {
  445. if k == "" {
  446. continue
  447. }
  448. keys = append(keys, k)
  449. }
  450. } else {
  451. keys = strings.Split(addChannelRequest.Channel.Key, "\n")
  452. }
  453. case "single":
  454. keys = []string{addChannelRequest.Channel.Key}
  455. default:
  456. c.JSON(http.StatusOK, gin.H{
  457. "success": false,
  458. "message": "不支持的添加模式",
  459. })
  460. return
  461. }
  462. channels := make([]model.Channel, 0, len(keys))
  463. for _, key := range keys {
  464. if key == "" {
  465. continue
  466. }
  467. localChannel := addChannelRequest.Channel
  468. localChannel.Key = key
  469. channels = append(channels, *localChannel)
  470. }
  471. err = model.BatchInsertChannels(channels)
  472. if err != nil {
  473. c.JSON(http.StatusOK, gin.H{
  474. "success": false,
  475. "message": err.Error(),
  476. })
  477. return
  478. }
  479. c.JSON(http.StatusOK, gin.H{
  480. "success": true,
  481. "message": "",
  482. })
  483. return
  484. }
  485. func DeleteChannel(c *gin.Context) {
  486. id, _ := strconv.Atoi(c.Param("id"))
  487. channel := model.Channel{Id: id}
  488. err := channel.Delete()
  489. if err != nil {
  490. c.JSON(http.StatusOK, gin.H{
  491. "success": false,
  492. "message": err.Error(),
  493. })
  494. return
  495. }
  496. c.JSON(http.StatusOK, gin.H{
  497. "success": true,
  498. "message": "",
  499. })
  500. return
  501. }
  502. func DeleteDisabledChannel(c *gin.Context) {
  503. rows, err := model.DeleteDisabledChannel()
  504. if err != nil {
  505. c.JSON(http.StatusOK, gin.H{
  506. "success": false,
  507. "message": err.Error(),
  508. })
  509. return
  510. }
  511. c.JSON(http.StatusOK, gin.H{
  512. "success": true,
  513. "message": "",
  514. "data": rows,
  515. })
  516. return
  517. }
  518. type ChannelTag struct {
  519. Tag string `json:"tag"`
  520. NewTag *string `json:"new_tag"`
  521. Priority *int64 `json:"priority"`
  522. Weight *uint `json:"weight"`
  523. ModelMapping *string `json:"model_mapping"`
  524. Models *string `json:"models"`
  525. Groups *string `json:"groups"`
  526. }
  527. func DisableTagChannels(c *gin.Context) {
  528. channelTag := ChannelTag{}
  529. err := c.ShouldBindJSON(&channelTag)
  530. if err != nil || channelTag.Tag == "" {
  531. c.JSON(http.StatusOK, gin.H{
  532. "success": false,
  533. "message": "参数错误",
  534. })
  535. return
  536. }
  537. err = model.DisableChannelByTag(channelTag.Tag)
  538. if err != nil {
  539. c.JSON(http.StatusOK, gin.H{
  540. "success": false,
  541. "message": err.Error(),
  542. })
  543. return
  544. }
  545. c.JSON(http.StatusOK, gin.H{
  546. "success": true,
  547. "message": "",
  548. })
  549. return
  550. }
  551. func EnableTagChannels(c *gin.Context) {
  552. channelTag := ChannelTag{}
  553. err := c.ShouldBindJSON(&channelTag)
  554. if err != nil || channelTag.Tag == "" {
  555. c.JSON(http.StatusOK, gin.H{
  556. "success": false,
  557. "message": "参数错误",
  558. })
  559. return
  560. }
  561. err = model.EnableChannelByTag(channelTag.Tag)
  562. if err != nil {
  563. c.JSON(http.StatusOK, gin.H{
  564. "success": false,
  565. "message": err.Error(),
  566. })
  567. return
  568. }
  569. c.JSON(http.StatusOK, gin.H{
  570. "success": true,
  571. "message": "",
  572. })
  573. return
  574. }
  575. func EditTagChannels(c *gin.Context) {
  576. channelTag := ChannelTag{}
  577. err := c.ShouldBindJSON(&channelTag)
  578. if err != nil {
  579. c.JSON(http.StatusOK, gin.H{
  580. "success": false,
  581. "message": "参数错误",
  582. })
  583. return
  584. }
  585. if channelTag.Tag == "" {
  586. c.JSON(http.StatusOK, gin.H{
  587. "success": false,
  588. "message": "tag不能为空",
  589. })
  590. return
  591. }
  592. err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
  593. if err != nil {
  594. c.JSON(http.StatusOK, gin.H{
  595. "success": false,
  596. "message": err.Error(),
  597. })
  598. return
  599. }
  600. c.JSON(http.StatusOK, gin.H{
  601. "success": true,
  602. "message": "",
  603. })
  604. return
  605. }
  606. type ChannelBatch struct {
  607. Ids []int `json:"ids"`
  608. Tag *string `json:"tag"`
  609. }
  610. func DeleteChannelBatch(c *gin.Context) {
  611. channelBatch := ChannelBatch{}
  612. err := c.ShouldBindJSON(&channelBatch)
  613. if err != nil || len(channelBatch.Ids) == 0 {
  614. c.JSON(http.StatusOK, gin.H{
  615. "success": false,
  616. "message": "参数错误",
  617. })
  618. return
  619. }
  620. err = model.BatchDeleteChannels(channelBatch.Ids)
  621. if err != nil {
  622. c.JSON(http.StatusOK, gin.H{
  623. "success": false,
  624. "message": err.Error(),
  625. })
  626. return
  627. }
  628. c.JSON(http.StatusOK, gin.H{
  629. "success": true,
  630. "message": "",
  631. "data": len(channelBatch.Ids),
  632. })
  633. return
  634. }
  635. func UpdateChannel(c *gin.Context) {
  636. channel := model.Channel{}
  637. err := c.ShouldBindJSON(&channel)
  638. if err != nil {
  639. c.JSON(http.StatusOK, gin.H{
  640. "success": false,
  641. "message": err.Error(),
  642. })
  643. return
  644. }
  645. if channel.Type == constant.ChannelTypeVertexAi {
  646. if channel.Other == "" {
  647. c.JSON(http.StatusOK, gin.H{
  648. "success": false,
  649. "message": "部署地区不能为空",
  650. })
  651. return
  652. } else {
  653. if common.IsJsonStr(channel.Other) {
  654. // must have default
  655. regionMap := common.StrToMap(channel.Other)
  656. if regionMap["default"] == nil {
  657. c.JSON(http.StatusOK, gin.H{
  658. "success": false,
  659. "message": "部署地区必须包含default字段",
  660. })
  661. return
  662. }
  663. }
  664. }
  665. }
  666. err = channel.Update()
  667. if err != nil {
  668. c.JSON(http.StatusOK, gin.H{
  669. "success": false,
  670. "message": err.Error(),
  671. })
  672. return
  673. }
  674. channel.Key = ""
  675. c.JSON(http.StatusOK, gin.H{
  676. "success": true,
  677. "message": "",
  678. "data": channel,
  679. })
  680. return
  681. }
  682. func FetchModels(c *gin.Context) {
  683. var req struct {
  684. BaseURL string `json:"base_url"`
  685. Type int `json:"type"`
  686. Key string `json:"key"`
  687. }
  688. if err := c.ShouldBindJSON(&req); err != nil {
  689. c.JSON(http.StatusBadRequest, gin.H{
  690. "success": false,
  691. "message": "Invalid request",
  692. })
  693. return
  694. }
  695. baseURL := req.BaseURL
  696. if baseURL == "" {
  697. baseURL = constant.ChannelBaseURLs[req.Type]
  698. }
  699. client := &http.Client{}
  700. url := fmt.Sprintf("%s/v1/models", baseURL)
  701. request, err := http.NewRequest("GET", url, nil)
  702. if err != nil {
  703. c.JSON(http.StatusInternalServerError, gin.H{
  704. "success": false,
  705. "message": err.Error(),
  706. })
  707. return
  708. }
  709. // remove line breaks and extra spaces.
  710. key := strings.TrimSpace(req.Key)
  711. // If the key contains a line break, only take the first part.
  712. key = strings.Split(key, "\n")[0]
  713. request.Header.Set("Authorization", "Bearer "+key)
  714. response, err := client.Do(request)
  715. if err != nil {
  716. c.JSON(http.StatusInternalServerError, gin.H{
  717. "success": false,
  718. "message": err.Error(),
  719. })
  720. return
  721. }
  722. //check status code
  723. if response.StatusCode != http.StatusOK {
  724. c.JSON(http.StatusInternalServerError, gin.H{
  725. "success": false,
  726. "message": "Failed to fetch models",
  727. })
  728. return
  729. }
  730. defer response.Body.Close()
  731. var result struct {
  732. Data []struct {
  733. ID string `json:"id"`
  734. } `json:"data"`
  735. }
  736. if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
  737. c.JSON(http.StatusInternalServerError, gin.H{
  738. "success": false,
  739. "message": err.Error(),
  740. })
  741. return
  742. }
  743. var models []string
  744. for _, model := range result.Data {
  745. models = append(models, model.ID)
  746. }
  747. c.JSON(http.StatusOK, gin.H{
  748. "success": true,
  749. "data": models,
  750. })
  751. }
  752. func BatchSetChannelTag(c *gin.Context) {
  753. channelBatch := ChannelBatch{}
  754. err := c.ShouldBindJSON(&channelBatch)
  755. if err != nil || len(channelBatch.Ids) == 0 {
  756. c.JSON(http.StatusOK, gin.H{
  757. "success": false,
  758. "message": "参数错误",
  759. })
  760. return
  761. }
  762. err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
  763. if err != nil {
  764. c.JSON(http.StatusOK, gin.H{
  765. "success": false,
  766. "message": err.Error(),
  767. })
  768. return
  769. }
  770. c.JSON(http.StatusOK, gin.H{
  771. "success": true,
  772. "message": "",
  773. "data": len(channelBatch.Ids),
  774. })
  775. return
  776. }
  777. func GetTagModels(c *gin.Context) {
  778. tag := c.Query("tag")
  779. if tag == "" {
  780. c.JSON(http.StatusBadRequest, gin.H{
  781. "success": false,
  782. "message": "tag不能为空",
  783. })
  784. return
  785. }
  786. channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
  787. if err != nil {
  788. c.JSON(http.StatusInternalServerError, gin.H{
  789. "success": false,
  790. "message": err.Error(),
  791. })
  792. return
  793. }
  794. var longestModels string
  795. maxLength := 0
  796. // Find the longest models string among all channels with the given tag
  797. for _, channel := range channels {
  798. if channel.Models != "" {
  799. currentModels := strings.Split(channel.Models, ",")
  800. if len(currentModels) > maxLength {
  801. maxLength = len(currentModels)
  802. longestModels = channel.Models
  803. }
  804. }
  805. }
  806. c.JSON(http.StatusOK, gin.H{
  807. "success": true,
  808. "message": "",
  809. "data": longestModels,
  810. })
  811. return
  812. }