Просмотр исходного кода

feat: Improve image handling for Ollama channels

1808837298@qq.com 1 год назад
Родитель
Сommit
971aea09ee

+ 17 - 2
dto/openai_request.go

@@ -88,8 +88,10 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
 }
 
 type Message struct {
-	Role             string          `json:"role"`
-	Content          json.RawMessage `json:"content"`
+	Role    string          `json:"role"`
+	Content json.RawMessage `json:"content"`
+	// parsedContent not json field
+	parsedContent    []MediaContent
 	Name             *string         `json:"name,omitempty"`
 	Prefix           *bool           `json:"prefix,omitempty"`
 	ReasoningContent string          `json:"reasoning_content,omitempty"`
@@ -160,6 +162,11 @@ func (m *Message) SetStringContent(content string) {
 	m.Content = jsonContent
 }
 
+func (m *Message) SetMediaContent(content []MediaContent) {
+	jsonContent, _ := json.Marshal(content)
+	m.Content = jsonContent
+}
+
 func (m *Message) IsStringContent() bool {
 	var stringContent string
 	if err := json.Unmarshal(m.Content, &stringContent); err == nil {
@@ -169,7 +176,15 @@ func (m *Message) IsStringContent() bool {
 }
 
 func (m *Message) ParseContent() []MediaContent {
+	if m.parsedContent != nil {
+		return m.parsedContent
+	}
 	var contentList []MediaContent
+	defer func() {
+		if len(contentList) > 0 {
+			m.parsedContent = contentList
+		}
+	}()
 	var stringContent string
 	if err := json.Unmarshal(m.Content, &stringContent); err == nil {
 		contentList = append(contentList, MediaContent{

+ 1 - 1
relay/channel/ollama/adaptor.go

@@ -46,7 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	return requestOpenAI2Ollama(*request), nil
+	return requestOpenAI2Ollama(*request)
 }
 
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {

+ 26 - 4
relay/channel/ollama/relay-ollama.go

@@ -9,14 +9,36 @@ import (
 	"net/http"
 	"one-api/dto"
 	"one-api/service"
+	"strings"
 )
 
-func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
+func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
 	messages := make([]dto.Message, 0, len(request.Messages))
 	for _, message := range request.Messages {
+		if !message.IsStringContent() {
+			mediaMessages := message.ParseContent()
+			for j, mediaMessage := range mediaMessages {
+				if mediaMessage.Type == dto.ContentTypeImageURL {
+					imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
+					// check if not base64
+					if strings.HasPrefix(imageUrl.Url, "http") {
+						fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
+						if err != nil {
+							return nil, err
+						}
+						imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
+					}
+					mediaMessage.ImageUrl = imageUrl
+					mediaMessages[j] = mediaMessage
+				}
+			}
+			message.SetMediaContent(mediaMessages)
+		}
 		messages = append(messages, dto.Message{
-			Role:    message.Role,
-			Content: message.Content,
+			Role:       message.Role,
+			Content:    message.Content,
+			ToolCalls:  message.ToolCalls,
+			ToolCallId: message.ToolCallId,
 		})
 	}
 	str, ok := request.Stop.(string)
@@ -42,7 +64,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
 		Prompt:           request.Prompt,
 		StreamOptions:    request.StreamOptions,
 		Suffix:           request.Suffix,
-	}
+	}, nil
 }
 
 func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {

+ 1 - 2
relay/channel/zhipu_4v/relay-zhipu_v4.go

@@ -90,8 +90,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
 					mediaMessages[j] = mediaMessage
 				}
 			}
-			messageRaw, _ := json.Marshal(mediaMessages)
-			message.Content = messageRaw
+			message.SetMediaContent(mediaMessages)
 		}
 		messages = append(messages, dto.Message{
 			Role:       message.Role,