Jelajahi Sumber

Merge branch 'Calcium-Ion:main' into main

GuoRuqiang 1 tahun lalu
induk
melakukan
6bbf1d4843

+ 21 - 12
common/model-ratio.go

@@ -42,6 +42,10 @@ var defaultModelRatio = map[string]float64{
 	"gpt-4o":                    2.5,  // $0.01 / 1K tokens
 	"gpt-4o-2024-05-13":         2.5,  // $0.01 / 1K tokens
 	"gpt-4o-2024-08-06":         1.25, // $0.01 / 1K tokens
+	"o1-preview":                7.5,
+	"o1-preview-2024-09-12":     7.5,
+	"o1-mini":                   1.5,
+	"o1-mini-2024-09-12":        1.5,
 	"gpt-4o-mini":               0.075,
 	"gpt-4o-mini-2024-07-18":    0.075,
 	"gpt-4-turbo":               5,    // $0.01 / 1K tokens
@@ -106,8 +110,10 @@ var defaultModelRatio = map[string]float64{
 	"gemini-pro-vision":              1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
 	"gemini-1.0-pro-vision-001":      1,
 	"gemini-1.0-pro-001":             1,
-	"gemini-1.5-pro-latest":          1,
+	"gemini-1.5-pro-latest":          1.75, // $3.5 / 1M tokens
+	"gemini-1.5-pro-exp-0827":        1.75, // $3.5 / 1M tokens
 	"gemini-1.5-flash-latest":        1,
+	"gemini-1.5-flash-exp-0827":      1,
 	"gemini-1.0-pro-latest":          1,
 	"gemini-1.0-pro-vision-latest":   1,
 	"gemini-ultra":                   1,
@@ -329,17 +335,6 @@ func GetCompletionRatio(name string) float64 {
 	if strings.HasPrefix(name, "gpt-4o-gizmo") {
 		name = "gpt-4o-gizmo-*"
 	}
-	if strings.HasPrefix(name, "gpt-3.5") {
-		if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
-			// https://openai.com/blog/new-embedding-models-and-api-updates
-			// Updated GPT-3.5 Turbo model and lower pricing
-			return 3
-		}
-		if strings.HasSuffix(name, "1106") {
-			return 2
-		}
-		return 4.0 / 3.0
-	}
 	if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") {
 		if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
 			return 3
@@ -352,6 +347,9 @@ func GetCompletionRatio(name string) float64 {
 		}
 		return 2
 	}
+	if strings.HasPrefix(name, "o1-") {
+		return 4
+	}
 	if name == "chatgpt-4o-latest" {
 		return 3
 	}
@@ -362,6 +360,17 @@ func GetCompletionRatio(name string) float64 {
 	} else if strings.Contains(name, "claude-3") {
 		return 5
 	}
+	if strings.HasPrefix(name, "gpt-3.5") {
+		if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
+			// https://openai.com/blog/new-embedding-models-and-api-updates
+			// Updated GPT-3.5 Turbo model and lower pricing
+			return 3
+		}
+		if strings.HasSuffix(name, "1106") {
+			return 2
+		}
+		return 4.0 / 3.0
+	}
 	if strings.HasPrefix(name, "mistral-") {
 		return 3
 	}

+ 10 - 8
constant/env.go

