Переглянути джерело

feat: support ollama claude format

CaIon 7 місяців тому
батько
коміт
ae0461692c

+ 1 - 1
controller/relay.go

@@ -56,7 +56,7 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
 		userGroup := c.GetString("group")
 		channelId := c.GetInt("channel_id")
 		other := make(map[string]interface{})
-		other["error_type"] = err.ErrorType
+		other["error_type"] = err.GetErrorType()
 		other["error_code"] = err.GetErrorCode()
 		other["status_code"] = err.StatusCode
 		other["channel_id"] = channelId

+ 21 - 5
relay/channel/ollama/adaptor.go

@@ -17,10 +17,13 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
-	//TODO implement me
-	panic("implement me")
-	return nil, nil
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
+	openaiAdaptor := openai.Adaptor{}
+	openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request)
+	if err != nil {
+		return nil, err
+	}
+	return requestOpenAI2Ollama(openaiRequest.(*dto.GeneralOpenAIRequest))
 }
 
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -37,6 +40,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	if info.RelayFormat == relaycommon.RelayFormatClaude {
+		return info.BaseUrl + "/v1/chat/completions", nil
+	}
 	switch info.RelayMode {
 	case relayconstant.RelayModeEmbeddings:
 		return info.BaseUrl + "/api/embed", nil
@@ -55,7 +61,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	return requestOpenAI2Ollama(*request)
+	return requestOpenAI2Ollama(request)
 }
 
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -85,6 +91,16 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 			usage, err = openai.OpenaiHandler(c, info, resp)
 		}
 	}
+	switch info.RelayMode {
+	case relayconstant.RelayModeEmbeddings:
+		usage, err = ollamaEmbeddingHandler(c, info, resp)
+	default:
+		if info.IsStream {
+			usage, err = openai.OaiStreamHandler(c, info, resp)
+		} else {
+			usage, err = openai.OpenaiHandler(c, info, resp)
+		}
+	}
 	return
 }
 

+ 5 - 5
relay/channel/ollama/relay-ollama.go

@@ -14,7 +14,7 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
+func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
 	messages := make([]dto.Message, 0, len(request.Messages))
 	for _, message := range request.Messages {
 		if !message.IsStringContent() {
@@ -92,15 +92,15 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
 	var ollamaEmbeddingResponse OllamaEmbeddingResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 	common.CloseResponseBodyGracefully(resp)
 	err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
 	if err != nil {
-		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 	if ollamaEmbeddingResponse.Error != "" {
-		return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
+		return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 	flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
 	data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
@@ -121,7 +121,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
 	}
 	doResponseBody, err := common.Marshal(embeddingResponse)
 	if err != nil {
-		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 	common.IOCopyBytesGracefully(c, resp, doResponseBody)
 	return usage, nil

+ 2 - 6
service/error.go

@@ -80,10 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
 }
 
 func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
-	newApiErr = &types.NewAPIError{
-		StatusCode: resp.StatusCode,
-		ErrorType:  types.ErrorTypeOpenAIError,
-	}
+	newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
 
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
@@ -105,8 +102,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
 		// General format error (OpenAI, Anthropic, Gemini, etc.)
 		newApiErr = types.WithOpenAIError(errResponse.Error, resp.StatusCode)
 	} else {
-		newApiErr = types.NewErrorWithStatusCode(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
-		newApiErr.ErrorType = types.ErrorTypeOpenAIError
+		newApiErr = types.NewOpenAIError(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
 	}
 	return
 }

+ 24 - 10
types/error.go

@@ -75,7 +75,7 @@ const (
 type NewAPIError struct {
 	Err        error
 	RelayError any
-	ErrorType  ErrorType
+	errorType  ErrorType
 	errorCode  ErrorCode
 	StatusCode int
 }
@@ -87,6 +87,13 @@ func (e *NewAPIError) GetErrorCode() ErrorCode {
 	return e.errorCode
 }
 
+func (e *NewAPIError) GetErrorType() ErrorType {
+	if e == nil {
+		return ""
+	}
+	return e.errorType
+}
+
 func (e *NewAPIError) Error() string {
 	if e == nil {
 		return ""
@@ -103,7 +110,7 @@ func (e *NewAPIError) SetMessage(message string) {
 }
 
 func (e *NewAPIError) ToOpenAIError() OpenAIError {
-	switch e.ErrorType {
+	switch e.errorType {
 	case ErrorTypeOpenAIError:
 		if openAIError, ok := e.RelayError.(OpenAIError); ok {
 			return openAIError
@@ -120,14 +127,14 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError {
 	}
 	return OpenAIError{
 		Message: e.Error(),
-		Type:    string(e.ErrorType),
+		Type:    string(e.errorType),
 		Param:   "",
 		Code:    e.errorCode,
 	}
 }
 
 func (e *NewAPIError) ToClaudeError() ClaudeError {
-	switch e.ErrorType {
+	switch e.errorType {
 	case ErrorTypeOpenAIError:
 		openAIError := e.RelayError.(OpenAIError)
 		return ClaudeError{
@@ -139,7 +146,7 @@ func (e *NewAPIError) ToClaudeError() ClaudeError {
 	default:
 		return ClaudeError{
 			Message: e.Error(),
-			Type:    string(e.ErrorType),
+			Type:    string(e.errorType),
 		}
 	}
 }
@@ -148,7 +155,7 @@ func NewError(err error, errorCode ErrorCode) *NewAPIError {
 	return &NewAPIError{
 		Err:        err,
 		RelayError: nil,
-		ErrorType:  ErrorTypeNewAPIError,
+		errorType:  ErrorTypeNewAPIError,
 		StatusCode: http.StatusInternalServerError,
 		errorCode:  errorCode,
 	}
@@ -162,6 +169,13 @@ func NewOpenAIError(err error, errorCode ErrorCode, statusCode int) *NewAPIError
 	return WithOpenAIError(openaiError, statusCode)
 }
 
+func InitOpenAIError(errorCode ErrorCode, statusCode int) *NewAPIError {
+	openaiError := OpenAIError{
+		Type: string(errorCode),
+	}
+	return WithOpenAIError(openaiError, statusCode)
+}
+
 func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *NewAPIError {
 	return &NewAPIError{
 		Err: err,
@@ -169,7 +183,7 @@ func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *New
 			Message: err.Error(),
 			Type:    string(errorCode),
 		},
-		ErrorType:  ErrorTypeNewAPIError,
+		errorType:  ErrorTypeNewAPIError,
 		StatusCode: statusCode,
 		errorCode:  errorCode,
 	}
@@ -182,7 +196,7 @@ func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError {
 	}
 	return &NewAPIError{
 		RelayError: openAIError,
-		ErrorType:  ErrorTypeOpenAIError,
+		errorType:  ErrorTypeOpenAIError,
 		StatusCode: statusCode,
 		Err:        errors.New(openAIError.Message),
 		errorCode:  ErrorCode(code),
@@ -192,7 +206,7 @@ func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError {
 func WithClaudeError(claudeError ClaudeError, statusCode int) *NewAPIError {
 	return &NewAPIError{
 		RelayError: claudeError,
-		ErrorType:  ErrorTypeClaudeError,
+		errorType:  ErrorTypeClaudeError,
 		StatusCode: statusCode,
 		Err:        errors.New(claudeError.Message),
 		errorCode:  ErrorCode(claudeError.Type),
@@ -211,5 +225,5 @@ func IsLocalError(err *NewAPIError) bool {
 		return false
 	}
 
-	return err.ErrorType == ErrorTypeNewAPIError
+	return err.errorType == ErrorTypeNewAPIError
 }