websocket.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. package relay
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. "github.com/gorilla/websocket"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. relaycommon "one-api/relay/common"
  11. "one-api/service"
  12. "one-api/setting"
  13. "one-api/setting/operation_setting"
  14. )
  15. func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
  16. relayInfo := relaycommon.GenRelayInfoWs(c, ws)
  17. // get & validate textRequest 获取并验证文本请求
  18. //realtimeEvent, err := getAndValidateWssRequest(c, ws)
  19. //if err != nil {
  20. // common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
  21. // return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
  22. //}
  23. // map model name
  24. modelMapping := c.GetString("model_mapping")
  25. //isModelMapped := false
  26. if modelMapping != "" && modelMapping != "{}" {
  27. modelMap := make(map[string]string)
  28. err := json.Unmarshal([]byte(modelMapping), &modelMap)
  29. if err != nil {
  30. return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
  31. }
  32. if modelMap[relayInfo.OriginModelName] != "" {
  33. relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName]
  34. // set upstream model name
  35. //isModelMapped = true
  36. }
  37. }
  38. //relayInfo.UpstreamModelName = textRequest.Model
  39. modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
  40. groupRatio := setting.GetGroupRatio(relayInfo.Group)
  41. var preConsumedQuota int
  42. var ratio float64
  43. var modelRatio float64
  44. //err := service.SensitiveWordsCheck(textRequest)
  45. //if constant.ShouldCheckPromptSensitive() {
  46. // err = checkRequestSensitive(textRequest, relayInfo)
  47. // if err != nil {
  48. // return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
  49. // }
  50. //}
  51. //promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
  52. //// count messages token error 计算promptTokens错误
  53. //if err != nil {
  54. // return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
  55. //}
  56. //
  57. if !getModelPriceSuccess {
  58. preConsumedTokens := common.PreConsumedQuota
  59. //if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
  60. // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
  61. //}
  62. modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
  63. ratio = modelRatio * groupRatio
  64. preConsumedQuota = int(float64(preConsumedTokens) * ratio)
  65. } else {
  66. preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
  67. relayInfo.UsePrice = true
  68. }
  69. // pre-consume quota 预消耗配额
  70. preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
  71. if openaiErr != nil {
  72. return openaiErr
  73. }
  74. defer func() {
  75. if openaiErr != nil {
  76. returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
  77. }
  78. }()
  79. adaptor := GetAdaptor(relayInfo.ApiType)
  80. if adaptor == nil {
  81. return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
  82. }
  83. adaptor.Init(relayInfo)
  84. //var requestBody io.Reader
  85. //firstWssRequest, _ := c.Get("first_wss_request")
  86. //requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
  87. statusCodeMappingStr := c.GetString("status_code_mapping")
  88. resp, err := adaptor.DoRequest(c, relayInfo, nil)
  89. if err != nil {
  90. return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  91. }
  92. if resp != nil {
  93. relayInfo.TargetWs = resp.(*websocket.Conn)
  94. defer relayInfo.TargetWs.Close()
  95. }
  96. usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
  97. if openaiErr != nil {
  98. // reset status code 重置状态码
  99. service.ResetStatusCode(openaiErr, statusCodeMappingStr)
  100. return openaiErr
  101. }
  102. service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
  103. userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
  104. return nil
  105. }