|
|
@@ -3,8 +3,8 @@ package tencent
|
|
|
import (
|
|
|
"bufio"
|
|
|
"crypto/hmac"
|
|
|
- "crypto/sha1"
|
|
|
- "encoding/base64"
|
|
|
+ "crypto/sha256"
|
|
|
+ "encoding/hex"
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
@@ -15,46 +15,28 @@ import (
|
|
|
"one-api/dto"
|
|
|
relaycommon "one-api/relay/common"
|
|
|
"one-api/service"
|
|
|
- "sort"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
+ "time"
|
|
|
)
|
|
|
|
|
|
// https://cloud.tencent.com/document/product/1729/97732
|
|
|
|
|
|
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
|
|
- messages := make([]TencentMessage, 0, len(request.Messages))
|
|
|
+ messages := make([]*TencentMessage, 0, len(request.Messages))
|
|
|
for i := 0; i < len(request.Messages); i++ {
|
|
|
message := request.Messages[i]
|
|
|
- if message.Role == "system" {
|
|
|
- messages = append(messages, TencentMessage{
|
|
|
- Role: "user",
|
|
|
- Content: message.StringContent(),
|
|
|
- })
|
|
|
- messages = append(messages, TencentMessage{
|
|
|
- Role: "assistant",
|
|
|
- Content: "Okay",
|
|
|
- })
|
|
|
- continue
|
|
|
- }
|
|
|
- messages = append(messages, TencentMessage{
|
|
|
+ messages = append(messages, &TencentMessage{
|
|
|
Content: message.StringContent(),
|
|
|
Role: message.Role,
|
|
|
})
|
|
|
}
|
|
|
- stream := 0
|
|
|
- if request.Stream {
|
|
|
- stream = 1
|
|
|
- }
|
|
|
return &TencentChatRequest{
|
|
|
- Timestamp: common.GetTimestamp(),
|
|
|
- Expired: common.GetTimestamp() + 24*60*60,
|
|
|
- QueryID: common.GetUUID(),
|
|
|
- Temperature: request.Temperature,
|
|
|
- TopP: request.TopP,
|
|
|
- Stream: stream,
|
|
|
+ Temperature: &request.Temperature,
|
|
|
+ TopP: &request.TopP,
|
|
|
+ Stream: &request.Stream,
|
|
|
Messages: messages,
|
|
|
- Model: request.Model,
|
|
|
+ Model: &request.Model,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -62,7 +44,11 @@ func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextRespon
|
|
|
fullTextResponse := dto.OpenAITextResponse{
|
|
|
Object: "chat.completion",
|
|
|
Created: common.GetTimestamp(),
|
|
|
- Usage: response.Usage,
|
|
|
+ Usage: dto.Usage{
|
|
|
+ PromptTokens: response.Usage.PromptTokens,
|
|
|
+ CompletionTokens: response.Usage.CompletionTokens,
|
|
|
+ TotalTokens: response.Usage.TotalTokens,
|
|
|
+ },
|
|
|
}
|
|
|
if len(response.Choices) > 0 {
|
|
|
content, _ := json.Marshal(response.Choices[0].Messages.Content)
|
|
|
@@ -99,64 +85,46 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
|
|
|
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
|
|
var responseText string
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
|
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
|
- if atEOF && len(data) == 0 {
|
|
|
- return 0, nil, nil
|
|
|
- }
|
|
|
- if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
|
- return i + 1, data[0:i], nil
|
|
|
+ scanner.Split(bufio.ScanLines)
|
|
|
+
|
|
|
+ service.SetEventStreamHeaders(c)
|
|
|
+
|
|
|
+ for scanner.Scan() {
|
|
|
+ data := scanner.Text()
|
|
|
+ if len(data) < 5 || !strings.HasPrefix(data, "data:") {
|
|
|
+ continue
|
|
|
}
|
|
|
- if atEOF {
|
|
|
- return len(data), data, nil
|
|
|
+ data = strings.TrimPrefix(data, "data:")
|
|
|
+
|
|
|
+ var tencentResponse TencentChatResponse
|
|
|
+ err := json.Unmarshal([]byte(data), &tencentResponse)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("error unmarshalling stream response: " + err.Error())
|
|
|
+ continue
|
|
|
}
|
|
|
- return 0, nil, nil
|
|
|
- })
|
|
|
- dataChan := make(chan string)
|
|
|
- stopChan := make(chan bool)
|
|
|
- go func() {
|
|
|
- for scanner.Scan() {
|
|
|
- data := scanner.Text()
|
|
|
- if len(data) < 5 { // ignore blank line or wrong format
|
|
|
- continue
|
|
|
- }
|
|
|
- if data[:5] != "data:" {
|
|
|
- continue
|
|
|
- }
|
|
|
- data = data[5:]
|
|
|
- dataChan <- data
|
|
|
+
|
|
|
+ response := streamResponseTencent2OpenAI(&tencentResponse)
|
|
|
+ if len(response.Choices) != 0 {
|
|
|
+ responseText += response.Choices[0].Delta.GetContentString()
|
|
|
}
|
|
|
- stopChan <- true
|
|
|
- }()
|
|
|
- service.SetEventStreamHeaders(c)
|
|
|
- c.Stream(func(w io.Writer) bool {
|
|
|
- select {
|
|
|
- case data := <-dataChan:
|
|
|
- var TencentResponse TencentChatResponse
|
|
|
- err := json.Unmarshal([]byte(data), &TencentResponse)
|
|
|
- if err != nil {
|
|
|
- common.SysError("error unmarshalling stream response: " + err.Error())
|
|
|
- return true
|
|
|
- }
|
|
|
- response := streamResponseTencent2OpenAI(&TencentResponse)
|
|
|
- if len(response.Choices) != 0 {
|
|
|
- responseText += response.Choices[0].Delta.GetContentString()
|
|
|
- }
|
|
|
- jsonResponse, err := json.Marshal(response)
|
|
|
- if err != nil {
|
|
|
- common.SysError("error marshalling stream response: " + err.Error())
|
|
|
- return true
|
|
|
- }
|
|
|
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
|
- return true
|
|
|
- case <-stopChan:
|
|
|
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
|
- return false
|
|
|
+
|
|
|
+ err = service.ObjectData(c, response)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError(err.Error())
|
|
|
}
|
|
|
- })
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := scanner.Err(); err != nil {
|
|
|
+ common.SysError("error reading stream: " + err.Error())
|
|
|
+ }
|
|
|
+
|
|
|
+ service.Done(c)
|
|
|
+
|
|
|
err := resp.Body.Close()
|
|
|
if err != nil {
|
|
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
|
|
}
|
|
|
+
|
|
|
return nil, responseText
|
|
|
}
|
|
|
|
|
|
@@ -206,29 +174,62 @@ func parseTencentConfig(config string) (appId int64, secretId string, secretKey
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func getTencentSign(req TencentChatRequest, secretKey string) string {
|
|
|
- params := make([]string, 0)
|
|
|
- params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
|
|
- params = append(params, "secret_id="+req.SecretId)
|
|
|
- params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
|
|
- params = append(params, "query_id="+req.QueryID)
|
|
|
- params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
|
|
- params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
|
|
- params = append(params, "stream="+strconv.Itoa(req.Stream))
|
|
|
- params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
|
|
-
|
|
|
- var messageStr string
|
|
|
- for _, msg := range req.Messages {
|
|
|
- messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
|
|
- }
|
|
|
- messageStr = strings.TrimSuffix(messageStr, ",")
|
|
|
- params = append(params, "messages=["+messageStr+"]")
|
|
|
-
|
|
|
- sort.Sort(sort.StringSlice(params))
|
|
|
- url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
|
|
- mac := hmac.New(sha1.New, []byte(secretKey))
|
|
|
- signURL := url
|
|
|
- mac.Write([]byte(signURL))
|
|
|
- sign := mac.Sum([]byte(nil))
|
|
|
- return base64.StdEncoding.EncodeToString(sign)
|
|
|
+func sha256hex(s string) string {
|
|
|
+ b := sha256.Sum256([]byte(s))
|
|
|
+ return hex.EncodeToString(b[:])
|
|
|
+}
|
|
|
+
|
|
|
+func hmacSha256(s, key string) string {
|
|
|
+ hashed := hmac.New(sha256.New, []byte(key))
|
|
|
+ hashed.Write([]byte(s))
|
|
|
+ return string(hashed.Sum(nil))
|
|
|
+}
|
|
|
+
|
|
|
+func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
|
|
|
+ // build canonical request string
|
|
|
+ host := "hunyuan.tencentcloudapi.com"
|
|
|
+ httpRequestMethod := "POST"
|
|
|
+ canonicalURI := "/"
|
|
|
+ canonicalQueryString := ""
|
|
|
+ canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
|
|
|
+ "application/json", host, strings.ToLower(adaptor.Action))
|
|
|
+ signedHeaders := "content-type;host;x-tc-action"
|
|
|
+ payload, _ := json.Marshal(req)
|
|
|
+ hashedRequestPayload := sha256hex(string(payload))
|
|
|
+ canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
|
|
+ httpRequestMethod,
|
|
|
+ canonicalURI,
|
|
|
+ canonicalQueryString,
|
|
|
+ canonicalHeaders,
|
|
|
+ signedHeaders,
|
|
|
+ hashedRequestPayload)
|
|
|
+ // build string to sign
|
|
|
+ algorithm := "TC3-HMAC-SHA256"
|
|
|
+ requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
|
|
|
+ timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
|
|
|
+ t := time.Unix(timestamp, 0).UTC()
|
|
|
+ // must be the format 2006-01-02, ref to package time for more info
|
|
|
+ date := t.Format("2006-01-02")
|
|
|
+ credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
|
|
|
+ hashedCanonicalRequest := sha256hex(canonicalRequest)
|
|
|
+ string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
|
|
|
+ algorithm,
|
|
|
+ requestTimestamp,
|
|
|
+ credentialScope,
|
|
|
+ hashedCanonicalRequest)
|
|
|
+
|
|
|
+ // sign string
|
|
|
+ secretDate := hmacSha256(date, "TC3"+secKey)
|
|
|
+ secretService := hmacSha256("hunyuan", secretDate)
|
|
|
+ secretKey := hmacSha256("tc3_request", secretService)
|
|
|
+ signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
|
|
|
+
|
|
|
+ // build authorization
|
|
|
+ authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
|
|
+ algorithm,
|
|
|
+ secId,
|
|
|
+ credentialScope,
|
|
|
+ signedHeaders,
|
|
|
+ signature)
|
|
|
+ return authorization
|
|
|
}
|