Przeglądaj źródła

Merge pull request #3505 from seefs001/fix/claude-media-support

fix: add basic inline file support for Claude relay
Calcium-Ion 1 miesiąc temu
rodzic
commit
41cd051ea9

+ 84 - 10
relay/channel/claude/relay-claude.go

@@ -1,10 +1,12 @@
 package claude
 
 import (
+	"encoding/base64"
 	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
+	"path/filepath"
 	"strings"
 
 	"github.com/QuantumNous/new-api/common"
@@ -44,6 +46,61 @@ func maybeMarkClaudeRefusal(c *gin.Context, stopReason string) {
 	}
 }
 
+func createClaudeFileSource(file *dto.MessageFile) *types.FileSource {
+	if file == nil || file.FileData == "" {
+		return nil
+	}
+	if strings.HasPrefix(file.FileData, "http://") || strings.HasPrefix(file.FileData, "https://") {
+		return types.NewURLFileSource(file.FileData)
+	}
+	mimeType := ""
+	if ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(file.FileName)), "."); ext != "" {
+		if detected := service.GetMimeTypeByExtension(ext); detected != "application/octet-stream" {
+			mimeType = detected
+		}
+	}
+	return types.NewBase64FileSource(file.FileData, mimeType)
+}
+
+func buildClaudeFileMessage(c *gin.Context, file *dto.MessageFile) (*dto.ClaudeMediaMessage, error) {
+	source := createClaudeFileSource(file)
+	if source == nil {
+		return nil, nil
+	}
+	base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting document for Claude")
+	if err != nil {
+		return nil, fmt.Errorf("get file data failed: %w", err)
+	}
+	switch strings.ToLower(mimeType) {
+	case "application/pdf":
+		return &dto.ClaudeMediaMessage{
+			Type: "document",
+			Source: &dto.ClaudeMessageSource{
+				Type:      "base64",
+				MediaType: mimeType,
+				Data:      base64Data,
+			},
+		}, nil
+	case "text/plain":
+		decodedData, err := base64.StdEncoding.DecodeString(base64Data)
+		if err != nil {
+			return nil, fmt.Errorf("decode text file data failed: %w", err)
+		}
+		return &dto.ClaudeMediaMessage{
+			Type: "text",
+			Text: common.GetPointer(string(decodedData)),
+		}, nil
+	default:
+		msg := fmt.Sprintf("claude: skip unsupported file content, filename=%q, mime=%q", file.FileName, mimeType)
+		if c != nil {
+			logger.LogInfo(c, msg)
+		} else {
+			common.SysLog(msg)
+		}
+		return nil, nil
+	}
+}
+
 func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
 	claudeTools := make([]any, 0, len(textRequest.Tools))
 
@@ -343,16 +400,22 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
 			} else {
 				claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
 				for _, mediaMessage := range message.ParseContent() {
-					claudeMediaMessage := dto.ClaudeMediaMessage{
-						Type: mediaMessage.Type,
-					}
-					if mediaMessage.Type == "text" {
-						claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
-					} else {
+					switch mediaMessage.Type {
+					case "text":
+						claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
+							Type: "text",
+							Text: common.GetPointer[string](mediaMessage.Text),
+						})
+					case dto.ContentTypeImageURL:
+						claudeMediaMessage := dto.ClaudeMediaMessage{
+							Type: "image",
+							Source: &dto.ClaudeMessageSource{
+								Type: "base64",
+							},
+						}
 						imageUrl := mediaMessage.GetImageMedia()
-						claudeMediaMessage.Type = "image"
-						claudeMediaMessage.Source = &dto.ClaudeMessageSource{
-							Type: "base64",
+						if imageUrl == nil {
+							continue
 						}
 						// 使用统一的文件服务获取图片数据
 						var source *types.FileSource
@@ -367,8 +430,19 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
 						}
 						claudeMediaMessage.Source.MediaType = mimeType
 						claudeMediaMessage.Source.Data = base64Data
+						claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
+					// FIXME
+					//case dto.ContentTypeFile:
+					//	claudeFileMessage, err := buildClaudeFileMessage(c, mediaMessage.GetFile())
+					//	if err != nil {
+					//		return nil, err
+					//	}
+					//	if claudeFileMessage != nil {
+					//		claudeMediaMessages = append(claudeMediaMessages, *claudeFileMessage)
+					//	}
+					default:
+						continue
 					}
-					claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
 				}
 				if message.ToolCalls != nil {
 					for _, toolCall := range message.ParseToolCalls() {

+ 108 - 0
relay/channel/claude/relay_claude_test.go

@@ -1,10 +1,12 @@
 package claude
 
 import (
+	"encoding/base64"
 	"strings"
 	"testing"
 
 	"github.com/QuantumNous/new-api/dto"
+	"github.com/stretchr/testify/require"
 )
 
 func TestFormatClaudeResponseInfo_MessageStart(t *testing.T) {
@@ -255,3 +257,109 @@ func TestBuildOpenAIStyleUsageFromClaudeUsagePreservesCacheCreationRemainder(t *
 		})
 	}
 }
