adaptor.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. package claude
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. "io"
  7. "net/http"
  8. "one-api/dto"
  9. "one-api/relay/channel"
  10. relaycommon "one-api/relay/common"
  11. "strings"
  12. )
  13. const (
  14. RequestModeCompletion = 1
  15. RequestModeMessage = 2
  16. )
  17. type Adaptor struct {
  18. RequestMode int
  19. }
  20. func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
  21. if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
  22. a.RequestMode = RequestModeMessage
  23. } else {
  24. a.RequestMode = RequestModeCompletion
  25. }
  26. }
  27. func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
  28. if a.RequestMode == RequestModeMessage {
  29. return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
  30. } else {
  31. return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
  32. }
  33. }
  34. func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
  35. channel.SetupApiRequestHeader(info, c, req)
  36. req.Header.Set("x-api-key", info.ApiKey)
  37. anthropicVersion := c.Request.Header.Get("anthropic-version")
  38. if anthropicVersion == "" {
  39. anthropicVersion = "2023-06-01"
  40. }
  41. req.Header.Set("anthropic-version", anthropicVersion)
  42. return nil
  43. }
  44. func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
  45. if request == nil {
  46. return nil, errors.New("request is nil")
  47. }
  48. if a.RequestMode == RequestModeCompletion {
  49. return requestOpenAI2ClaudeComplete(*request), nil
  50. } else {
  51. return requestOpenAI2ClaudeMessage(*request)
  52. }
  53. }
  54. func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
  55. return channel.DoApiRequest(a, c, info, requestBody)
  56. }
  57. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
  58. if info.IsStream {
  59. err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
  60. } else {
  61. err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
  62. }
  63. return
  64. }
  65. func (a *Adaptor) GetModelList() []string {
  66. return ModelList
  67. }
  68. func (a *Adaptor) GetChannelName() string {
  69. return ChannelName
  70. }