|
|
@@ -1,12 +1,17 @@
|
|
|
package controller
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
+ "encoding/json"
|
|
|
+ "errors"
|
|
|
+ "fmt"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
"one-api/model"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
+ "time"
|
|
|
)
|
|
|
|
|
|
func GetAllChannels(c *gin.Context) {
|
|
|
@@ -153,3 +158,97 @@ func UpdateChannel(c *gin.Context) {
|
|
|
})
|
|
|
return
|
|
|
}
|
|
|
+
|
|
|
+func testChannel(channel *model.Channel, request *ChatRequest) error {
|
|
|
+ if request.Model == "" {
|
|
|
+ request.Model = "gpt-3.5-turbo"
|
|
|
+ if channel.Type == common.ChannelTypeAzure {
|
|
|
+ request.Model = "gpt-35-turbo"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ requestURL := common.ChannelBaseURLs[channel.Type]
|
|
|
+ if channel.Type == common.ChannelTypeAzure {
|
|
|
+ requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
|
|
|
+ } else {
|
|
|
+ if channel.Type == common.ChannelTypeCustom {
|
|
|
+ requestURL = channel.BaseURL
|
|
|
+ }
|
|
|
+ requestURL += "/v1/chat/completions"
|
|
|
+ }
|
|
|
+
|
|
|
+ jsonData, err := json.Marshal(request)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if channel.Type == common.ChannelTypeAzure {
|
|
|
+ req.Header.Set("api-key", channel.Key)
|
|
|
+ } else {
|
|
|
+ req.Header.Set("Authorization", "Bearer "+channel.Key)
|
|
|
+ }
|
|
|
+ req.Header.Set("Content-Type", "application/json")
|
|
|
+ client := &http.Client{}
|
|
|
+ resp, err := client.Do(req)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ defer resp.Body.Close()
|
|
|
+ var response TextResponse
|
|
|
+ err = json.NewDecoder(resp.Body).Decode(&response)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if response.Error.Type != "" {
|
|
|
+ return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func TestChannel(c *gin.Context) {
|
|
|
+ id, err := strconv.Atoi(c.Param("id"))
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": err.Error(),
|
|
|
+ })
|
|
|
+ return
|
|
|
+ }
|
|
|
+ channel, err := model.GetChannelById(id, true)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": err.Error(),
|
|
|
+ })
|
|
|
+ return
|
|
|
+ }
|
|
|
+ model_ := c.Query("model")
|
|
|
+ chatRequest := &ChatRequest{
|
|
|
+ Model: model_,
|
|
|
+ }
|
|
|
+ testMessage := Message{
|
|
|
+ Role: "user",
|
|
|
+ Content: "echo hi",
|
|
|
+ }
|
|
|
+ chatRequest.Messages = append(chatRequest.Messages, testMessage)
|
|
|
+ tik := time.Now()
|
|
|
+ err = testChannel(channel, chatRequest)
|
|
|
+ tok := time.Now()
|
|
|
+ consumedTime := float64(tok.Sub(tik).Milliseconds()) / 1000.0
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": err.Error(),
|
|
|
+ "time": consumedTime,
|
|
|
+ })
|
|
|
+ return
|
|
|
+ }
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": true,
|
|
|
+ "message": "",
|
|
|
+ "time": consumedTime,
|
|
|
+ })
|
|
|
+ return
|
|
|
+}
|