channel.go 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953
  1. package controller
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/constant"
  12. "github.com/QuantumNous/new-api/dto"
  13. "github.com/QuantumNous/new-api/model"
  14. "github.com/QuantumNous/new-api/relay/channel/gemini"
  15. "github.com/QuantumNous/new-api/relay/channel/ollama"
  16. "github.com/QuantumNous/new-api/service"
  17. "github.com/gin-gonic/gin"
  18. )
  19. type OpenAIModel struct {
  20. ID string `json:"id"`
  21. Object string `json:"object"`
  22. Created int64 `json:"created"`
  23. OwnedBy string `json:"owned_by"`
  24. Metadata map[string]any `json:"metadata,omitempty"`
  25. Permission []struct {
  26. ID string `json:"id"`
  27. Object string `json:"object"`
  28. Created int64 `json:"created"`
  29. AllowCreateEngine bool `json:"allow_create_engine"`
  30. AllowSampling bool `json:"allow_sampling"`
  31. AllowLogprobs bool `json:"allow_logprobs"`
  32. AllowSearchIndices bool `json:"allow_search_indices"`
  33. AllowView bool `json:"allow_view"`
  34. AllowFineTuning bool `json:"allow_fine_tuning"`
  35. Organization string `json:"organization"`
  36. Group string `json:"group"`
  37. IsBlocking bool `json:"is_blocking"`
  38. } `json:"permission"`
  39. Root string `json:"root"`
  40. Parent string `json:"parent"`
  41. }
  42. type OpenAIModelsResponse struct {
  43. Data []OpenAIModel `json:"data"`
  44. Success bool `json:"success"`
  45. }
  46. func parseStatusFilter(statusParam string) int {
  47. switch strings.ToLower(statusParam) {
  48. case "enabled", "1":
  49. return common.ChannelStatusEnabled
  50. case "disabled", "0":
  51. return 0
  52. default:
  53. return -1
  54. }
  55. }
  56. func clearChannelInfo(channel *model.Channel) {
  57. if channel.ChannelInfo.IsMultiKey {
  58. channel.ChannelInfo.MultiKeyDisabledReason = nil
  59. channel.ChannelInfo.MultiKeyDisabledTime = nil
  60. }
  61. }
  62. func GetAllChannels(c *gin.Context) {
  63. pageInfo := common.GetPageQuery(c)
  64. channelData := make([]*model.Channel, 0)
  65. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  66. enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
  67. statusParam := c.Query("status")
  68. // statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
  69. statusFilter := parseStatusFilter(statusParam)
  70. // type filter
  71. typeStr := c.Query("type")
  72. typeFilter := -1
  73. if typeStr != "" {
  74. if t, err := strconv.Atoi(typeStr); err == nil {
  75. typeFilter = t
  76. }
  77. }
  78. var total int64
  79. if enableTagMode {
  80. tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
  81. if err != nil {
  82. common.SysError("failed to get paginated tags: " + err.Error())
  83. c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"})
  84. return
  85. }
  86. for _, tag := range tags {
  87. if tag == nil || *tag == "" {
  88. continue
  89. }
  90. tagChannels, err := model.GetChannelsByTag(*tag, idSort, false)
  91. if err != nil {
  92. continue
  93. }
  94. filtered := make([]*model.Channel, 0)
  95. for _, ch := range tagChannels {
  96. if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
  97. continue
  98. }
  99. if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
  100. continue
  101. }
  102. if typeFilter >= 0 && ch.Type != typeFilter {
  103. continue
  104. }
  105. filtered = append(filtered, ch)
  106. }
  107. channelData = append(channelData, filtered...)
  108. }
  109. total, _ = model.CountAllTags()
  110. } else {
  111. baseQuery := model.DB.Model(&model.Channel{})
  112. if typeFilter >= 0 {
  113. baseQuery = baseQuery.Where("type = ?", typeFilter)
  114. }
  115. if statusFilter == common.ChannelStatusEnabled {
  116. baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
  117. } else if statusFilter == 0 {
  118. baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
  119. }
  120. baseQuery.Count(&total)
  121. order := "priority desc"
  122. if idSort {
  123. order = "id desc"
  124. }
  125. err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
  126. if err != nil {
  127. common.SysError("failed to get channels: " + err.Error())
  128. c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"})
  129. return
  130. }
  131. }
  132. for _, datum := range channelData {
  133. clearChannelInfo(datum)
  134. }
  135. countQuery := model.DB.Model(&model.Channel{})
  136. if statusFilter == common.ChannelStatusEnabled {
  137. countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
  138. } else if statusFilter == 0 {
  139. countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
  140. }
  141. var results []struct {
  142. Type int64
  143. Count int64
  144. }
  145. _ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
  146. typeCounts := make(map[int64]int64)
  147. for _, r := range results {
  148. typeCounts[r.Type] = r.Count
  149. }
  150. common.ApiSuccess(c, gin.H{
  151. "items": channelData,
  152. "total": total,
  153. "page": pageInfo.GetPage(),
  154. "page_size": pageInfo.GetPageSize(),
  155. "type_counts": typeCounts,
  156. })
  157. return
  158. }
  159. func buildFetchModelsHeaders(channel *model.Channel, key string) (http.Header, error) {
  160. var headers http.Header
  161. switch channel.Type {
  162. case constant.ChannelTypeAnthropic:
  163. headers = GetClaudeAuthHeader(key)
  164. default:
  165. headers = GetAuthHeader(key)
  166. }
  167. headerOverride := channel.GetHeaderOverride()
  168. for k, v := range headerOverride {
  169. str, ok := v.(string)
  170. if !ok {
  171. return nil, fmt.Errorf("invalid header override for key %s", k)
  172. }
  173. if strings.Contains(str, "{api_key}") {
  174. str = strings.ReplaceAll(str, "{api_key}", key)
  175. }
  176. headers.Set(k, str)
  177. }
  178. return headers, nil
  179. }
  180. func FetchUpstreamModels(c *gin.Context) {
  181. id, err := strconv.Atoi(c.Param("id"))
  182. if err != nil {
  183. common.ApiError(c, err)
  184. return
  185. }
  186. channel, err := model.GetChannelById(id, true)
  187. if err != nil {
  188. common.ApiError(c, err)
  189. return
  190. }
  191. ids, err := fetchChannelUpstreamModelIDs(channel)
  192. if err != nil {
  193. c.JSON(http.StatusOK, gin.H{
  194. "success": false,
  195. "message": fmt.Sprintf("获取模型列表失败: %s", err.Error()),
  196. })
  197. return
  198. }
  199. c.JSON(http.StatusOK, gin.H{
  200. "success": true,
  201. "message": "",
  202. "data": ids,
  203. })
  204. }
  205. func FixChannelsAbilities(c *gin.Context) {
  206. success, fails, err := model.FixAbility()
  207. if err != nil {
  208. common.ApiError(c, err)
  209. return
  210. }
  211. c.JSON(http.StatusOK, gin.H{
  212. "success": true,
  213. "message": "",
  214. "data": gin.H{
  215. "success": success,
  216. "fails": fails,
  217. },
  218. })
  219. }
  220. func SearchChannels(c *gin.Context) {
  221. keyword := c.Query("keyword")
  222. group := c.Query("group")
  223. modelKeyword := c.Query("model")
  224. statusParam := c.Query("status")
  225. statusFilter := parseStatusFilter(statusParam)
  226. idSort, _ := strconv.ParseBool(c.Query("id_sort"))
  227. enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
  228. channelData := make([]*model.Channel, 0)
  229. if enableTagMode {
  230. tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
  231. if err != nil {
  232. c.JSON(http.StatusOK, gin.H{
  233. "success": false,
  234. "message": err.Error(),
  235. })
  236. return
  237. }
  238. for _, tag := range tags {
  239. if tag != nil && *tag != "" {
  240. tagChannel, err := model.GetChannelsByTag(*tag, idSort, false)
  241. if err == nil {
  242. channelData = append(channelData, tagChannel...)
  243. }
  244. }
  245. }
  246. } else {
  247. channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
  248. if err != nil {
  249. c.JSON(http.StatusOK, gin.H{
  250. "success": false,
  251. "message": err.Error(),
  252. })
  253. return
  254. }
  255. channelData = channels
  256. }
  257. if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
  258. filtered := make([]*model.Channel, 0, len(channelData))
  259. for _, ch := range channelData {
  260. if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
  261. continue
  262. }
  263. if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
  264. continue
  265. }
  266. filtered = append(filtered, ch)
  267. }
  268. channelData = filtered
  269. }
  270. // calculate type counts for search results
  271. typeCounts := make(map[int64]int64)
  272. for _, channel := range channelData {
  273. typeCounts[int64(channel.Type)]++
  274. }
  275. typeParam := c.Query("type")
  276. typeFilter := -1
  277. if typeParam != "" {
  278. if tp, err := strconv.Atoi(typeParam); err == nil {
  279. typeFilter = tp
  280. }
  281. }
  282. if typeFilter >= 0 {
  283. filtered := make([]*model.Channel, 0, len(channelData))
  284. for _, ch := range channelData {
  285. if ch.Type == typeFilter {
  286. filtered = append(filtered, ch)
  287. }
  288. }
  289. channelData = filtered
  290. }
  291. page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
  292. pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
  293. if page < 1 {
  294. page = 1
  295. }
  296. if pageSize <= 0 {
  297. pageSize = 20
  298. }
  299. total := len(channelData)
  300. startIdx := (page - 1) * pageSize
  301. if startIdx > total {
  302. startIdx = total
  303. }
  304. endIdx := startIdx + pageSize
  305. if endIdx > total {
  306. endIdx = total
  307. }
  308. pagedData := channelData[startIdx:endIdx]
  309. for _, datum := range pagedData {
  310. clearChannelInfo(datum)
  311. }
  312. c.JSON(http.StatusOK, gin.H{
  313. "success": true,
  314. "message": "",
  315. "data": gin.H{
  316. "items": pagedData,
  317. "total": total,
  318. "type_counts": typeCounts,
  319. },
  320. })
  321. return
  322. }
  323. func GetChannel(c *gin.Context) {
  324. id, err := strconv.Atoi(c.Param("id"))
  325. if err != nil {
  326. common.ApiError(c, err)
  327. return
  328. }
  329. channel, err := model.GetChannelById(id, false)
  330. if err != nil {
  331. common.ApiError(c, err)
  332. return
  333. }
  334. if channel != nil {
  335. clearChannelInfo(channel)
  336. }
  337. c.JSON(http.StatusOK, gin.H{
  338. "success": true,
  339. "message": "",
  340. "data": channel,
  341. })
  342. return
  343. }
  344. // GetChannelKey 获取渠道密钥(需要通过安全验证中间件)
  345. // 此函数依赖 SecureVerificationRequired 中间件,确保用户已通过安全验证
  346. func GetChannelKey(c *gin.Context) {
  347. userId := c.GetInt("id")
  348. channelId, err := strconv.Atoi(c.Param("id"))
  349. if err != nil {
  350. common.ApiError(c, fmt.Errorf("渠道ID格式错误: %v", err))
  351. return
  352. }
  353. // 获取渠道信息(包含密钥)
  354. channel, err := model.GetChannelById(channelId, true)
  355. if err != nil {
  356. common.ApiError(c, fmt.Errorf("获取渠道信息失败: %v", err))
  357. return
  358. }
  359. if channel == nil {
  360. common.ApiError(c, fmt.Errorf("渠道不存在"))
  361. return
  362. }
  363. // 记录操作日志
  364. model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId))
  365. // 返回渠道密钥
  366. c.JSON(http.StatusOK, gin.H{
  367. "success": true,
  368. "message": "获取成功",
  369. "data": map[string]interface{}{
  370. "key": channel.Key,
  371. },
  372. })
  373. }
  374. // validateTwoFactorAuth 统一的2FA验证函数
  375. func validateTwoFactorAuth(twoFA *model.TwoFA, code string) bool {
  376. // 尝试验证TOTP
  377. if cleanCode, err := common.ValidateNumericCode(code); err == nil {
  378. if isValid, _ := twoFA.ValidateTOTPAndUpdateUsage(cleanCode); isValid {
  379. return true
  380. }
  381. }
  382. // 尝试验证备用码
  383. if isValid, err := twoFA.ValidateBackupCodeAndUpdateUsage(code); err == nil && isValid {
  384. return true
  385. }
  386. return false
  387. }
  388. // validateChannel 通用的渠道校验函数
  389. func validateChannel(channel *model.Channel, isAdd bool) error {
  390. // 校验 channel settings
  391. if err := channel.ValidateSettings(); err != nil {
  392. return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error())
  393. }
  394. // 如果是添加操作,检查 channel 和 key 是否为空
  395. if isAdd {
  396. if channel == nil || channel.Key == "" {
  397. return fmt.Errorf("channel cannot be empty")
  398. }
  399. // 检查模型名称长度是否超过 255
  400. for _, m := range channel.GetModels() {
  401. if len(m) > 255 {
  402. return fmt.Errorf("模型名称过长: %s", m)
  403. }
  404. }
  405. }
  406. // VertexAI 特殊校验
  407. if channel.Type == constant.ChannelTypeVertexAi {
  408. if channel.Other == "" {
  409. return fmt.Errorf("部署地区不能为空")
  410. }
  411. regionMap, err := common.StrToMap(channel.Other)
  412. if err != nil {
  413. return fmt.Errorf("部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}")
  414. }
  415. if regionMap["default"] == nil {
  416. return fmt.Errorf("部署地区必须包含default字段")
  417. }
  418. }
  419. // Codex OAuth key validation (optional, only when JSON object is provided)
  420. if channel.Type == constant.ChannelTypeCodex {
  421. trimmedKey := strings.TrimSpace(channel.Key)
  422. if isAdd || trimmedKey != "" {
  423. if !strings.HasPrefix(trimmedKey, "{") {
  424. return fmt.Errorf("Codex key must be a valid JSON object")
  425. }
  426. var keyMap map[string]any
  427. if err := common.Unmarshal([]byte(trimmedKey), &keyMap); err != nil {
  428. return fmt.Errorf("Codex key must be a valid JSON object")
  429. }
  430. if v, ok := keyMap["access_token"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" {
  431. return fmt.Errorf("Codex key JSON must include access_token")
  432. }
  433. if v, ok := keyMap["account_id"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" {
  434. return fmt.Errorf("Codex key JSON must include account_id")
  435. }
  436. }
  437. }
  438. return nil
  439. }
  440. func RefreshCodexChannelCredential(c *gin.Context) {
  441. channelId, err := strconv.Atoi(c.Param("id"))
  442. if err != nil {
  443. common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
  444. return
  445. }
  446. ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
  447. defer cancel()
  448. oauthKey, ch, err := service.RefreshCodexChannelCredential(ctx, channelId, service.CodexCredentialRefreshOptions{ResetCaches: true})
  449. if err != nil {
  450. common.SysError("failed to refresh codex channel credential: " + err.Error())
  451. c.JSON(http.StatusOK, gin.H{"success": false, "message": "刷新凭证失败,请稍后重试"})
  452. return
  453. }
  454. c.JSON(http.StatusOK, gin.H{
  455. "success": true,
  456. "message": "refreshed",
  457. "data": gin.H{
  458. "expires_at": oauthKey.Expired,
  459. "last_refresh": oauthKey.LastRefresh,
  460. "account_id": oauthKey.AccountID,
  461. "email": oauthKey.Email,
  462. "channel_id": ch.Id,
  463. "channel_type": ch.Type,
  464. "channel_name": ch.Name,
  465. },
  466. })
  467. }
  468. type AddChannelRequest struct {
  469. Mode string `json:"mode"`
  470. MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
  471. BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"`
  472. Channel *model.Channel `json:"channel"`
  473. }
  474. func getVertexArrayKeys(keys string) ([]string, error) {
  475. if keys == "" {
  476. return nil, nil
  477. }
  478. var keyArray []interface{}
  479. err := common.Unmarshal([]byte(keys), &keyArray)
  480. if err != nil {
  481. return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
  482. }
  483. cleanKeys := make([]string, 0, len(keyArray))
  484. for _, key := range keyArray {
  485. var keyStr string
  486. switch v := key.(type) {
  487. case string:
  488. keyStr = strings.TrimSpace(v)
  489. default:
  490. bytes, err := json.Marshal(v)
  491. if err != nil {
  492. return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
  493. }
  494. keyStr = string(bytes)
  495. }
  496. if keyStr != "" {
  497. cleanKeys = append(cleanKeys, keyStr)
  498. }
  499. }
  500. if len(cleanKeys) == 0 {
  501. return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
  502. }
  503. return cleanKeys, nil
  504. }
  505. func AddChannel(c *gin.Context) {
  506. addChannelRequest := AddChannelRequest{}
  507. err := c.ShouldBindJSON(&addChannelRequest)
  508. if err != nil {
  509. common.ApiError(c, err)
  510. return
  511. }
  512. // 使用统一的校验函数
  513. if err := validateChannel(addChannelRequest.Channel, true); err != nil {
  514. c.JSON(http.StatusOK, gin.H{
  515. "success": false,
  516. "message": err.Error(),
  517. })
  518. return
  519. }
  520. addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
  521. keys := make([]string, 0)
  522. switch addChannelRequest.Mode {
  523. case "multi_to_single":
  524. addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
  525. addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
  526. if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
  527. array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
  528. if err != nil {
  529. c.JSON(http.StatusOK, gin.H{
  530. "success": false,
  531. "message": err.Error(),
  532. })
  533. return
  534. }
  535. addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
  536. addChannelRequest.Channel.Key = strings.Join(array, "\n")
  537. } else {
  538. cleanKeys := make([]string, 0)
  539. for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
  540. if key == "" {
  541. continue
  542. }
  543. key = strings.TrimSpace(key)
  544. cleanKeys = append(cleanKeys, key)
  545. }
  546. addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
  547. addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
  548. }
  549. keys = []string{addChannelRequest.Channel.Key}
  550. case "batch":
  551. if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
  552. // multi json
  553. keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
  554. if err != nil {
  555. c.JSON(http.StatusOK, gin.H{
  556. "success": false,
  557. "message": err.Error(),
  558. })
  559. return
  560. }
  561. } else {
  562. keys = strings.Split(addChannelRequest.Channel.Key, "\n")
  563. }
  564. case "single":
  565. keys = []string{addChannelRequest.Channel.Key}
  566. default:
  567. c.JSON(http.StatusOK, gin.H{
  568. "success": false,
  569. "message": "不支持的添加模式",
  570. })
  571. return
  572. }
  573. channels := make([]model.Channel, 0, len(keys))
  574. for _, key := range keys {
  575. if key == "" {
  576. continue
  577. }
  578. localChannel := addChannelRequest.Channel
  579. localChannel.Key = key
  580. if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 {
  581. keyPrefix := localChannel.Key
  582. if len(localChannel.Key) > 8 {
  583. keyPrefix = localChannel.Key[:8]
  584. }
  585. localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix)
  586. }
  587. channels = append(channels, *localChannel)
  588. }
  589. err = model.BatchInsertChannels(channels)
  590. if err != nil {
  591. common.ApiError(c, err)
  592. return
  593. }
  594. service.ResetProxyClientCache()
  595. c.JSON(http.StatusOK, gin.H{
  596. "success": true,
  597. "message": "",
  598. })
  599. return
  600. }
  601. func DeleteChannel(c *gin.Context) {
  602. id, _ := strconv.Atoi(c.Param("id"))
  603. channel := model.Channel{Id: id}
  604. err := channel.Delete()
  605. if err != nil {
  606. common.ApiError(c, err)
  607. return
  608. }
  609. model.InitChannelCache()
  610. c.JSON(http.StatusOK, gin.H{
  611. "success": true,
  612. "message": "",
  613. })
  614. return
  615. }
  616. func DeleteDisabledChannel(c *gin.Context) {
  617. rows, err := model.DeleteDisabledChannel()
  618. if err != nil {
  619. common.ApiError(c, err)
  620. return
  621. }
  622. model.InitChannelCache()
  623. c.JSON(http.StatusOK, gin.H{
  624. "success": true,
  625. "message": "",
  626. "data": rows,
  627. })
  628. return
  629. }
  630. type ChannelTag struct {
  631. Tag string `json:"tag"`
  632. NewTag *string `json:"new_tag"`
  633. Priority *int64 `json:"priority"`
  634. Weight *uint `json:"weight"`
  635. ModelMapping *string `json:"model_mapping"`
  636. Models *string `json:"models"`
  637. Groups *string `json:"groups"`
  638. ParamOverride *string `json:"param_override"`
  639. HeaderOverride *string `json:"header_override"`
  640. }
  641. func DisableTagChannels(c *gin.Context) {
  642. channelTag := ChannelTag{}
  643. err := c.ShouldBindJSON(&channelTag)
  644. if err != nil || channelTag.Tag == "" {
  645. c.JSON(http.StatusOK, gin.H{
  646. "success": false,
  647. "message": "参数错误",
  648. })
  649. return
  650. }
  651. err = model.DisableChannelByTag(channelTag.Tag)
  652. if err != nil {
  653. common.ApiError(c, err)
  654. return
  655. }
  656. model.InitChannelCache()
  657. c.JSON(http.StatusOK, gin.H{
  658. "success": true,
  659. "message": "",
  660. })
  661. return
  662. }
  663. func EnableTagChannels(c *gin.Context) {
  664. channelTag := ChannelTag{}
  665. err := c.ShouldBindJSON(&channelTag)
  666. if err != nil || channelTag.Tag == "" {
  667. c.JSON(http.StatusOK, gin.H{
  668. "success": false,
  669. "message": "参数错误",
  670. })
  671. return
  672. }
  673. err = model.EnableChannelByTag(channelTag.Tag)
  674. if err != nil {
  675. common.ApiError(c, err)
  676. return
  677. }
  678. model.InitChannelCache()
  679. c.JSON(http.StatusOK, gin.H{
  680. "success": true,
  681. "message": "",
  682. })
  683. return
  684. }
  685. func EditTagChannels(c *gin.Context) {
  686. channelTag := ChannelTag{}
  687. err := c.ShouldBindJSON(&channelTag)
  688. if err != nil {
  689. c.JSON(http.StatusOK, gin.H{
  690. "success": false,
  691. "message": "参数错误",
  692. })
  693. return
  694. }
  695. if channelTag.Tag == "" {
  696. c.JSON(http.StatusOK, gin.H{
  697. "success": false,
  698. "message": "tag不能为空",
  699. })
  700. return
  701. }
  702. if channelTag.ParamOverride != nil {
  703. trimmed := strings.TrimSpace(*channelTag.ParamOverride)
  704. if trimmed != "" && !json.Valid([]byte(trimmed)) {
  705. c.JSON(http.StatusOK, gin.H{
  706. "success": false,
  707. "message": "参数覆盖必须是合法的 JSON 格式",
  708. })
  709. return
  710. }
  711. channelTag.ParamOverride = common.GetPointer[string](trimmed)
  712. }
  713. if channelTag.HeaderOverride != nil {
  714. trimmed := strings.TrimSpace(*channelTag.HeaderOverride)
  715. if trimmed != "" && !json.Valid([]byte(trimmed)) {
  716. c.JSON(http.StatusOK, gin.H{
  717. "success": false,
  718. "message": "请求头覆盖必须是合法的 JSON 格式",
  719. })
  720. return
  721. }
  722. channelTag.HeaderOverride = common.GetPointer[string](trimmed)
  723. }
  724. err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight, channelTag.ParamOverride, channelTag.HeaderOverride)
  725. if err != nil {
  726. common.ApiError(c, err)
  727. return
  728. }
  729. model.InitChannelCache()
  730. c.JSON(http.StatusOK, gin.H{
  731. "success": true,
  732. "message": "",
  733. })
  734. return
  735. }
  736. type ChannelBatch struct {
  737. Ids []int `json:"ids"`
  738. Tag *string `json:"tag"`
  739. }
  740. func DeleteChannelBatch(c *gin.Context) {
  741. channelBatch := ChannelBatch{}
  742. err := c.ShouldBindJSON(&channelBatch)
  743. if err != nil || len(channelBatch.Ids) == 0 {
  744. c.JSON(http.StatusOK, gin.H{
  745. "success": false,
  746. "message": "参数错误",
  747. })
  748. return
  749. }
  750. err = model.BatchDeleteChannels(channelBatch.Ids)
  751. if err != nil {
  752. common.ApiError(c, err)
  753. return
  754. }
  755. model.InitChannelCache()
  756. c.JSON(http.StatusOK, gin.H{
  757. "success": true,
  758. "message": "",
  759. "data": len(channelBatch.Ids),
  760. })
  761. return
  762. }
  763. type PatchChannel struct {
  764. model.Channel
  765. MultiKeyMode *string `json:"multi_key_mode"`
  766. KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
  767. }
  768. func UpdateChannel(c *gin.Context) {
  769. channel := PatchChannel{}
  770. err := c.ShouldBindJSON(&channel)
  771. if err != nil {
  772. common.ApiError(c, err)
  773. return
  774. }
  775. // 使用统一的校验函数
  776. if err := validateChannel(&channel.Channel, false); err != nil {
  777. c.JSON(http.StatusOK, gin.H{
  778. "success": false,
  779. "message": err.Error(),
  780. })
  781. return
  782. }
  783. // Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
  784. originChannel, err := model.GetChannelById(channel.Id, true)
  785. if err != nil {
  786. c.JSON(http.StatusOK, gin.H{
  787. "success": false,
  788. "message": err.Error(),
  789. })
  790. return
  791. }
  792. // Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
  793. channel.ChannelInfo = originChannel.ChannelInfo
  794. // If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
  795. if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
  796. channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
  797. }
  798. // 处理多key模式下的密钥追加/覆盖逻辑
  799. if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
  800. switch *channel.KeyMode {
  801. case "append":
  802. // 追加模式:将新密钥添加到现有密钥列表
  803. if originChannel.Key != "" {
  804. var newKeys []string
  805. var existingKeys []string
  806. // 解析现有密钥
  807. if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
  808. // JSON数组格式
  809. var arr []json.RawMessage
  810. if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
  811. existingKeys = make([]string, len(arr))
  812. for i, v := range arr {
  813. existingKeys[i] = string(v)
  814. }
  815. }
  816. } else {
  817. // 换行分隔格式
  818. existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
  819. }
  820. // 处理 Vertex AI 的特殊情况
  821. if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
  822. // 尝试解析新密钥为JSON数组
  823. if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
  824. array, err := getVertexArrayKeys(channel.Key)
  825. if err != nil {
  826. c.JSON(http.StatusOK, gin.H{
  827. "success": false,
  828. "message": "追加密钥解析失败: " + err.Error(),
  829. })
  830. return
  831. }
  832. newKeys = array
  833. } else {
  834. // 单个JSON密钥
  835. newKeys = []string{channel.Key}
  836. }
  837. } else {
  838. // 普通渠道的处理
  839. inputKeys := strings.Split(channel.Key, "\n")
  840. for _, key := range inputKeys {
  841. key = strings.TrimSpace(key)
  842. if key != "" {
  843. newKeys = append(newKeys, key)
  844. }
  845. }
  846. }
  847. seen := make(map[string]struct{}, len(existingKeys)+len(newKeys))
  848. for _, key := range existingKeys {
  849. normalized := strings.TrimSpace(key)
  850. if normalized == "" {
  851. continue
  852. }
  853. seen[normalized] = struct{}{}
  854. }
  855. dedupedNewKeys := make([]string, 0, len(newKeys))
  856. for _, key := range newKeys {
  857. normalized := strings.TrimSpace(key)
  858. if normalized == "" {
  859. continue
  860. }
  861. if _, ok := seen[normalized]; ok {
  862. continue
  863. }
  864. seen[normalized] = struct{}{}
  865. dedupedNewKeys = append(dedupedNewKeys, normalized)
  866. }
  867. allKeys := append(existingKeys, dedupedNewKeys...)
  868. channel.Key = strings.Join(allKeys, "\n")
  869. }
  870. case "replace":
  871. // 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
  872. }
  873. }
  874. err = channel.Update()
  875. if err != nil {
  876. common.ApiError(c, err)
  877. return
  878. }
  879. model.InitChannelCache()
  880. service.ResetProxyClientCache()
  881. channel.Key = ""
  882. clearChannelInfo(&channel.Channel)
  883. c.JSON(http.StatusOK, gin.H{
  884. "success": true,
  885. "message": "",
  886. "data": channel,
  887. })
  888. return
  889. }
  890. func FetchModels(c *gin.Context) {
  891. var req struct {
  892. BaseURL string `json:"base_url"`
  893. Type int `json:"type"`
  894. Key string `json:"key"`
  895. }
  896. if err := c.ShouldBindJSON(&req); err != nil {
  897. c.JSON(http.StatusBadRequest, gin.H{
  898. "success": false,
  899. "message": "Invalid request",
  900. })
  901. return
  902. }
  903. baseURL := req.BaseURL
  904. if baseURL == "" {
  905. baseURL = constant.ChannelBaseURLs[req.Type]
  906. }
  907. // remove line breaks and extra spaces.
  908. key := strings.TrimSpace(req.Key)
  909. key = strings.Split(key, "\n")[0]
  910. if req.Type == constant.ChannelTypeOllama {
  911. models, err := ollama.FetchOllamaModels(baseURL, key)
  912. if err != nil {
  913. c.JSON(http.StatusOK, gin.H{
  914. "success": false,
  915. "message": fmt.Sprintf("获取Ollama模型失败: %s", err.Error()),
  916. })
  917. return
  918. }
  919. names := make([]string, 0, len(models))
  920. for _, modelInfo := range models {
  921. names = append(names, modelInfo.Name)
  922. }
  923. c.JSON(http.StatusOK, gin.H{
  924. "success": true,
  925. "data": names,
  926. })
  927. return
  928. }
  929. if req.Type == constant.ChannelTypeGemini {
  930. models, err := gemini.FetchGeminiModels(baseURL, key, "")
  931. if err != nil {
  932. c.JSON(http.StatusOK, gin.H{
  933. "success": false,
  934. "message": fmt.Sprintf("获取Gemini模型失败: %s", err.Error()),
  935. })
  936. return
  937. }
  938. c.JSON(http.StatusOK, gin.H{
  939. "success": true,
  940. "data": models,
  941. })
  942. return
  943. }
  944. client := &http.Client{}
  945. url := fmt.Sprintf("%s/v1/models", baseURL)
  946. request, err := http.NewRequest("GET", url, nil)
  947. if err != nil {
  948. c.JSON(http.StatusInternalServerError, gin.H{
  949. "success": false,
  950. "message": err.Error(),
  951. })
  952. return
  953. }
  954. request.Header.Set("Authorization", "Bearer "+key)
  955. response, err := client.Do(request)
  956. if err != nil {
  957. c.JSON(http.StatusInternalServerError, gin.H{
  958. "success": false,
  959. "message": err.Error(),
  960. })
  961. return
  962. }
  963. //check status code
  964. if response.StatusCode != http.StatusOK {
  965. c.JSON(http.StatusInternalServerError, gin.H{
  966. "success": false,
  967. "message": "Failed to fetch models",
  968. })
  969. return
  970. }
  971. defer response.Body.Close()
  972. var result struct {
  973. Data []struct {
  974. ID string `json:"id"`
  975. } `json:"data"`
  976. }
  977. if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
  978. c.JSON(http.StatusInternalServerError, gin.H{
  979. "success": false,
  980. "message": err.Error(),
  981. })
  982. return
  983. }
  984. var models []string
  985. for _, model := range result.Data {
  986. models = append(models, model.ID)
  987. }
  988. c.JSON(http.StatusOK, gin.H{
  989. "success": true,
  990. "data": models,
  991. })
  992. }
  993. func BatchSetChannelTag(c *gin.Context) {
  994. channelBatch := ChannelBatch{}
  995. err := c.ShouldBindJSON(&channelBatch)
  996. if err != nil || len(channelBatch.Ids) == 0 {
  997. c.JSON(http.StatusOK, gin.H{
  998. "success": false,
  999. "message": "参数错误",
  1000. })
  1001. return
  1002. }
  1003. err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
  1004. if err != nil {
  1005. common.ApiError(c, err)
  1006. return
  1007. }
  1008. model.InitChannelCache()
  1009. c.JSON(http.StatusOK, gin.H{
  1010. "success": true,
  1011. "message": "",
  1012. "data": len(channelBatch.Ids),
  1013. })
  1014. return
  1015. }
  1016. func GetTagModels(c *gin.Context) {
  1017. tag := c.Query("tag")
  1018. if tag == "" {
  1019. c.JSON(http.StatusBadRequest, gin.H{
  1020. "success": false,
  1021. "message": "tag不能为空",
  1022. })
  1023. return
  1024. }
  1025. channels, err := model.GetChannelsByTag(tag, false, false) // idSort=false, selectAll=false
  1026. if err != nil {
  1027. c.JSON(http.StatusInternalServerError, gin.H{
  1028. "success": false,
  1029. "message": err.Error(),
  1030. })
  1031. return
  1032. }
  1033. var longestModels string
  1034. maxLength := 0
  1035. // Find the longest models string among all channels with the given tag
  1036. for _, channel := range channels {
  1037. if channel.Models != "" {
  1038. currentModels := strings.Split(channel.Models, ",")
  1039. if len(currentModels) > maxLength {
  1040. maxLength = len(currentModels)
  1041. longestModels = channel.Models
  1042. }
  1043. }
  1044. }
  1045. c.JSON(http.StatusOK, gin.H{
  1046. "success": true,
  1047. "message": "",
  1048. "data": longestModels,
  1049. })
  1050. return
  1051. }
  1052. // CopyChannel handles cloning an existing channel with its key.
  1053. // POST /api/channel/copy/:id
  1054. // Optional query params:
  1055. //
  1056. // suffix - string appended to the original name (default "_复制")
  1057. // reset_balance - bool, when true will reset balance & used_quota to 0 (default true)
  1058. func CopyChannel(c *gin.Context) {
  1059. id, err := strconv.Atoi(c.Param("id"))
  1060. if err != nil {
  1061. c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
  1062. return
  1063. }
  1064. suffix := c.DefaultQuery("suffix", "_复制")
  1065. resetBalance := true
  1066. if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
  1067. if v, err := strconv.ParseBool(rbStr); err == nil {
  1068. resetBalance = v
  1069. }
  1070. }
  1071. // fetch original channel with key
  1072. origin, err := model.GetChannelById(id, true)
  1073. if err != nil {
  1074. common.SysError("failed to get channel by id: " + err.Error())
  1075. c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道信息失败,请稍后重试"})
  1076. return
  1077. }
  1078. // clone channel
  1079. clone := *origin // shallow copy is sufficient as we will overwrite primitives
  1080. clone.Id = 0 // let DB auto-generate
  1081. clone.CreatedTime = common.GetTimestamp()
  1082. clone.Name = origin.Name + suffix
  1083. clone.TestTime = 0
  1084. clone.ResponseTime = 0
  1085. if resetBalance {
  1086. clone.Balance = 0
  1087. clone.UsedQuota = 0
  1088. }
  1089. // insert
  1090. if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
  1091. common.SysError("failed to clone channel: " + err.Error())
  1092. c.JSON(http.StatusOK, gin.H{"success": false, "message": "复制渠道失败,请稍后重试"})
  1093. return
  1094. }
  1095. model.InitChannelCache()
  1096. // success
  1097. c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
  1098. }
  1099. // MultiKeyManageRequest represents the request for multi-key management operations
  1100. type MultiKeyManageRequest struct {
  1101. ChannelId int `json:"channel_id"`
  1102. Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status"
  1103. KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions
  1104. Page int `json:"page,omitempty"` // for get_key_status pagination
  1105. PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
  1106. Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
  1107. }
  1108. // MultiKeyStatusResponse represents the response for key status query
  1109. type MultiKeyStatusResponse struct {
  1110. Keys []KeyStatus `json:"keys"`
  1111. Total int `json:"total"`
  1112. Page int `json:"page"`
  1113. PageSize int `json:"page_size"`
  1114. TotalPages int `json:"total_pages"`
  1115. // Statistics
  1116. EnabledCount int `json:"enabled_count"`
  1117. ManualDisabledCount int `json:"manual_disabled_count"`
  1118. AutoDisabledCount int `json:"auto_disabled_count"`
  1119. }
  1120. type KeyStatus struct {
  1121. Index int `json:"index"`
  1122. Status int `json:"status"` // 1: enabled, 2: disabled
  1123. DisabledTime int64 `json:"disabled_time,omitempty"`
  1124. Reason string `json:"reason,omitempty"`
  1125. KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
  1126. }
  1127. // ManageMultiKeys handles multi-key management operations
  1128. func ManageMultiKeys(c *gin.Context) {
  1129. request := MultiKeyManageRequest{}
  1130. err := c.ShouldBindJSON(&request)
  1131. if err != nil {
  1132. common.ApiError(c, err)
  1133. return
  1134. }
  1135. channel, err := model.GetChannelById(request.ChannelId, true)
  1136. if err != nil {
  1137. c.JSON(http.StatusOK, gin.H{
  1138. "success": false,
  1139. "message": "渠道不存在",
  1140. })
  1141. return
  1142. }
  1143. if !channel.ChannelInfo.IsMultiKey {
  1144. c.JSON(http.StatusOK, gin.H{
  1145. "success": false,
  1146. "message": "该渠道不是多密钥模式",
  1147. })
  1148. return
  1149. }
  1150. lock := model.GetChannelPollingLock(channel.Id)
  1151. lock.Lock()
  1152. defer lock.Unlock()
  1153. switch request.Action {
  1154. case "get_key_status":
  1155. keys := channel.GetKeys()
  1156. // Default pagination parameters
  1157. page := request.Page
  1158. pageSize := request.PageSize
  1159. if page <= 0 {
  1160. page = 1
  1161. }
  1162. if pageSize <= 0 {
  1163. pageSize = 50 // Default page size
  1164. }
  1165. // Statistics for all keys (unchanged by filtering)
  1166. var enabledCount, manualDisabledCount, autoDisabledCount int
  1167. // Build all key status data first
  1168. var allKeyStatusList []KeyStatus
  1169. for i, key := range keys {
  1170. status := 1 // default enabled
  1171. var disabledTime int64
  1172. var reason string
  1173. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1174. if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
  1175. status = s
  1176. }
  1177. }
  1178. // Count for statistics (all keys)
  1179. switch status {
  1180. case 1:
  1181. enabledCount++
  1182. case 2:
  1183. manualDisabledCount++
  1184. case 3:
  1185. autoDisabledCount++
  1186. }
  1187. if status != 1 {
  1188. if channel.ChannelInfo.MultiKeyDisabledTime != nil {
  1189. disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
  1190. }
  1191. if channel.ChannelInfo.MultiKeyDisabledReason != nil {
  1192. reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
  1193. }
  1194. }
  1195. // Create key preview (first 10 chars)
  1196. keyPreview := key
  1197. if len(key) > 10 {
  1198. keyPreview = key[:10] + "..."
  1199. }
  1200. allKeyStatusList = append(allKeyStatusList, KeyStatus{
  1201. Index: i,
  1202. Status: status,
  1203. DisabledTime: disabledTime,
  1204. Reason: reason,
  1205. KeyPreview: keyPreview,
  1206. })
  1207. }
  1208. // Apply status filter if specified
  1209. var filteredKeyStatusList []KeyStatus
  1210. if request.Status != nil {
  1211. for _, keyStatus := range allKeyStatusList {
  1212. if keyStatus.Status == *request.Status {
  1213. filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
  1214. }
  1215. }
  1216. } else {
  1217. filteredKeyStatusList = allKeyStatusList
  1218. }
  1219. // Calculate pagination based on filtered results
  1220. filteredTotal := len(filteredKeyStatusList)
  1221. totalPages := (filteredTotal + pageSize - 1) / pageSize
  1222. if totalPages == 0 {
  1223. totalPages = 1
  1224. }
  1225. if page > totalPages {
  1226. page = totalPages
  1227. }
  1228. // Calculate range for current page
  1229. start := (page - 1) * pageSize
  1230. end := start + pageSize
  1231. if end > filteredTotal {
  1232. end = filteredTotal
  1233. }
  1234. // Get the page data
  1235. var pageKeyStatusList []KeyStatus
  1236. if start < filteredTotal {
  1237. pageKeyStatusList = filteredKeyStatusList[start:end]
  1238. }
  1239. c.JSON(http.StatusOK, gin.H{
  1240. "success": true,
  1241. "message": "",
  1242. "data": MultiKeyStatusResponse{
  1243. Keys: pageKeyStatusList,
  1244. Total: filteredTotal, // Total of filtered results
  1245. Page: page,
  1246. PageSize: pageSize,
  1247. TotalPages: totalPages,
  1248. EnabledCount: enabledCount, // Overall statistics
  1249. ManualDisabledCount: manualDisabledCount, // Overall statistics
  1250. AutoDisabledCount: autoDisabledCount, // Overall statistics
  1251. },
  1252. })
  1253. return
  1254. case "disable_key":
  1255. if request.KeyIndex == nil {
  1256. c.JSON(http.StatusOK, gin.H{
  1257. "success": false,
  1258. "message": "未指定要禁用的密钥索引",
  1259. })
  1260. return
  1261. }
  1262. keyIndex := *request.KeyIndex
  1263. if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
  1264. c.JSON(http.StatusOK, gin.H{
  1265. "success": false,
  1266. "message": "密钥索引超出范围",
  1267. })
  1268. return
  1269. }
  1270. if channel.ChannelInfo.MultiKeyStatusList == nil {
  1271. channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
  1272. }
  1273. if channel.ChannelInfo.MultiKeyDisabledTime == nil {
  1274. channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
  1275. }
  1276. if channel.ChannelInfo.MultiKeyDisabledReason == nil {
  1277. channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
  1278. }
  1279. channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
  1280. err = channel.Update()
  1281. if err != nil {
  1282. common.ApiError(c, err)
  1283. return
  1284. }
  1285. model.InitChannelCache()
  1286. c.JSON(http.StatusOK, gin.H{
  1287. "success": true,
  1288. "message": "密钥已禁用",
  1289. })
  1290. return
  1291. case "enable_key":
  1292. if request.KeyIndex == nil {
  1293. c.JSON(http.StatusOK, gin.H{
  1294. "success": false,
  1295. "message": "未指定要启用的密钥索引",
  1296. })
  1297. return
  1298. }
  1299. keyIndex := *request.KeyIndex
  1300. if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
  1301. c.JSON(http.StatusOK, gin.H{
  1302. "success": false,
  1303. "message": "密钥索引超出范围",
  1304. })
  1305. return
  1306. }
  1307. // 从状态列表中删除该密钥的记录,使其回到默认启用状态
  1308. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1309. delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
  1310. }
  1311. if channel.ChannelInfo.MultiKeyDisabledTime != nil {
  1312. delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
  1313. }
  1314. if channel.ChannelInfo.MultiKeyDisabledReason != nil {
  1315. delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
  1316. }
  1317. err = channel.Update()
  1318. if err != nil {
  1319. common.ApiError(c, err)
  1320. return
  1321. }
  1322. model.InitChannelCache()
  1323. c.JSON(http.StatusOK, gin.H{
  1324. "success": true,
  1325. "message": "密钥已启用",
  1326. })
  1327. return
  1328. case "enable_all_keys":
  1329. // 清空所有禁用状态,使所有密钥回到默认启用状态
  1330. var enabledCount int
  1331. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1332. enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
  1333. }
  1334. channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
  1335. channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
  1336. channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
  1337. err = channel.Update()
  1338. if err != nil {
  1339. common.ApiError(c, err)
  1340. return
  1341. }
  1342. model.InitChannelCache()
  1343. c.JSON(http.StatusOK, gin.H{
  1344. "success": true,
  1345. "message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
  1346. })
  1347. return
  1348. case "disable_all_keys":
  1349. // 禁用所有启用的密钥
  1350. if channel.ChannelInfo.MultiKeyStatusList == nil {
  1351. channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
  1352. }
  1353. if channel.ChannelInfo.MultiKeyDisabledTime == nil {
  1354. channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
  1355. }
  1356. if channel.ChannelInfo.MultiKeyDisabledReason == nil {
  1357. channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
  1358. }
  1359. var disabledCount int
  1360. for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
  1361. status := 1 // default enabled
  1362. if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
  1363. status = s
  1364. }
  1365. // 只禁用当前启用的密钥
  1366. if status == 1 {
  1367. channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
  1368. disabledCount++
  1369. }
  1370. }
  1371. if disabledCount == 0 {
  1372. c.JSON(http.StatusOK, gin.H{
  1373. "success": false,
  1374. "message": "没有可禁用的密钥",
  1375. })
  1376. return
  1377. }
  1378. err = channel.Update()
  1379. if err != nil {
  1380. common.ApiError(c, err)
  1381. return
  1382. }
  1383. model.InitChannelCache()
  1384. c.JSON(http.StatusOK, gin.H{
  1385. "success": true,
  1386. "message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
  1387. })
  1388. return
  1389. case "delete_key":
  1390. if request.KeyIndex == nil {
  1391. c.JSON(http.StatusOK, gin.H{
  1392. "success": false,
  1393. "message": "未指定要删除的密钥索引",
  1394. })
  1395. return
  1396. }
  1397. keyIndex := *request.KeyIndex
  1398. if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
  1399. c.JSON(http.StatusOK, gin.H{
  1400. "success": false,
  1401. "message": "密钥索引超出范围",
  1402. })
  1403. return
  1404. }
  1405. keys := channel.GetKeys()
  1406. var remainingKeys []string
  1407. var newStatusList = make(map[int]int)
  1408. var newDisabledTime = make(map[int]int64)
  1409. var newDisabledReason = make(map[int]string)
  1410. newIndex := 0
  1411. for i, key := range keys {
  1412. // 跳过要删除的密钥
  1413. if i == keyIndex {
  1414. continue
  1415. }
  1416. remainingKeys = append(remainingKeys, key)
  1417. // 保留其他密钥的状态信息,重新索引
  1418. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1419. if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 {
  1420. newStatusList[newIndex] = status
  1421. }
  1422. }
  1423. if channel.ChannelInfo.MultiKeyDisabledTime != nil {
  1424. if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
  1425. newDisabledTime[newIndex] = t
  1426. }
  1427. }
  1428. if channel.ChannelInfo.MultiKeyDisabledReason != nil {
  1429. if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
  1430. newDisabledReason[newIndex] = r
  1431. }
  1432. }
  1433. newIndex++
  1434. }
  1435. if len(remainingKeys) == 0 {
  1436. c.JSON(http.StatusOK, gin.H{
  1437. "success": false,
  1438. "message": "不能删除最后一个密钥",
  1439. })
  1440. return
  1441. }
  1442. // Update channel with remaining keys
  1443. channel.Key = strings.Join(remainingKeys, "\n")
  1444. channel.ChannelInfo.MultiKeySize = len(remainingKeys)
  1445. channel.ChannelInfo.MultiKeyStatusList = newStatusList
  1446. channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
  1447. channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
  1448. err = channel.Update()
  1449. if err != nil {
  1450. common.ApiError(c, err)
  1451. return
  1452. }
  1453. model.InitChannelCache()
  1454. c.JSON(http.StatusOK, gin.H{
  1455. "success": true,
  1456. "message": "密钥已删除",
  1457. })
  1458. return
  1459. case "delete_disabled_keys":
  1460. keys := channel.GetKeys()
  1461. var remainingKeys []string
  1462. var deletedCount int
  1463. var newStatusList = make(map[int]int)
  1464. var newDisabledTime = make(map[int]int64)
  1465. var newDisabledReason = make(map[int]string)
  1466. newIndex := 0
  1467. for i, key := range keys {
  1468. status := 1 // default enabled
  1469. if channel.ChannelInfo.MultiKeyStatusList != nil {
  1470. if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
  1471. status = s
  1472. }
  1473. }
  1474. // 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥
  1475. if status == 3 {
  1476. deletedCount++
  1477. } else {
  1478. remainingKeys = append(remainingKeys, key)
  1479. // 保留非自动禁用密钥的状态信息,重新索引
  1480. if status != 1 {
  1481. newStatusList[newIndex] = status
  1482. if channel.ChannelInfo.MultiKeyDisabledTime != nil {
  1483. if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
  1484. newDisabledTime[newIndex] = t
  1485. }
  1486. }
  1487. if channel.ChannelInfo.MultiKeyDisabledReason != nil {
  1488. if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
  1489. newDisabledReason[newIndex] = r
  1490. }
  1491. }
  1492. }
  1493. newIndex++
  1494. }
  1495. }
  1496. if deletedCount == 0 {
  1497. c.JSON(http.StatusOK, gin.H{
  1498. "success": false,
  1499. "message": "没有需要删除的自动禁用密钥",
  1500. })
  1501. return
  1502. }
  1503. // Update channel with remaining keys
  1504. channel.Key = strings.Join(remainingKeys, "\n")
  1505. channel.ChannelInfo.MultiKeySize = len(remainingKeys)
  1506. channel.ChannelInfo.MultiKeyStatusList = newStatusList
  1507. channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
  1508. channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
  1509. err = channel.Update()
  1510. if err != nil {
  1511. common.ApiError(c, err)
  1512. return
  1513. }
  1514. model.InitChannelCache()
  1515. c.JSON(http.StatusOK, gin.H{
  1516. "success": true,
  1517. "message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
  1518. "data": deletedCount,
  1519. })
  1520. return
  1521. default:
  1522. c.JSON(http.StatusOK, gin.H{
  1523. "success": false,
  1524. "message": "不支持的操作",
  1525. })
  1526. return
  1527. }
  1528. }
  1529. // OllamaPullModel 拉取 Ollama 模型
  1530. func OllamaPullModel(c *gin.Context) {
  1531. var req struct {
  1532. ChannelID int `json:"channel_id"`
  1533. ModelName string `json:"model_name"`
  1534. }
  1535. if err := c.ShouldBindJSON(&req); err != nil {
  1536. c.JSON(http.StatusBadRequest, gin.H{
  1537. "success": false,
  1538. "message": "Invalid request parameters",
  1539. })
  1540. return
  1541. }
  1542. if req.ChannelID == 0 || req.ModelName == "" {
  1543. c.JSON(http.StatusBadRequest, gin.H{
  1544. "success": false,
  1545. "message": "Channel ID and model name are required",
  1546. })
  1547. return
  1548. }
  1549. // 获取渠道信息
  1550. channel, err := model.GetChannelById(req.ChannelID, true)
  1551. if err != nil {
  1552. c.JSON(http.StatusNotFound, gin.H{
  1553. "success": false,
  1554. "message": "Channel not found",
  1555. })
  1556. return
  1557. }
  1558. // 检查是否是 Ollama 渠道
  1559. if channel.Type != constant.ChannelTypeOllama {
  1560. c.JSON(http.StatusBadRequest, gin.H{
  1561. "success": false,
  1562. "message": "This operation is only supported for Ollama channels",
  1563. })
  1564. return
  1565. }
  1566. baseURL := constant.ChannelBaseURLs[channel.Type]
  1567. if channel.GetBaseURL() != "" {
  1568. baseURL = channel.GetBaseURL()
  1569. }
  1570. key := strings.Split(channel.Key, "\n")[0]
  1571. err = ollama.PullOllamaModel(baseURL, key, req.ModelName)
  1572. if err != nil {
  1573. c.JSON(http.StatusInternalServerError, gin.H{
  1574. "success": false,
  1575. "message": fmt.Sprintf("Failed to pull model: %s", err.Error()),
  1576. })
  1577. return
  1578. }
  1579. c.JSON(http.StatusOK, gin.H{
  1580. "success": true,
  1581. "message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
  1582. })
  1583. }
  1584. // OllamaPullModelStream 流式拉取 Ollama 模型
  1585. func OllamaPullModelStream(c *gin.Context) {
  1586. var req struct {
  1587. ChannelID int `json:"channel_id"`
  1588. ModelName string `json:"model_name"`
  1589. }
  1590. if err := c.ShouldBindJSON(&req); err != nil {
  1591. c.JSON(http.StatusBadRequest, gin.H{
  1592. "success": false,
  1593. "message": "Invalid request parameters",
  1594. })
  1595. return
  1596. }
  1597. if req.ChannelID == 0 || req.ModelName == "" {
  1598. c.JSON(http.StatusBadRequest, gin.H{
  1599. "success": false,
  1600. "message": "Channel ID and model name are required",
  1601. })
  1602. return
  1603. }
  1604. // 获取渠道信息
  1605. channel, err := model.GetChannelById(req.ChannelID, true)
  1606. if err != nil {
  1607. c.JSON(http.StatusNotFound, gin.H{
  1608. "success": false,
  1609. "message": "Channel not found",
  1610. })
  1611. return
  1612. }
  1613. // 检查是否是 Ollama 渠道
  1614. if channel.Type != constant.ChannelTypeOllama {
  1615. c.JSON(http.StatusBadRequest, gin.H{
  1616. "success": false,
  1617. "message": "This operation is only supported for Ollama channels",
  1618. })
  1619. return
  1620. }
  1621. baseURL := constant.ChannelBaseURLs[channel.Type]
  1622. if channel.GetBaseURL() != "" {
  1623. baseURL = channel.GetBaseURL()
  1624. }
  1625. // 设置 SSE 头部
  1626. c.Header("Content-Type", "text/event-stream")
  1627. c.Header("Cache-Control", "no-cache")
  1628. c.Header("Connection", "keep-alive")
  1629. c.Header("Access-Control-Allow-Origin", "*")
  1630. key := strings.Split(channel.Key, "\n")[0]
  1631. // 创建进度回调函数
  1632. progressCallback := func(progress ollama.OllamaPullResponse) {
  1633. data, _ := json.Marshal(progress)
  1634. fmt.Fprintf(c.Writer, "data: %s\n\n", string(data))
  1635. c.Writer.Flush()
  1636. }
  1637. // 执行拉取
  1638. err = ollama.PullOllamaModelStream(baseURL, key, req.ModelName, progressCallback)
  1639. if err != nil {
  1640. errorData, _ := json.Marshal(gin.H{
  1641. "error": err.Error(),
  1642. })
  1643. fmt.Fprintf(c.Writer, "data: %s\n\n", string(errorData))
  1644. } else {
  1645. successData, _ := json.Marshal(gin.H{
  1646. "message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
  1647. })
  1648. fmt.Fprintf(c.Writer, "data: %s\n\n", string(successData))
  1649. }
  1650. // 发送结束标志
  1651. fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
  1652. c.Writer.Flush()
  1653. }
  1654. // OllamaDeleteModel 删除 Ollama 模型
  1655. func OllamaDeleteModel(c *gin.Context) {
  1656. var req struct {
  1657. ChannelID int `json:"channel_id"`
  1658. ModelName string `json:"model_name"`
  1659. }
  1660. if err := c.ShouldBindJSON(&req); err != nil {
  1661. c.JSON(http.StatusBadRequest, gin.H{
  1662. "success": false,
  1663. "message": "Invalid request parameters",
  1664. })
  1665. return
  1666. }
  1667. if req.ChannelID == 0 || req.ModelName == "" {
  1668. c.JSON(http.StatusBadRequest, gin.H{
  1669. "success": false,
  1670. "message": "Channel ID and model name are required",
  1671. })
  1672. return
  1673. }
  1674. // 获取渠道信息
  1675. channel, err := model.GetChannelById(req.ChannelID, true)
  1676. if err != nil {
  1677. c.JSON(http.StatusNotFound, gin.H{
  1678. "success": false,
  1679. "message": "Channel not found",
  1680. })
  1681. return
  1682. }
  1683. // 检查是否是 Ollama 渠道
  1684. if channel.Type != constant.ChannelTypeOllama {
  1685. c.JSON(http.StatusBadRequest, gin.H{
  1686. "success": false,
  1687. "message": "This operation is only supported for Ollama channels",
  1688. })
  1689. return
  1690. }
  1691. baseURL := constant.ChannelBaseURLs[channel.Type]
  1692. if channel.GetBaseURL() != "" {
  1693. baseURL = channel.GetBaseURL()
  1694. }
  1695. key := strings.Split(channel.Key, "\n")[0]
  1696. err = ollama.DeleteOllamaModel(baseURL, key, req.ModelName)
  1697. if err != nil {
  1698. c.JSON(http.StatusInternalServerError, gin.H{
  1699. "success": false,
  1700. "message": fmt.Sprintf("Failed to delete model: %s", err.Error()),
  1701. })
  1702. return
  1703. }
  1704. c.JSON(http.StatusOK, gin.H{
  1705. "success": true,
  1706. "message": fmt.Sprintf("Model %s deleted successfully", req.ModelName),
  1707. })
  1708. }
  1709. // OllamaVersion 获取 Ollama 服务版本信息
  1710. func OllamaVersion(c *gin.Context) {
  1711. id, err := strconv.Atoi(c.Param("id"))
  1712. if err != nil {
  1713. c.JSON(http.StatusBadRequest, gin.H{
  1714. "success": false,
  1715. "message": "Invalid channel id",
  1716. })
  1717. return
  1718. }
  1719. channel, err := model.GetChannelById(id, true)
  1720. if err != nil {
  1721. c.JSON(http.StatusNotFound, gin.H{
  1722. "success": false,
  1723. "message": "Channel not found",
  1724. })
  1725. return
  1726. }
  1727. if channel.Type != constant.ChannelTypeOllama {
  1728. c.JSON(http.StatusBadRequest, gin.H{
  1729. "success": false,
  1730. "message": "This operation is only supported for Ollama channels",
  1731. })
  1732. return
  1733. }
  1734. baseURL := constant.ChannelBaseURLs[channel.Type]
  1735. if channel.GetBaseURL() != "" {
  1736. baseURL = channel.GetBaseURL()
  1737. }
  1738. key := strings.Split(channel.Key, "\n")[0]
  1739. version, err := ollama.FetchOllamaVersion(baseURL, key)
  1740. if err != nil {
  1741. c.JSON(http.StatusOK, gin.H{
  1742. "success": false,
  1743. "message": fmt.Sprintf("获取Ollama版本失败: %s", err.Error()),
  1744. })
  1745. return
  1746. }
  1747. c.JSON(http.StatusOK, gin.H{
  1748. "success": true,
  1749. "data": gin.H{
  1750. "version": version,
  1751. },
  1752. })
  1753. }