+
+func TestRequestOpenAI2ClaudeMessage_IgnoresUnsupportedFileContent(t *testing.T) {
+	request := dto.GeneralOpenAIRequest{
+		Model: "claude-3-5-sonnet",
+		Messages: []dto.Message{
+			{
+				Role: "user",
+				Content: []any{
+					dto.MediaContent{
+						Type: dto.ContentTypeText,
+						Text: "see attachment",
+					},
+					dto.MediaContent{
+						Type: dto.ContentTypeFile,
+						File: &dto.MessageFile{
+							FileName: "blob.bin",
+							FileData: "JVBERi0xLjQK",
+						},
+					},
+				},
+			},
+		},
+	}
+
+	claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
+	require.NoError(t, err)
+	require.Len(t, claudeRequest.Messages, 1)
+
+	content, ok := claudeRequest.Messages[0].Content.([]dto.ClaudeMediaMessage)
+	require.True(t, ok)
+	require.Len(t, content, 1)
+	require.Equal(t, "text", content[0].Type)
+	require.NotNil(t, content[0].Text)
+	require.Equal(t, "see attachment", *content[0].Text)
+}
+
+func TestRequestOpenAI2ClaudeMessage_SupportsPDFFileContent(t *testing.T) {
+	request := dto.GeneralOpenAIRequest{
+		Model: "claude-3-5-sonnet",
+		Messages: []dto.Message{
+			{
+				Role: "user",
+				Content: []any{
+					dto.MediaContent{
+						Type: dto.ContentTypeFile,
+						File: &dto.MessageFile{
+							FileName: "spec.pdf",
+							FileData: "JVBERi0xLjQK",
+						},
+					},
+					dto.MediaContent{
+						Type: dto.ContentTypeText,
+						Text: "summarize it",
+					},
+				},
+			},
+		},
+	}
+
+	claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
+	require.NoError(t, err)
+	require.Len(t, claudeRequest.Messages, 1)
+
+	content, ok := claudeRequest.Messages[0].Content.([]dto.ClaudeMediaMessage)
+	require.True(t, ok)
+	require.Len(t, content, 2)
+	require.Equal(t, "document", content[0].Type)
+	require.NotNil(t, content[0].Source)
+	require.Equal(t, "base64", content[0].Source.Type)
+	require.Equal(t, "application/pdf", content[0].Source.MediaType)
+	require.Equal(t, "JVBERi0xLjQK", content[0].Source.Data)
+	require.Equal(t, "text", content[1].Type)
+	require.NotNil(t, content[1].Text)
+	require.Equal(t, "summarize it", *content[1].Text)
+}
+
+func TestRequestOpenAI2ClaudeMessage_ConvertsTextFileContentToText(t *testing.T) {
+	request := dto.GeneralOpenAIRequest{
+		Model: "claude-3-5-sonnet",
+		Messages: []dto.Message{
+			{
+				Role: "user",
+				Content: []any{
+					dto.MediaContent{
+						Type: dto.ContentTypeFile,
+						File: &dto.MessageFile{
+							FileName: "notes.txt",
+							FileData: base64.StdEncoding.EncodeToString([]byte("alpha\nbeta")),
+						},
+					},
+				},
+			},
+		},
+	}
+
+	claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
+	require.NoError(t, err)
+	require.Len(t, claudeRequest.Messages, 1)
+
+	content, ok := claudeRequest.Messages[0].Content.([]dto.ClaudeMediaMessage)
+	require.True(t, ok)
+	require.Len(t, content, 1)
+	require.Equal(t, "text", content[0].Type)
+	require.NotNil(t, content[0].Text)
+	require.Equal(t, "alpha\nbeta", *content[0].Text)
+}