relay-tencent.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. package tencent
  2. import (
  3. "bufio"
  4. "crypto/hmac"
  5. "crypto/sha256"
  6. "encoding/hex"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "net/http"
  12. "strconv"
  13. "strings"
  14. "time"
  15. "github.com/QuantumNous/new-api/common"
  16. "github.com/QuantumNous/new-api/constant"
  17. "github.com/QuantumNous/new-api/dto"
  18. relaycommon "github.com/QuantumNous/new-api/relay/common"
  19. "github.com/QuantumNous/new-api/relay/helper"
  20. "github.com/QuantumNous/new-api/service"
  21. "github.com/QuantumNous/new-api/types"
  22. "github.com/gin-gonic/gin"
  23. )
  24. // https://cloud.tencent.com/document/product/1729/97732
  25. func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
  26. messages := make([]*TencentMessage, 0, len(request.Messages))
  27. for i := 0; i < len(request.Messages); i++ {
  28. message := request.Messages[i]
  29. messages = append(messages, &TencentMessage{
  30. Content: message.StringContent(),
  31. Role: message.Role,
  32. })
  33. }
  34. var req = TencentChatRequest{
  35. Stream: request.Stream,
  36. Messages: messages,
  37. Model: &request.Model,
  38. }
  39. if request.TopP != nil {
  40. req.TopP = request.TopP
  41. }
  42. req.Temperature = request.Temperature
  43. return &req
  44. }
  45. func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
  46. fullTextResponse := dto.OpenAITextResponse{
  47. Id: response.Id,
  48. Object: "chat.completion",
  49. Created: common.GetTimestamp(),
  50. Usage: dto.Usage{
  51. PromptTokens: response.Usage.PromptTokens,
  52. CompletionTokens: response.Usage.CompletionTokens,
  53. TotalTokens: response.Usage.TotalTokens,
  54. },
  55. }
  56. if len(response.Choices) > 0 {
  57. choice := dto.OpenAITextResponseChoice{
  58. Index: 0,
  59. Message: dto.Message{
  60. Role: "assistant",
  61. Content: response.Choices[0].Messages.Content,
  62. },
  63. FinishReason: response.Choices[0].FinishReason,
  64. }
  65. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  66. }
  67. return &fullTextResponse
  68. }
  69. func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.ChatCompletionsStreamResponse {
  70. response := dto.ChatCompletionsStreamResponse{
  71. Object: "chat.completion.chunk",
  72. Created: common.GetTimestamp(),
  73. Model: "tencent-hunyuan",
  74. }
  75. if len(TencentResponse.Choices) > 0 {
  76. var choice dto.ChatCompletionsStreamResponseChoice
  77. choice.Delta.SetContentString(TencentResponse.Choices[0].Delta.Content)
  78. if TencentResponse.Choices[0].FinishReason == "stop" {
  79. choice.FinishReason = &constant.FinishReasonStop
  80. }
  81. response.Choices = append(response.Choices, choice)
  82. }
  83. return &response
  84. }
  85. func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  86. var responseText string
  87. scanner := bufio.NewScanner(resp.Body)
  88. scanner.Split(bufio.ScanLines)
  89. helper.SetEventStreamHeaders(c)
  90. for scanner.Scan() {
  91. data := scanner.Text()
  92. if len(data) < 5 || !strings.HasPrefix(data, "data:") {
  93. continue
  94. }
  95. data = strings.TrimPrefix(data, "data:")
  96. var tencentResponse TencentChatResponse
  97. err := common.Unmarshal([]byte(data), &tencentResponse)
  98. if err != nil {
  99. common.SysLog("error unmarshalling stream response: " + err.Error())
  100. continue
  101. }
  102. response := streamResponseTencent2OpenAI(&tencentResponse)
  103. if len(response.Choices) != 0 {
  104. responseText += response.Choices[0].Delta.GetContentString()
  105. }
  106. err = helper.ObjectData(c, response)
  107. if err != nil {
  108. common.SysLog(err.Error())
  109. }
  110. }
  111. if err := scanner.Err(); err != nil {
  112. common.SysLog("error reading stream: " + err.Error())
  113. }
  114. helper.Done(c)
  115. service.CloseResponseBodyGracefully(resp)
  116. return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens()), nil
  117. }
  118. func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  119. var tencentSb TencentChatResponseSB
  120. responseBody, err := io.ReadAll(resp.Body)
  121. if err != nil {
  122. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  123. }
  124. service.CloseResponseBodyGracefully(resp)
  125. err = json.Unmarshal(responseBody, &tencentSb)
  126. if err != nil {
  127. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  128. }
  129. if tencentSb.Response.Error.Code != 0 {
  130. return nil, types.WithOpenAIError(types.OpenAIError{
  131. Message: tencentSb.Response.Error.Message,
  132. Code: tencentSb.Response.Error.Code,
  133. }, resp.StatusCode)
  134. }
  135. fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
  136. jsonResponse, err := common.Marshal(fullTextResponse)
  137. if err != nil {
  138. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  139. }
  140. c.Writer.Header().Set("Content-Type", "application/json")
  141. c.Writer.WriteHeader(resp.StatusCode)
  142. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  143. return &fullTextResponse.Usage, nil
  144. }
  145. func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
  146. parts := strings.Split(config, "|")
  147. if len(parts) != 3 {
  148. err = errors.New("invalid tencent config")
  149. return
  150. }
  151. appId, err = strconv.ParseInt(parts[0], 10, 64)
  152. secretId = parts[1]
  153. secretKey = parts[2]
  154. return
  155. }
  156. func sha256hex(s string) string {
  157. b := sha256.Sum256([]byte(s))
  158. return hex.EncodeToString(b[:])
  159. }
  160. func hmacSha256(s, key string) string {
  161. hashed := hmac.New(sha256.New, []byte(key))
  162. hashed.Write([]byte(s))
  163. return string(hashed.Sum(nil))
  164. }
  165. func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
  166. // build canonical request string
  167. host := "hunyuan.tencentcloudapi.com"
  168. httpRequestMethod := "POST"
  169. canonicalURI := "/"
  170. canonicalQueryString := ""
  171. canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
  172. "application/json", host, strings.ToLower(adaptor.Action))
  173. signedHeaders := "content-type;host;x-tc-action"
  174. payload, _ := json.Marshal(req)
  175. hashedRequestPayload := sha256hex(string(payload))
  176. canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
  177. httpRequestMethod,
  178. canonicalURI,
  179. canonicalQueryString,
  180. canonicalHeaders,
  181. signedHeaders,
  182. hashedRequestPayload)
  183. // build string to sign
  184. algorithm := "TC3-HMAC-SHA256"
  185. requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
  186. timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
  187. t := time.Unix(timestamp, 0).UTC()
  188. // must be the format 2006-01-02, ref to package time for more info
  189. date := t.Format("2006-01-02")
  190. credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
  191. hashedCanonicalRequest := sha256hex(canonicalRequest)
  192. string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
  193. algorithm,
  194. requestTimestamp,
  195. credentialScope,
  196. hashedCanonicalRequest)
  197. // sign string
  198. secretDate := hmacSha256(date, "TC3"+secKey)
  199. secretService := hmacSha256("hunyuan", secretDate)
  200. secretKey := hmacSha256("tc3_request", secretService)
  201. signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
  202. // build authorization
  203. authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
  204. algorithm,
  205. secId,
  206. credentialScope,
  207. signedHeaders,
  208. signature)
  209. return authorization
  210. }