Kaynağa Gözat

fix: try to fix tencent hunyuan #336

CalciumIon 1 yıl önce
ebeveyn
işleme
6c5b3b51b0

+ 1 - 1
controller/task.go

@@ -24,7 +24,7 @@ func UpdateTaskBulk() {
 	//imageModel := "midjourney"
 	for {
 		time.Sleep(time.Duration(15) * time.Second)
-		common.SysLog("任务进度轮询开始")
+		common.SysLog("	任务进度轮询开始")
 		ctx := context.TODO()
 		allTasks := model.GetAllUnFinishSyncTasks(500)
 		platformTask := make(map[constant.TaskPlatform][]*model.Task)

+ 14 - 6
relay/channel/tencent/adaptor.go

@@ -6,18 +6,26 @@ import (
 	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
+	"one-api/common"
 	"one-api/dto"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
+	"strconv"
 	"strings"
 )
 
 type Adaptor struct {
-	Sign string
+	Sign      string
+	Action    string
+	Version   string
+	Timestamp int64
 }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+	a.Action = "ChatCompletions"
+	a.Version = "2023-09-01"
+	a.Timestamp = common.GetTimestamp()
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -27,7 +35,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	req.Header.Set("Authorization", a.Sign)
-	req.Header.Set("X-TC-Action", info.UpstreamModelName)
+	req.Header.Set("X-TC-Action", a.Action)
+	req.Header.Set("X-TC-Version", a.Version)
+	req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
 	return nil
 }
 
@@ -37,15 +47,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	}
 	apiKey := c.Request.Header.Get("Authorization")
 	apiKey = strings.TrimPrefix(apiKey, "Bearer ")
-	appId, secretId, secretKey, err := parseTencentConfig(apiKey)
+	_, secretId, secretKey, err := parseTencentConfig(apiKey)
 	if err != nil {
 		return nil, err
 	}
 	tencentRequest := requestOpenAI2Tencent(*request)
-	tencentRequest.AppId = appId
-	tencentRequest.SecretId = secretId
 	// we have to calculate the sign here
-	a.Sign = getTencentSign(*tencentRequest, secretKey)
+	a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
 	return tencentRequest, nil
 }
 

+ 52 - 43
relay/channel/tencent/dto.go

@@ -1,62 +1,71 @@
 package tencent
 