@@ -20,14 +20,16 @@ var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STR
 var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
 
 var GeminiModelMap = map[string]string{
-	"gemini-1.5-pro-latest":   "v1beta",
-	"gemini-1.5-pro-001":      "v1beta",
-	"gemini-1.5-pro":          "v1beta",
-	"gemini-1.5-pro-exp-0801": "v1beta",
-	"gemini-1.5-flash-latest": "v1beta",
-	"gemini-1.5-flash-001":    "v1beta",
-	"gemini-1.5-flash":        "v1beta",
-	"gemini-ultra":            "v1beta",
+	"gemini-1.5-pro-latest":     "v1beta",
+	"gemini-1.5-pro-001":        "v1beta",
+	"gemini-1.5-pro":            "v1beta",
+	"gemini-1.5-pro-exp-0801":   "v1beta",
+	"gemini-1.5-pro-exp-0827":   "v1beta",
+	"gemini-1.5-flash-latest":   "v1beta",
+	"gemini-1.5-flash-exp-0827": "v1beta",
+	"gemini-1.5-flash-001":      "v1beta",
+	"gemini-1.5-flash":          "v1beta",
+	"gemini-ultra":              "v1beta",
 }
 
 func InitEnv() {

+ 17 - 16
controller/channel-test.go

@@ -20,6 +20,7 @@ import (
 	"one-api/relay/constant"
 	"one-api/service"
 	"strconv"
+	"strings"
 	"sync"
 	"time"
 
@@ -81,8 +82,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
 	}
 
-	request := buildTestRequest()
-	request.Model = testModel
+	request := buildTestRequest(testModel)
 	meta.UpstreamModelName = testModel
 	common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
 
@@ -141,17 +141,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	return nil, nil
 }
 
-func buildTestRequest() *dto.GeneralOpenAIRequest {
+func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
 	testRequest := &dto.GeneralOpenAIRequest{
-		Model:     "", // this will be set later
-		MaxTokens: 1,
-		Stream:    false,
+		Model:  "", // this will be set later
+		Stream: false,
+	}
+	if strings.HasPrefix(model, "o1-") {
+		testRequest.MaxCompletionTokens = 1
+	} else {
+		testRequest.MaxTokens = 1
 	}
 	content, _ := json.Marshal("hi")
 	testMessage := dto.Message{
 		Role:    "user",
 		Content: content,
 	}
+	testRequest.Model = model
 	testRequest.Messages = append(testRequest.Messages, testMessage)
 	return testRequest
 }
