relay-zhipu.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. package controller
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "github.com/gin-gonic/gin"
  6. "github.com/golang-jwt/jwt"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. // https://open.bigmodel.cn/doc/api#chatglm_std
  15. // chatglm_std, chatglm_lite
  16. // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
  17. // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
  18. type ZhipuMessage struct {
  19. Role string `json:"role"`
  20. Content string `json:"content"`
  21. }
  22. type ZhipuRequest struct {
  23. Prompt []ZhipuMessage `json:"prompt"`
  24. Temperature float64 `json:"temperature,omitempty"`
  25. TopP float64 `json:"top_p,omitempty"`
  26. RequestId string `json:"request_id,omitempty"`
  27. Incremental bool `json:"incremental,omitempty"`
  28. }
  29. type ZhipuResponseData struct {
  30. TaskId string `json:"task_id"`
  31. RequestId string `json:"request_id"`
  32. TaskStatus string `json:"task_status"`
  33. Choices []ZhipuMessage `json:"choices"`
  34. Usage `json:"usage"`
  35. }
  36. type ZhipuResponse struct {
  37. Code int `json:"code"`
  38. Msg string `json:"msg"`
  39. Success bool `json:"success"`
  40. Data ZhipuResponseData `json:"data"`
  41. }
  42. type ZhipuStreamMetaResponse struct {
  43. RequestId string `json:"request_id"`
  44. TaskId string `json:"task_id"`
  45. TaskStatus string `json:"task_status"`
  46. Usage `json:"usage"`
  47. }
  48. type zhipuTokenData struct {
  49. Token string
  50. ExpiryTime time.Time
  51. }
  52. var zhipuTokens sync.Map
  53. var expSeconds int64 = 24 * 3600
  54. func getZhipuToken(apikey string) string {
  55. data, ok := zhipuTokens.Load(apikey)
  56. if ok {
  57. tokenData := data.(zhipuTokenData)
  58. if time.Now().Before(tokenData.ExpiryTime) {
  59. return tokenData.Token
  60. }
  61. }
  62. split := strings.Split(apikey, ".")
  63. if len(split) != 2 {
  64. common.SysError("invalid zhipu key: " + apikey)
  65. return ""
  66. }
  67. id := split[0]
  68. secret := split[1]
  69. expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
  70. expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
  71. timestamp := time.Now().UnixNano() / 1e6
  72. payload := jwt.MapClaims{
  73. "api_key": id,
  74. "exp": expMillis,
  75. "timestamp": timestamp,
  76. }
  77. token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
  78. token.Header["alg"] = "HS256"
  79. token.Header["sign_type"] = "SIGN"
  80. tokenString, err := token.SignedString([]byte(secret))
  81. if err != nil {
  82. return ""
  83. }
  84. zhipuTokens.Store(apikey, zhipuTokenData{
  85. Token: tokenString,
  86. ExpiryTime: expiryTime,
  87. })
  88. return tokenString
  89. }
  90. func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
  91. messages := make([]ZhipuMessage, 0, len(request.Messages))
  92. for _, message := range request.Messages {
  93. messages = append(messages, ZhipuMessage{
  94. Role: message.Role,
  95. Content: message.Content,
  96. })
  97. }
  98. return &ZhipuRequest{
  99. Prompt: messages,
  100. Temperature: request.Temperature,
  101. TopP: request.TopP,
  102. Incremental: false,
  103. }
  104. }
  105. func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
  106. fullTextResponse := OpenAITextResponse{
  107. Id: response.Data.TaskId,
  108. Object: "chat.completion",
  109. Created: common.GetTimestamp(),
  110. Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
  111. Usage: response.Data.Usage,
  112. }
  113. for i, choice := range response.Data.Choices {
  114. openaiChoice := OpenAITextResponseChoice{
  115. Index: i,
  116. Message: Message{
  117. Role: choice.Role,
  118. Content: strings.Trim(choice.Content, "\""),
  119. },
  120. FinishReason: "",
  121. }
  122. if i == len(response.Data.Choices)-1 {
  123. openaiChoice.FinishReason = "stop"
  124. }
  125. fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
  126. }
  127. return &fullTextResponse
  128. }
  129. func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
  130. var choice ChatCompletionsStreamResponseChoice
  131. choice.Delta.Content = zhipuResponse
  132. choice.FinishReason = ""
  133. response := ChatCompletionsStreamResponse{
  134. Object: "chat.completion.chunk",
  135. Created: common.GetTimestamp(),
  136. Model: "chatglm",
  137. Choices: []ChatCompletionsStreamResponseChoice{choice},
  138. }
  139. return &response
  140. }
  141. func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
  142. var choice ChatCompletionsStreamResponseChoice
  143. choice.Delta.Content = ""
  144. choice.FinishReason = "stop"
  145. response := ChatCompletionsStreamResponse{
  146. Id: zhipuResponse.RequestId,
  147. Object: "chat.completion.chunk",
  148. Created: common.GetTimestamp(),
  149. Model: "chatglm",
  150. Choices: []ChatCompletionsStreamResponseChoice{choice},
  151. }
  152. return &response, &zhipuResponse.Usage
  153. }
  154. func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  155. var usage *Usage
  156. scanner := bufio.NewScanner(resp.Body)
  157. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  158. if atEOF && len(data) == 0 {
  159. return 0, nil, nil
  160. }
  161. if i := strings.Index(string(data), "\n"); i >= 0 {
  162. return i + 1, data[0:i], nil
  163. }
  164. if atEOF {
  165. return len(data), data, nil
  166. }
  167. return 0, nil, nil
  168. })
  169. dataChan := make(chan string)
  170. metaChan := make(chan string)
  171. stopChan := make(chan bool)
  172. go func() {
  173. for scanner.Scan() {
  174. data := scanner.Text()
  175. data = strings.Trim(data, "\"")
  176. if len(data) < 5 { // ignore blank line or wrong format
  177. continue
  178. }
  179. if data[:5] == "data:" {
  180. dataChan <- data[5:]
  181. } else if data[:5] == "meta:" {
  182. metaChan <- data[5:]
  183. }
  184. }
  185. stopChan <- true
  186. }()
  187. c.Writer.Header().Set("Content-Type", "text/event-stream")
  188. c.Writer.Header().Set("Cache-Control", "no-cache")
  189. c.Writer.Header().Set("Connection", "keep-alive")
  190. c.Writer.Header().Set("Transfer-Encoding", "chunked")
  191. c.Writer.Header().Set("X-Accel-Buffering", "no")
  192. c.Stream(func(w io.Writer) bool {
  193. select {
  194. case data := <-dataChan:
  195. response := streamResponseZhipu2OpenAI(data)
  196. jsonResponse, err := json.Marshal(response)
  197. if err != nil {
  198. common.SysError("error marshalling stream response: " + err.Error())
  199. return true
  200. }
  201. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  202. return true
  203. case data := <-metaChan:
  204. var zhipuResponse ZhipuStreamMetaResponse
  205. err := json.Unmarshal([]byte(data), &zhipuResponse)
  206. if err != nil {
  207. common.SysError("error unmarshalling stream response: " + err.Error())
  208. return true
  209. }
  210. response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
  211. jsonResponse, err := json.Marshal(response)
  212. if err != nil {
  213. common.SysError("error marshalling stream response: " + err.Error())
  214. return true
  215. }
  216. usage = zhipuUsage
  217. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  218. return true
  219. case <-stopChan:
  220. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  221. return false
  222. }
  223. })
  224. err := resp.Body.Close()
  225. if err != nil {
  226. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  227. }
  228. return nil, usage
  229. }
  230. func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  231. var zhipuResponse ZhipuResponse
  232. responseBody, err := io.ReadAll(resp.Body)
  233. if err != nil {
  234. return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  235. }
  236. err = resp.Body.Close()
  237. if err != nil {
  238. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  239. }
  240. err = json.Unmarshal(responseBody, &zhipuResponse)
  241. if err != nil {
  242. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  243. }
  244. if !zhipuResponse.Success {
  245. return &OpenAIErrorWithStatusCode{
  246. OpenAIError: OpenAIError{
  247. Message: zhipuResponse.Msg,
  248. Type: "zhipu_error",
  249. Param: "",
  250. Code: zhipuResponse.Code,
  251. },
  252. StatusCode: resp.StatusCode,
  253. }, nil
  254. }
  255. fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
  256. jsonResponse, err := json.Marshal(fullTextResponse)
  257. if err != nil {
  258. return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  259. }
  260. c.Writer.Header().Set("Content-Type", "application/json")
  261. c.Writer.WriteHeader(resp.StatusCode)
  262. _, err = c.Writer.Write(jsonResponse)
  263. return nil, &fullTextResponse.Usage
  264. }