websocket.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. package relay
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. "github.com/gorilla/websocket"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. relaycommon "one-api/relay/common"
  11. "one-api/service"
  12. "one-api/setting"
  13. )
  14. func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
  15. relayInfo := relaycommon.GenRelayInfoWs(c, ws)
  16. // get & validate textRequest 获取并验证文本请求
  17. //realtimeEvent, err := getAndValidateWssRequest(c, ws)
  18. //if err != nil {
  19. // common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
  20. // return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
  21. //}
  22. // map model name
  23. modelMapping := c.GetString("model_mapping")
  24. //isModelMapped := false
  25. if modelMapping != "" && modelMapping != "{}" {
  26. modelMap := make(map[string]string)
  27. err := json.Unmarshal([]byte(modelMapping), &modelMap)
  28. if err != nil {
  29. return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
  30. }
  31. if modelMap[relayInfo.OriginModelName] != "" {
  32. relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName]
  33. // set upstream model name
  34. //isModelMapped = true
  35. }
  36. }
  37. //relayInfo.UpstreamModelName = textRequest.Model
  38. modelPrice, getModelPriceSuccess := setting.GetModelPrice(relayInfo.UpstreamModelName, false)
  39. groupRatio := setting.GetGroupRatio(relayInfo.Group)
  40. var preConsumedQuota int
  41. var ratio float64
  42. var modelRatio float64
  43. //err := service.SensitiveWordsCheck(textRequest)
  44. //if constant.ShouldCheckPromptSensitive() {
  45. // err = checkRequestSensitive(textRequest, relayInfo)
  46. // if err != nil {
  47. // return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
  48. // }
  49. //}
  50. //promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
  51. //// count messages token error 计算promptTokens错误
  52. //if err != nil {
  53. // return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
  54. //}
  55. //
  56. if !getModelPriceSuccess {
  57. preConsumedTokens := common.PreConsumedQuota
  58. //if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
  59. // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
  60. //}
  61. modelRatio, _ = setting.GetModelRatio(relayInfo.UpstreamModelName)
  62. ratio = modelRatio * groupRatio
  63. preConsumedQuota = int(float64(preConsumedTokens) * ratio)
  64. } else {
  65. preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
  66. relayInfo.UsePrice = true
  67. }
  68. // pre-consume quota 预消耗配额
  69. preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
  70. if openaiErr != nil {
  71. return openaiErr
  72. }
  73. defer func() {
  74. if openaiErr != nil {
  75. returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
  76. }
  77. }()
  78. adaptor := GetAdaptor(relayInfo.ApiType)
  79. if adaptor == nil {
  80. return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
  81. }
  82. adaptor.Init(relayInfo)
  83. //var requestBody io.Reader
  84. //firstWssRequest, _ := c.Get("first_wss_request")
  85. //requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
  86. statusCodeMappingStr := c.GetString("status_code_mapping")
  87. resp, err := adaptor.DoRequest(c, relayInfo, nil)
  88. if err != nil {
  89. return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  90. }
  91. if resp != nil {
  92. relayInfo.TargetWs = resp.(*websocket.Conn)
  93. defer relayInfo.TargetWs.Close()
  94. }
  95. usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
  96. if openaiErr != nil {
  97. // reset status code 重置状态码
  98. service.ResetStatusCode(openaiErr, statusCodeMappingStr)
  99. return openaiErr
  100. }
  101. service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
  102. userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
  103. return nil
  104. }