@@ -226,26 +231,22 @@ func testAllChannels(notify bool) error {
 			tok := time.Now()
 			milliseconds := tok.Sub(tik).Milliseconds()
 
-			ban := false
-			if milliseconds > disableThreshold {
-				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
-				ban = true
-			}
+			shouldBanChannel := false
 
 			// request error disables the channel
 			if openaiWithStatusErr != nil {
 				oaiErr := openaiWithStatusErr.Error
 				err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
-				ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
+				shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
 			}
 
-			// parse *int to bool
-			if !channel.GetAutoBan() {
-				ban = false
+			if milliseconds > disableThreshold {
+				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
+				shouldBanChannel = true
 			}
 
 			// disable channel
-			if ban && isChannelEnabled {
+			if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
 				service.DisableChannel(channel.Id, channel.Name, err.Error())
 			}
 

+ 26 - 25
dto/text_request.go

@@ -7,31 +7,32 @@ type ResponseFormat struct {
 }
 
 type GeneralOpenAIRequest struct {
-	Model            string         `json:"model,omitempty"`
-	Messages         []Message      `json:"messages,omitempty"`
-	Prompt           any            `json:"prompt,omitempty"`
-	Stream           bool           `json:"stream,omitempty"`
-	StreamOptions    *StreamOptions `json:"stream_options,omitempty"`
-	MaxTokens        uint           `json:"max_tokens,omitempty"`
-	Temperature      float64        `json:"temperature,omitempty"`
-	TopP             float64        `json:"top_p,omitempty"`
-	TopK             int            `json:"top_k,omitempty"`
-	Stop             any            `json:"stop,omitempty"`
-	N                int            `json:"n,omitempty"`
-	Input            any            `json:"input,omitempty"`
-	Instruction      string         `json:"instruction,omitempty"`
-	Size             string         `json:"size,omitempty"`
-	Functions        any            `json:"functions,omitempty"`
-	FrequencyPenalty float64        `json:"frequency_penalty,omitempty"`
-	PresencePenalty  float64        `json:"presence_penalty,omitempty"`
-	ResponseFormat   any            `json:"response_format,omitempty"`
-	Seed             float64        `json:"seed,omitempty"`
-	Tools            []ToolCall     `json:"tools,omitempty"`
-	ToolChoice       any            `json:"tool_choice,omitempty"`
-	User             string         `json:"user,omitempty"`
-	LogProbs         bool           `json:"logprobs,omitempty"`
-	TopLogProbs      int            `json:"top_logprobs,omitempty"`
-	Dimensions       int            `json:"dimensions,omitempty"`
+	Model               string         `json:"model,omitempty"`
+	Messages            []Message      `json:"messages,omitempty"`
+	Prompt              any            `json:"prompt,omitempty"`
+	Stream              bool           `json:"stream,omitempty"`
+	StreamOptions       *StreamOptions `json:"stream_options,omitempty"`
+	MaxTokens           uint           `json:"max_tokens,omitempty"`
+	MaxCompletionTokens uint           `json:"max_completion_tokens,omitempty"`
+	Temperature         float64        `json:"temperature,omitempty"`
+	TopP                float64        `json:"top_p,omitempty"`
+	TopK                int            `json:"top_k,omitempty"`
+	Stop                any            `json:"stop,omitempty"`
+	N                   int            `json:"n,omitempty"`
+	Input               any            `json:"input,omitempty"`
+	Instruction         string         `json:"instruction,omitempty"`
+	Size                string         `json:"size,omitempty"`
+	Functions           any            `json:"functions,omitempty"`
+	FrequencyPenalty    float64        `json:"frequency_penalty,omitempty"`
+	PresencePenalty     float64        `json:"presence_penalty,omitempty"`
+	ResponseFormat      any            `json:"response_format,omitempty"`
+	Seed                float64        `json:"seed,omitempty"`
+	Tools               []ToolCall     `json:"tools,omitempty"`
+	ToolChoice          any            `json:"tool_choice,omitempty"`
+	User                string         `json:"user,omitempty"`
+	LogProbs            bool           `json:"logprobs,omitempty"`
+	TopLogProbs         int            `json:"top_logprobs,omitempty"`
+	Dimensions          int            `json:"dimensions,omitempty"`
 }
 
 type OpenAITools struct {

+ 1 - 0
dto/text_response.go

@@ -34,6 +34,7 @@ type OpenAITextResponseChoice struct {
 
 type OpenAITextResponse struct {
 	Id      string                     `json:"id"`
+	Model   string                     `json:"model"`
 	Object  string                     `json:"object"`
 	Created int64                      `json:"created"`
 	Choices []OpenAITextResponseChoice `json:"choices"`

+ 65 - 10
relay/channel/claude/relay-claude.go

@@ -4,7 +4,6 @@ import (
 	"bufio"
 	"encoding/json"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
 	"one-api/common"
@@ -12,6 +11,8 @@ import (
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 
 func stopReasonClaude2OpenAI(reason string) string {
@@ -108,13 +109,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 		}
 	}
 	formatMessages := make([]dto.Message, 0)
-	var lastMessage *dto.Message
+	lastMessage := dto.Message{
+		Role: "tool",
+	}
 	for i, message := range textRequest.Messages {
-		//if message.Role == "system" {
-		//	if i != 0 {
-		//		message.Role = "user"
-		//	}
-		//}
 		if message.Role == "" {
 			textRequest.Messages[i].Role = "user"
 		}
@@ -122,7 +120,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 			Role:    message.Role,
 			Content: message.Content,
 		}
-		if lastMessage != nil && lastMessage.Role == message.Role {
+		if message.Role == "tool" {
+			fmtMessage.ToolCallId = message.ToolCallId
+		}
+		if message.Role == "assistant" && message.ToolCalls != nil {
+			fmtMessage.ToolCalls = message.ToolCalls
+		}
+		if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
 			if lastMessage.IsStringContent() && message.IsStringContent() {
 				content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
 				fmtMessage.Content = content
@@ -135,7 +139,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 			fmtMessage.Content = content
 		}
 		formatMessages = append(formatMessages, fmtMessage)
-		lastMessage = &textRequest.Messages[i]
+		lastMessage = fmtMessage
 	}
 
 	claudeMessages := make([]ClaudeMessage, 0)
@@ -174,7 +178,35 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 			claudeMessage := ClaudeMessage{
 				Role: message.Role,
 			}
-			if message.IsStringContent() {
+			if message.Role == "tool" {
+				if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
+					lastMessage := claudeMessages[len(claudeMessages)-1]
+					if content, ok := lastMessage.Content.(string); ok {
+						lastMessage.Content = []ClaudeMediaMessage{
+							{
+								Type: "text",
+								Text: content,
+							},
+						}
+					}
+					lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
+						Type:      "tool_result",
+						ToolUseId: message.ToolCallId,
+						Content:   message.StringContent(),
+					})
+					claudeMessages[len(claudeMessages)-1] = lastMessage
+					continue
+				} else {
+					claudeMessage.Role = "user"
+					claudeMessage.Content = []ClaudeMediaMessage{
+						{
+							Type:      "tool_result",
+							ToolUseId: message.ToolCallId,
+							Content:   message.StringContent(),
+						},
+					}
+				}
+			} else if message.IsStringContent() && message.ToolCalls == nil {
 				claudeMessage.Content = message.StringContent()
 			} else {
 				claudeMediaMessages := make([]ClaudeMediaMessage, 0)
@@ -207,6 +239,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 					}
 					claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
 				}
+				if message.ToolCalls != nil {
+					for _, tc := range message.ToolCalls.([]interface{}) {
+						toolCallJSON, _ := json.Marshal(tc)
+						var toolCall dto.ToolCall
+						err := json.Unmarshal(toolCallJSON, &toolCall)
+						if err != nil {
+							common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
+							continue
+						}
+						inputObj := make(map[string]any)
+						if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
+							common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
+							continue
+						}
+						claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
+							Type:  "tool_use",
+							Id:    toolCall.ID,
+							Name:  toolCall.Function.Name,
+							Input: inputObj,
+						})
+					}
+				}
 				claudeMessage.Content = claudeMediaMessages
 			}
 			claudeMessages = append(claudeMessages, claudeMessage)