-import "one-api/dto"
-
 type TencentMessage struct {
-	Role    string `json:"role"`
-	Content string `json:"content"`
+	Role    string `json:"Role"`
+	Content string `json:"Content"`
 }
 
 type TencentChatRequest struct {
-	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID
-	SecretId string `json:"secret_id"` // 官网 SecretId
-	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
-	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
-	Timestamp int64 `json:"timestamp"`
-	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
-	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
-	Expired int64  `json:"expired"`
-	QueryID string `json:"query_id"` //请求 Id,用于问题排查
-	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
-	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
-	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
-	Temperature float64 `json:"temperature"`
-	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
-	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
-	// 建议该参数和 temperature 只设置1个,不要同时更改
-	TopP float64 `json:"top_p"`
-	// Stream 0:同步,1:流式 (默认,协议:SSE)
-	// 同步请求超时:60s,如果内容较长建议使用流式
-	Stream int `json:"stream"`
-	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
-	// 输入 content 总数最大支持 3000 token。
-	Messages []TencentMessage `json:"messages"`
-	Model    string           `json:"model"` // 模型名称
+	// 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。
+	// 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。
+	//
+	// 注意:
+	// 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。
+	Model *string `json:"Model"`
+	// 聊天上下文信息。
+	// 说明:
+	// 1. 长度最多为 40,按对话时间从旧到新在数组中排列。
+	// 2. Message.Role 可选值:system、user、assistant。
+	// 其中,system 角色可选,如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system(可选) user assistant user assistant user ...]。
+	// 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。
+	Messages []*TencentMessage `json:"Messages"`
+	// 流式调用开关。
+	// 说明:
+	// 1. 未传值时默认为非流式调用(false)。
+	// 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。
+	// 3. 非流式调用时:
+	// 调用方式与普通 HTTP 请求无异。
+	// 接口响应耗时较长,**如需更低时延建议设置为 true**。
+	// 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。
+	//
+	// 注意:
+	// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
+	Stream *bool `json:"Stream"`
+	// 说明:
+	// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
+	// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
+	// 3. 非必要不建议使用,不合理的取值会影响效果。
+	TopP *float64 `json:"TopP"`
+	// 说明:
+	// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
+	// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
+	// 3. 非必要不建议使用,不合理的取值会影响效果。
+	Temperature *float64 `json:"Temperature"`
 }
 
 type TencentError struct {
-	Code    int    `json:"code"`
-	Message string `json:"message"`
+	Code    int    `json:"Code"`
+	Message string `json:"Message"`
 }
 
 type TencentUsage struct {
-	InputTokens  int `json:"input_tokens"`
-	OutputTokens int `json:"output_tokens"`
-	TotalTokens  int `json:"total_tokens"`
+	PromptTokens     int `json:"PromptTokens"`
+	CompletionTokens int `json:"CompletionTokens"`
+	TotalTokens      int `json:"TotalTokens"`
 }
 
 type TencentResponseChoices struct {
-	FinishReason string         `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
-	Messages     TencentMessage `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
-	Delta        TencentMessage `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
+	FinishReason string         `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
+	Messages     TencentMessage `json:"Message,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
+	Delta        TencentMessage `json:"Delta,omitempty"`        // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
 }
 
 type TencentChatResponse struct {
-	Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
-	Created string                   `json:"created,omitempty"` // unix 时间戳的字符串
-	Id      string                   `json:"id,omitempty"`      // 会话 id
-	Usage   dto.Usage                `json:"usage,omitempty"`   // token 数量
-	Error   TencentError             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值
-	Note    string                   `json:"note,omitempty"`    // 注释
-	ReqID   string                   `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
+	Choices []TencentResponseChoices `json:"Choices,omitempty"` // 结果
+	Created int64                    `json:"Created,omitempty"` // unix 时间戳的字符串
+	Id      string                   `json:"Id,omitempty"`      // 会话 id
+	Usage   TencentUsage             `json:"Usage,omitempty"`   // token 数量
+	Error   TencentError             `json:"Error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值
+	Note    string                   `json:"Note,omitempty"`    // 注释
+	ReqID   string                   `json:"Req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
 }

+ 104 - 103
relay/channel/tencent/relay-tencent.go

