Преглед изворни кода

feat: able to test channels now (#59)

JustSong пре 2 година
родитељ
комит
443a22b75d
4 измењених фајлова са 133 додато и 1 уклоњено
  1. 99 0
      controller/channel.go
  2. 13 0
      controller/relay.go
  3. 1 0
      router/api-router.go
  4. 20 1
      web/src/components/ChannelsTable.js

+ 99 - 0
controller/channel.go

@@ -1,12 +1,17 @@
 package controller
 
 import (
+	"bytes"
+	"encoding/json"
+	"errors"
+	"fmt"
 	"github.com/gin-gonic/gin"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
 	"strconv"
 	"strings"
+	"time"
 )
 
 func GetAllChannels(c *gin.Context) {
@@ -153,3 +158,97 @@ func UpdateChannel(c *gin.Context) {
 	})
 	return
 }
+
+func testChannel(channel *model.Channel, request *ChatRequest) error {
+	if request.Model == "" {
+		request.Model = "gpt-3.5-turbo"
+		if channel.Type == common.ChannelTypeAzure {
+			request.Model = "gpt-35-turbo"
+		}
+	}
+	requestURL := common.ChannelBaseURLs[channel.Type]
+	if channel.Type == common.ChannelTypeAzure {
+		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
+	} else {
+		if channel.Type == common.ChannelTypeCustom {
+			requestURL = channel.BaseURL
+		}
+		requestURL += "/v1/chat/completions"
+	}
+
+	jsonData, err := json.Marshal(request)
+	if err != nil {
+		return err
+	}
+	req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
+	if err != nil {
+		return err
+	}
+	if channel.Type == common.ChannelTypeAzure {
+		req.Header.Set("api-key", channel.Key)
+	} else {
+		req.Header.Set("Authorization", "Bearer "+channel.Key)
+	}
+	req.Header.Set("Content-Type", "application/json")
+	client := &http.Client{}
+	resp, err := client.Do(req)
+	if err != nil {
+		return err
+	}
+	defer resp.Body.Close()
+	var response TextResponse
+	err = json.NewDecoder(resp.Body).Decode(&response)
+	if err != nil {
+		return err
+	}
+	if response.Error.Type != "" {
+		return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
+	}
+	return nil
+}
+
+func TestChannel(c *gin.Context) {
+	id, err := strconv.Atoi(c.Param("id"))
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	channel, err := model.GetChannelById(id, true)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	model_ := c.Query("model")
+	chatRequest := &ChatRequest{
+		Model: model_,
+	}
+	testMessage := Message{
+		Role:    "user",
+		Content: "echo hi",
+	}
+	chatRequest.Messages = append(chatRequest.Messages, testMessage)
+	tik := time.Now()
+	err = testChannel(channel, chatRequest)
+	tok := time.Now()
+	consumedTime := float64(tok.Sub(tik).Milliseconds()) / 1000.0
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+			"time":    consumedTime,
+		})
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"time":    consumedTime,
+	})
+	return
+}

+ 13 - 0
controller/relay.go

@@ -19,6 +19,11 @@ type Message struct {
 	Content string `json:"content"`
 }
 
+type ChatRequest struct {
+	Model    string    `json:"model"`
+	Messages []Message `json:"messages"`
+}
+
 type TextRequest struct {
 	Model    string    `json:"model"`
 	Messages []Message `json:"messages"`
@@ -32,8 +37,16 @@ type Usage struct {
 	TotalTokens      int `json:"total_tokens"`
 }
 
+type OpenAIError struct {
+	Message string `json:"message"`
+	Type    string `json:"type"`
+	Param   string `json:"param"`
+	Code    string `json:"code"`
+}
+
 type TextResponse struct {
 	Usage `json:"usage"`
+	Error OpenAIError `json:"error"`
 }
 
 type StreamResponse struct {

+ 1 - 0
router/api-router.go

@@ -63,6 +63,7 @@ func SetApiRouter(router *gin.Engine) {
 			channelRoute.GET("/", controller.GetAllChannels)
 			channelRoute.GET("/search", controller.SearchChannels)
 			channelRoute.GET("/:id", controller.GetChannel)
+			channelRoute.GET("/test/:id", controller.TestChannel)
 			channelRoute.POST("/", controller.AddChannel)
 			channelRoute.PUT("/", controller.UpdateChannel)
 			channelRoute.DELETE("/:id", controller.DeleteChannel)

+ 20 - 1
web/src/components/ChannelsTable.js

@@ -1,7 +1,7 @@
 import React, { useEffect, useState } from 'react';
 import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
 import { Link } from 'react-router-dom';
-import { API, copy, showError, showSuccess, timestamp2string } from '../helpers';
+import { API, copy, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
 
 import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
 
@@ -139,6 +139,16 @@ const ChannelsTable = () => {
     setSearching(false);
   };
 
+  const testChannel = async (id, name) => {
+    const res = await API.get(`/api/channel/test/${id}/`);
+    const { success, message, time } = res.data;
+    if (success) {
+      showInfo(`通道 ${name} 测试成功,耗时 ${time} 秒。`);
+    } else {
+      showError(message);
+    }
+  }
+
   const handleKeywordChange = async (e, { value }) => {
     setSearchKeyword(value.trim());
   };
@@ -244,6 +254,15 @@ const ChannelsTable = () => {
                   <Table.Cell>{renderTimestamp(channel.accessed_time)}</Table.Cell>
                   <Table.Cell>
                     <div>
+                      <Button
+                        size={'small'}
+                        positive
+                        onClick={() => {
+                          testChannel(channel.id, channel.name);
+                        }}
+                      >
+                        测试
+                      </Button>
                       <Popup
                         trigger={
                           <Button size='small' negative>