text.go 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. package mistral
  2. import (
  3. "regexp"
  4. "github.com/QuantumNous/new-api/common"
  5. "github.com/QuantumNous/new-api/dto"
  6. )
  7. var mistralToolCallIdRegexp = regexp.MustCompile("^[a-zA-Z0-9]{9}$")
  8. func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
  9. messages := make([]dto.Message, 0, len(request.Messages))
  10. idMap := make(map[string]string)
  11. for _, message := range request.Messages {
  12. // 1. tool_calls.id
  13. toolCalls := message.ParseToolCalls()
  14. if toolCalls != nil {
  15. for i := range toolCalls {
  16. if !mistralToolCallIdRegexp.MatchString(toolCalls[i].ID) {
  17. if newId, ok := idMap[toolCalls[i].ID]; ok {
  18. toolCalls[i].ID = newId
  19. } else {
  20. newId, err := common.GenerateRandomCharsKey(9)
  21. if err == nil {
  22. idMap[toolCalls[i].ID] = newId
  23. toolCalls[i].ID = newId
  24. }
  25. }
  26. }
  27. }
  28. message.SetToolCalls(toolCalls)
  29. }
  30. // 2. tool_call_id
  31. if message.ToolCallId != "" {
  32. if newId, ok := idMap[message.ToolCallId]; ok {
  33. message.ToolCallId = newId
  34. } else {
  35. if !mistralToolCallIdRegexp.MatchString(message.ToolCallId) {
  36. newId, err := common.GenerateRandomCharsKey(9)
  37. if err == nil {
  38. idMap[message.ToolCallId] = newId
  39. message.ToolCallId = newId
  40. }
  41. }
  42. }
  43. }
  44. mediaMessages := message.ParseContent()
  45. if message.Role == "assistant" && message.ToolCalls != nil && message.Content == "" {
  46. mediaMessages = []dto.MediaContent{}
  47. }
  48. for j, mediaMessage := range mediaMessages {
  49. if mediaMessage.Type == dto.ContentTypeImageURL {
  50. imageUrl := mediaMessage.GetImageMedia()
  51. mediaMessage.ImageUrl = imageUrl.Url
  52. mediaMessages[j] = mediaMessage
  53. }
  54. }
  55. message.SetMediaContent(mediaMessages)
  56. messages = append(messages, dto.Message{
  57. Role: message.Role,
  58. Content: message.Content,
  59. ToolCalls: message.ToolCalls,
  60. ToolCallId: message.ToolCallId,
  61. })
  62. }
  63. out := &dto.GeneralOpenAIRequest{
  64. Model: request.Model,
  65. Stream: request.Stream,
  66. Messages: messages,
  67. Temperature: request.Temperature,
  68. TopP: request.TopP,
  69. Tools: request.Tools,
  70. ToolChoice: request.ToolChoice,
  71. }
  72. if request.MaxTokens != nil || request.MaxCompletionTokens != nil {
  73. maxTokens := request.GetMaxTokens()
  74. out.MaxTokens = &maxTokens
  75. }
  76. return out
  77. }