@@ -341,6 +395,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
 	if len(tools) > 0 {
 		choice.Message.ToolCalls = tools
 	}
+	fullTextResponse.Model = claudeResponse.Model
 	choices = append(choices, choice)
 	fullTextResponse.Choices = choices
 	return &fullTextResponse

+ 1 - 1
relay/channel/cohere/dto.go

@@ -8,7 +8,7 @@ type CohereRequest struct {
 	Message     string        `json:"message"`
 	Stream      bool          `json:"stream"`
 	MaxTokens   int           `json:"max_tokens"`
-	SafetyMode  string        `json:"safety_mode"`
+	SafetyMode  string        `json:"safety_mode,omitempty"`
 }
 
 type ChatHistory struct {

+ 3 - 1
relay/channel/cohere/relay-cohere.go

@@ -22,7 +22,9 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
 		Message:     "",
 		Stream:      textRequest.Stream,
 		MaxTokens:   textRequest.GetMaxTokens(),
-		SafetyMode:  common.CohereSafetySetting,
+	}
+	if common.CohereSafetySetting != "NONE" {
+		cohereReq.SafetyMode = common.CohereSafetySetting
 	}
 	if cohereReq.MaxTokens == 0 {
 		cohereReq.MaxTokens = 4000

+ 1 - 1
relay/channel/gemini/constant.go

@@ -6,7 +6,7 @@ const (
 
 var ModelList = []string{
 	"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra",
-	"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001",
+	"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", "gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827",
 }
 
 var ChannelName = "google gemini"

+ 16 - 2
relay/channel/ollama/dto.go

@@ -17,11 +17,25 @@ type OllamaRequest struct {
 	PresencePenalty  float64        `json:"presence_penalty,omitempty"`
 }
 
+type Options struct {
+	Seed             int     `json:"seed,omitempty"`
+	Temperature      float64 `json:"temperature,omitempty"`
+	TopK             int     `json:"top_k,omitempty"`
+	TopP             float64 `json:"top_p,omitempty"`
+	FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+	PresencePenalty  float64 `json:"presence_penalty,omitempty"`
+	NumPredict       int     `json:"num_predict,omitempty"`
+	NumCtx           int     `json:"num_ctx,omitempty"`
+}
+
 type OllamaEmbeddingRequest struct {
-	Model  string `json:"model,omitempty"`
-	Prompt any    `json:"prompt,omitempty"`
+	Model   string   `json:"model,omitempty"`
+	Input   []string `json:"input"`
+	Options *Options `json:"options,omitempty"`
 }
 
 type OllamaEmbeddingResponse struct {
+	Error     string    `json:"error,omitempty"`
+	Model     string    `json:"model"`
 	Embedding []float64 `json:"embedding,omitempty"`
 }

+ 12 - 3
relay/channel/ollama/relay-ollama.go

@@ -9,7 +9,6 @@ import (
 	"net/http"
 	"one-api/dto"
 	"one-api/service"
-	"strings"
 )
 
 func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
@@ -45,8 +44,15 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
 
 func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
 	return &OllamaEmbeddingRequest{
-		Model:  request.Model,
-		Prompt: strings.Join(request.ParseInput(), " "),
+		Model: request.Model,
+		Input: request.ParseInput(),
+		Options: &Options{
+			Seed:             int(request.Seed),
+			Temperature:      request.Temperature,
+			TopP:             request.TopP,
+			FrequencyPenalty: request.FrequencyPenalty,
+			PresencePenalty:  request.PresencePenalty,
+		},
 	}
 }
 
@@ -64,6 +70,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
+	if ollamaEmbeddingResponse.Error != "" {
+		return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
+	}
 	data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
 	data = append(data, dto.OpenAIEmbeddingResponseItem{
 		Embedding: ollamaEmbeddingResponse.Embedding,

+ 6 - 0
relay/channel/openai/adaptor.go

@@ -78,6 +78,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 	if info.ChannelType != common.ChannelTypeOpenAI {
 		request.StreamOptions = nil
 	}
+	if strings.HasPrefix(request.Model, "o1-") {
+		if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
+			request.MaxCompletionTokens = request.MaxTokens
+			request.MaxTokens = 0
+		}
+	}
 	return request, nil
 }
 

+ 2 - 0
relay/channel/openai/constant.go

@@ -11,6 +11,8 @@ var ModelList = []string{
 	"chatgpt-4o-latest",
 	"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06",
 	"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
+	"o1-preview", "o1-preview-2024-09-12",
+	"o1-mini", "o1-mini-2024-09-12",
 	"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
 	"text-curie-001", "text-babbage-001", "text-ada-001",
 	"text-moderation-latest", "text-moderation-stable",

+ 4 - 6
web/src/components/Footer.js

@@ -59,12 +59,10 @@ const Footer = () => {
     <Layout>
       <Layout.Content style={{ textAlign: 'center' }}>
         {footer ? (
-          <Tooltip content={defaultFooter}>
-            <div
-              className='custom-footer'
-              dangerouslySetInnerHTML={{ __html: footer }}
-            ></div>
-          </Tooltip>
+          <div
+            className='custom-footer'
+            dangerouslySetInnerHTML={{ __html: footer }}
+          ></div>
         ) : (
           defaultFooter
         )}