Ver código fonte

Merge pull request #271 from p3psi-boo/main

feat: 添加同步上游模型列表按钮
Calcium-Ion 1 ano atrás
pai
commit
c86bff38ac
3 arquivos alterados com 104 adições e 0 exclusões
  1. 89 0
      controller/channel.go
  2. 2 0
      router/api-router.go
  3. 13 0
      web/src/pages/Channel/EditChannel.js

+ 89 - 0
controller/channel.go

@@ -1,6 +1,8 @@
 package controller
 package controller
 
 
 import (
 import (
+	"encoding/json"
+	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
@@ -9,6 +11,34 @@ import (
 	"strings"
 	"strings"
 )
 )
 
 
+type OpenAIModel struct {
+	ID         string `json:"id"`
+	Object     string `json:"object"`
+	Created    int64  `json:"created"`
+	OwnedBy    string `json:"owned_by"`
+	Permission []struct {
+		ID                 string `json:"id"`
+		Object             string `json:"object"`
+		Created            int64  `json:"created"`
+		AllowCreateEngine  bool   `json:"allow_create_engine"`
+		AllowSampling      bool   `json:"allow_sampling"`
+		AllowLogprobs      bool   `json:"allow_logprobs"`
+		AllowSearchIndices bool   `json:"allow_search_indices"`
+		AllowView          bool   `json:"allow_view"`
+		AllowFineTuning    bool   `json:"allow_fine_tuning"`
+		Organization       string `json:"organization"`
+		Group              string `json:"group"`
+		IsBlocking         bool   `json:"is_blocking"`
+	} `json:"permission"`
+	Root   string `json:"root"`
+	Parent string `json:"parent"`
+}
+
+type OpenAIModelsResponse struct {
+	Data    []OpenAIModel `json:"data"`
+	Success bool          `json:"success"`
+}
+
 func GetAllChannels(c *gin.Context) {
 func GetAllChannels(c *gin.Context) {
 	p, _ := strconv.Atoi(c.Query("p"))
 	p, _ := strconv.Atoi(c.Query("p"))
 	pageSize, _ := strconv.Atoi(c.Query("page_size"))
 	pageSize, _ := strconv.Atoi(c.Query("page_size"))
@@ -35,6 +65,65 @@ func GetAllChannels(c *gin.Context) {
 	return
 	return
 }
 }
 
 
+func FetchUpstreamModels(c *gin.Context) {
+	id, err := strconv.Atoi(c.Param("id"))
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	channel, err := model.GetChannelById(id, true)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	if channel.Type != common.ChannelTypeOpenAI {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "仅支持 OpenAI 类型渠道",
+		})
+		return
+	}
+	url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+	}
+	result := OpenAIModelsResponse{}
+	err = json.Unmarshal(body, &result)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+	}
+	if !result.Success {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "上游返回错误",
+		})
+	}
+
+	var ids []string
+	for _, model := range result.Data {
+		ids = append(ids, model.ID)
+	}
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    ids,
+	})
+}
+
 func FixChannelsAbilities(c *gin.Context) {
 func FixChannelsAbilities(c *gin.Context) {
 	count, err := model.FixAbility()
 	count, err := model.FixAbility()
 	if err != nil {
 	if err != nil {

+ 2 - 0
router/api-router.go

@@ -90,6 +90,8 @@ func SetApiRouter(router *gin.Engine) {
 			channelRoute.DELETE("/:id", controller.DeleteChannel)
 			channelRoute.DELETE("/:id", controller.DeleteChannel)
 			channelRoute.POST("/batch", controller.DeleteChannelBatch)
 			channelRoute.POST("/batch", controller.DeleteChannelBatch)
 			channelRoute.POST("/fix", controller.FixChannelsAbilities)
 			channelRoute.POST("/fix", controller.FixChannelsAbilities)
+			channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
+
 		}
 		}
 		tokenRoute := apiRouter.Group("/token")
 		tokenRoute := apiRouter.Group("/token")
 		tokenRoute.Use(middleware.UserAuth())
 		tokenRoute.Use(middleware.UserAuth())

+ 13 - 0
web/src/pages/Channel/EditChannel.js

@@ -15,6 +15,7 @@ import {
   Space,
   Space,
   Spin,
   Spin,
   Button,
   Button,
+  Tooltip,
   Input,
   Input,
   Typography,
   Typography,
   Select,
   Select,
@@ -24,6 +25,7 @@ import {
 } from '@douyinfe/semi-ui';
 } from '@douyinfe/semi-ui';
 import { Divider } from 'semantic-ui-react';
 import { Divider } from 'semantic-ui-react';
 import { getChannelModels, loadChannelModels } from '../../components/utils.js';
 import { getChannelModels, loadChannelModels } from '../../components/utils.js';
+import axios from 'axios';
 
 
 const MODEL_MAPPING_EXAMPLE = {
 const MODEL_MAPPING_EXAMPLE = {
   'gpt-3.5-turbo-0301': 'gpt-3.5-turbo',
   'gpt-3.5-turbo-0301': 'gpt-3.5-turbo',
@@ -331,6 +333,7 @@ const EditChannel = (props) => {
     handleInputChange('models', localModels);
     handleInputChange('models', localModels);
   };
   };
 
 
+
   return (
   return (
     <>
     <>
       <SideSheet
       <SideSheet
@@ -550,6 +553,16 @@ const EditChannel = (props) => {
               >
               >
                 填入所有模型
                 填入所有模型
               </Button>
               </Button>
+              <Tooltip content={fetchButtonTips}>
+                <Button
+                  type='tertiary'
+                  onClick={() => {
+                    fetchUpstreamModelList('models');
+                  }}
+                >
+                  获取模型列表
+                </Button>
+              </Tooltip>
               <Button
               <Button
                 type='warning'
                 type='warning'
                 onClick={() => {
                 onClick={() => {