@@ -3,8 +3,8 @@ package tencent
 import (
 	"bufio"
 	"crypto/hmac"
-	"crypto/sha1"
-	"encoding/base64"
+	"crypto/sha256"
+	"encoding/hex"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -15,46 +15,28 @@ import (
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
-	"sort"
 	"strconv"
 	"strings"
+	"time"
 )
 
 // https://cloud.tencent.com/document/product/1729/97732
 
 func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
-	messages := make([]TencentMessage, 0, len(request.Messages))
+	messages := make([]*TencentMessage, 0, len(request.Messages))
 	for i := 0; i < len(request.Messages); i++ {
 		message := request.Messages[i]
-		if message.Role == "system" {
-			messages = append(messages, TencentMessage{
-				Role:    "user",
-				Content: message.StringContent(),
-			})
-			messages = append(messages, TencentMessage{
-				Role:    "assistant",
-				Content: "Okay",
-			})
-			continue
-		}
-		messages = append(messages, TencentMessage{
+		messages = append(messages, &TencentMessage{
 			Content: message.StringContent(),
 			Role:    message.Role,
 		})
 	}
-	stream := 0
-	if request.Stream {
-		stream = 1
-	}
 	return &TencentChatRequest{
-		Timestamp:   common.GetTimestamp(),
-		Expired:     common.GetTimestamp() + 24*60*60,
-		QueryID:     common.GetUUID(),
-		Temperature: request.Temperature,
-		TopP:        request.TopP,
-		Stream:      stream,
+		Temperature: &request.Temperature,
+		TopP:        &request.TopP,
+		Stream:      &request.Stream,
 		Messages:    messages,
-		Model:       request.Model,
+		Model:       &request.Model,
 	}
 }
 
@@ -62,7 +44,11 @@ func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextRespon
 	fullTextResponse := dto.OpenAITextResponse{
 		Object:  "chat.completion",
 		Created: common.GetTimestamp(),
-		Usage:   response.Usage,
+		Usage: dto.Usage{
+			PromptTokens:     response.Usage.PromptTokens,
+			CompletionTokens: response.Usage.CompletionTokens,
+			TotalTokens:      response.Usage.TotalTokens,
+		},
 	}
 	if len(response.Choices) > 0 {
 		content, _ := json.Marshal(response.Choices[0].Messages.Content)
@@ -99,64 +85,46 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
 func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
 	var responseText string
 	scanner := bufio.NewScanner(resp.Body)
-	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
-		if atEOF && len(data) == 0 {
-			return 0, nil, nil
-		}
-		if i := strings.Index(string(data), "\n"); i >= 0 {
-			return i + 1, data[0:i], nil
+	scanner.Split(bufio.ScanLines)
+
+	service.SetEventStreamHeaders(c)
+
+	for scanner.Scan() {
+		data := scanner.Text()
+		if len(data) < 5 || !strings.HasPrefix(data, "data:") {
+			continue
 		}
-		if atEOF {
-			return len(data), data, nil
+		data = strings.TrimPrefix(data, "data:")
+
+		var tencentResponse TencentChatResponse
+		err := json.Unmarshal([]byte(data), &tencentResponse)
+		if err != nil {
+			common.SysError("error unmarshalling stream response: " + err.Error())
+			continue
 		}
-		return 0, nil, nil
-	})
-	dataChan := make(chan string)
-	stopChan := make(chan bool)
-	go func() {
-		for scanner.Scan() {
-			data := scanner.Text()
-			if len(data) < 5 { // ignore blank line or wrong format
-				continue
-			}
-			if data[:5] != "data:" {
-				continue
-			}
-			data = data[5:]
-			dataChan <- data
+
+		response := streamResponseTencent2OpenAI(&tencentResponse)
+		if len(response.Choices) != 0 {
+			responseText += response.Choices[0].Delta.GetContentString()
 		}
-		stopChan <- true
-	}()
-	service.SetEventStreamHeaders(c)
-	c.Stream(func(w io.Writer) bool {
-		select {
-		case data := <-dataChan:
-			var TencentResponse TencentChatResponse
-			err := json.Unmarshal([]byte(data), &TencentResponse)
-			if err != nil {
-				common.SysError("error unmarshalling stream response: " + err.Error())
-				return true
-			}
-			response := streamResponseTencent2OpenAI(&TencentResponse)
-			if len(response.Choices) != 0 {
-				responseText += response.Choices[0].Delta.GetContentString()
-			}
-			jsonResponse, err := json.Marshal(response)
-			if err != nil {
-				common.SysError("error marshalling stream response: " + err.Error())
-				return true
-			}
-			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
-			return true
-		case <-stopChan:
-			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
-			return false
+
+		err = service.ObjectData(c, response)
+		if err != nil {
+			common.SysError(err.Error())
 		}
-	})
+	}
+
+	if err := scanner.Err(); err != nil {
+		common.SysError("error reading stream: " + err.Error())
+	}
+
+	service.Done(c)
+
 	err := resp.Body.Close()
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 	}
+
 	return nil, responseText
 }
 
@@ -206,29 +174,62 @@ func parseTencentConfig(config string) (appId int64, secretId string, secretKey
 	return
 }
 
-func getTencentSign(req TencentChatRequest, secretKey string) string {
-	params := make([]string, 0)
-	params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
-	params = append(params, "secret_id="+req.SecretId)
-	params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
-	params = append(params, "query_id="+req.QueryID)
-	params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
-	params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
-	params = append(params, "stream="+strconv.Itoa(req.Stream))
-	params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
-
-	var messageStr string
-	for _, msg := range req.Messages {
-		messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
-	}
-	messageStr = strings.TrimSuffix(messageStr, ",")
-	params = append(params, "messages=["+messageStr+"]")
-
-	sort.Sort(sort.StringSlice(params))
-	url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
-	mac := hmac.New(sha1.New, []byte(secretKey))
-	signURL := url
-	mac.Write([]byte(signURL))
-	sign := mac.Sum([]byte(nil))
-	return base64.StdEncoding.EncodeToString(sign)
+func sha256hex(s string) string {
+	b := sha256.Sum256([]byte(s))
+	return hex.EncodeToString(b[:])
+}
+
+func hmacSha256(s, key string) string {
+	hashed := hmac.New(sha256.New, []byte(key))
+	hashed.Write([]byte(s))
+	return string(hashed.Sum(nil))
+}
+
+func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
+	// build canonical request string
+	host := "hunyuan.tencentcloudapi.com"
+	httpRequestMethod := "POST"
+	canonicalURI := "/"
+	canonicalQueryString := ""
+	canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
+		"application/json", host, strings.ToLower(adaptor.Action))
+	signedHeaders := "content-type;host;x-tc-action"
+	payload, _ := json.Marshal(req)
+	hashedRequestPayload := sha256hex(string(payload))
+	canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
+		httpRequestMethod,
+		canonicalURI,
+		canonicalQueryString,
+		canonicalHeaders,
+		signedHeaders,
+		hashedRequestPayload)
+	// build string to sign
+	algorithm := "TC3-HMAC-SHA256"
+	requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
+	timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
+	t := time.Unix(timestamp, 0).UTC()
+	// must be the format 2006-01-02, ref to package time for more info
+	date := t.Format("2006-01-02")
+	credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
+	hashedCanonicalRequest := sha256hex(canonicalRequest)
+	string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
+		algorithm,
+		requestTimestamp,
+		credentialScope,
+		hashedCanonicalRequest)
+
+	// sign string
+	secretDate := hmacSha256(date, "TC3"+secKey)
+	secretService := hmacSha256("hunyuan", secretDate)
+	secretKey := hmacSha256("tc3_request", secretService)
+	signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
+
+	// build authorization
+	authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
+		algorithm,
+		secId,
+		credentialScope,
+		signedHeaders,
+		signature)
+	return authorization
 }

+ 27 - 1
service/sse.go

@@ -1,6 +1,12 @@
 package service
 
-import "github.com/gin-gonic/gin"
+import (
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"one-api/common"
+	"strings"
+)
 
 func SetEventStreamHeaders(c *gin.Context) {
 	c.Writer.Header().Set("Content-Type", "text/event-stream")
@@ -9,3 +15,23 @@ func SetEventStreamHeaders(c *gin.Context) {
 	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 	c.Writer.Header().Set("X-Accel-Buffering", "no")
 }
+
+func StringData(c *gin.Context, str string) {
+	str = strings.TrimPrefix(str, "data: ")
+	str = strings.TrimSuffix(str, "\r")
+	c.Render(-1, common.CustomEvent{Data: "data: " + str})
+	c.Writer.Flush()
+}
+
+func ObjectData(c *gin.Context, object interface{}) error {
+	jsonData, err := json.Marshal(object)
+	if err != nil {
+		return fmt.Errorf("error marshalling object: %w", err)
+	}
+	StringData(c, string(jsonData))
+	return nil
+}
+
+func Done(c *gin.Context) {
+	StringData(c, "[DONE]")
+}