relay-zhipu.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. package controller
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "github.com/gin-gonic/gin"
  6. "github.com/golang-jwt/jwt"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. // https://open.bigmodel.cn/doc/api#chatglm_std
  15. // chatglm_std, chatglm_lite
  16. // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
  17. // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
  18. type ZhipuMessage struct {
  19. Role string `json:"role"`
  20. Content string `json:"content"`
  21. }
  22. type ZhipuRequest struct {
  23. Prompt []ZhipuMessage `json:"prompt"`
  24. Temperature float64 `json:"temperature,omitempty"`
  25. TopP float64 `json:"top_p,omitempty"`
  26. RequestId string `json:"request_id,omitempty"`
  27. Incremental bool `json:"incremental,omitempty"`
  28. }
  29. type ZhipuResponseData struct {
  30. TaskId string `json:"task_id"`
  31. RequestId string `json:"request_id"`
  32. TaskStatus string `json:"task_status"`
  33. Choices []ZhipuMessage `json:"choices"`
  34. Usage `json:"usage"`
  35. }
  36. type ZhipuResponse struct {
  37. Code int `json:"code"`
  38. Msg string `json:"msg"`
  39. Success bool `json:"success"`
  40. Data ZhipuResponseData `json:"data"`
  41. }
  42. type ZhipuStreamMetaResponse struct {
  43. RequestId string `json:"request_id"`
  44. TaskId string `json:"task_id"`
  45. TaskStatus string `json:"task_status"`
  46. Usage `json:"usage"`
  47. }
  48. type zhipuTokenData struct {
  49. Token string
  50. ExpiryTime time.Time
  51. }
  52. var zhipuTokens sync.Map
  53. var expSeconds int64 = 24 * 3600
  54. func getZhipuToken(apikey string) string {
  55. data, ok := zhipuTokens.Load(apikey)
  56. if ok {
  57. tokenData := data.(zhipuTokenData)
  58. if time.Now().Before(tokenData.ExpiryTime) {
  59. return tokenData.Token
  60. }
  61. }
  62. split := strings.Split(apikey, ".")
  63. if len(split) != 2 {
  64. common.SysError("invalid zhipu key: " + apikey)
  65. return ""
  66. }
  67. id := split[0]
  68. secret := split[1]
  69. expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
  70. expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
  71. timestamp := time.Now().UnixNano() / 1e6
  72. payload := jwt.MapClaims{
  73. "api_key": id,
  74. "exp": expMillis,
  75. "timestamp": timestamp,
  76. }
  77. token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
  78. token.Header["alg"] = "HS256"
  79. token.Header["sign_type"] = "SIGN"
  80. tokenString, err := token.SignedString([]byte(secret))
  81. if err != nil {
  82. return ""
  83. }
  84. zhipuTokens.Store(apikey, zhipuTokenData{
  85. Token: tokenString,
  86. ExpiryTime: expiryTime,
  87. })
  88. return tokenString
  89. }
  90. func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
  91. messages := make([]ZhipuMessage, 0, len(request.Messages))
  92. for _, message := range request.Messages {
  93. if message.Role == "system" {
  94. messages = append(messages, ZhipuMessage{
  95. Role: "system",
  96. Content: string(message.Content),
  97. })
  98. messages = append(messages, ZhipuMessage{
  99. Role: "user",
  100. Content: "Okay",
  101. })
  102. } else {
  103. messages = append(messages, ZhipuMessage{
  104. Role: message.Role,
  105. Content: string(message.Content),
  106. })
  107. }
  108. }
  109. return &ZhipuRequest{
  110. Prompt: messages,
  111. Temperature: request.Temperature,
  112. TopP: request.TopP,
  113. Incremental: false,
  114. }
  115. }
  116. func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
  117. fullTextResponse := OpenAITextResponse{
  118. Id: response.Data.TaskId,
  119. Object: "chat.completion",
  120. Created: common.GetTimestamp(),
  121. Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
  122. Usage: response.Data.Usage,
  123. }
  124. for i, choice := range response.Data.Choices {
  125. content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
  126. openaiChoice := OpenAITextResponseChoice{
  127. Index: i,
  128. Message: Message{
  129. Role: choice.Role,
  130. Content: content,
  131. },
  132. FinishReason: "",
  133. }
  134. if i == len(response.Data.Choices)-1 {
  135. openaiChoice.FinishReason = "stop"
  136. }
  137. fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
  138. }
  139. return &fullTextResponse
  140. }
  141. func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
  142. var choice ChatCompletionsStreamResponseChoice
  143. choice.Delta.Content = zhipuResponse
  144. response := ChatCompletionsStreamResponse{
  145. Object: "chat.completion.chunk",
  146. Created: common.GetTimestamp(),
  147. Model: "chatglm",
  148. Choices: []ChatCompletionsStreamResponseChoice{choice},
  149. }
  150. return &response
  151. }
  152. func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
  153. var choice ChatCompletionsStreamResponseChoice
  154. choice.Delta.Content = ""
  155. choice.FinishReason = &stopFinishReason
  156. response := ChatCompletionsStreamResponse{
  157. Id: zhipuResponse.RequestId,
  158. Object: "chat.completion.chunk",
  159. Created: common.GetTimestamp(),
  160. Model: "chatglm",
  161. Choices: []ChatCompletionsStreamResponseChoice{choice},
  162. }
  163. return &response, &zhipuResponse.Usage
  164. }
  165. func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  166. var usage *Usage
  167. scanner := bufio.NewScanner(resp.Body)
  168. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  169. if atEOF && len(data) == 0 {
  170. return 0, nil, nil
  171. }
  172. if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
  173. return i + 2, data[0:i], nil
  174. }
  175. if atEOF {
  176. return len(data), data, nil
  177. }
  178. return 0, nil, nil
  179. })
  180. dataChan := make(chan string)
  181. metaChan := make(chan string)
  182. stopChan := make(chan bool)
  183. go func() {
  184. for scanner.Scan() {
  185. data := scanner.Text()
  186. lines := strings.Split(data, "\n")
  187. for i, line := range lines {
  188. if len(line) < 5 {
  189. continue
  190. }
  191. if line[:5] == "data:" {
  192. dataChan <- line[5:]
  193. if i != len(lines)-1 {
  194. dataChan <- "\n"
  195. }
  196. } else if line[:5] == "meta:" {
  197. metaChan <- line[5:]
  198. }
  199. }
  200. }
  201. stopChan <- true
  202. }()
  203. setEventStreamHeaders(c)
  204. c.Stream(func(w io.Writer) bool {
  205. select {
  206. case data := <-dataChan:
  207. response := streamResponseZhipu2OpenAI(data)
  208. jsonResponse, err := json.Marshal(response)
  209. if err != nil {
  210. common.SysError("error marshalling stream response: " + err.Error())
  211. return true
  212. }
  213. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  214. return true
  215. case data := <-metaChan:
  216. var zhipuResponse ZhipuStreamMetaResponse
  217. err := json.Unmarshal([]byte(data), &zhipuResponse)
  218. if err != nil {
  219. common.SysError("error unmarshalling stream response: " + err.Error())
  220. return true
  221. }
  222. response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
  223. jsonResponse, err := json.Marshal(response)
  224. if err != nil {
  225. common.SysError("error marshalling stream response: " + err.Error())
  226. return true
  227. }
  228. usage = zhipuUsage
  229. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  230. return true
  231. case <-stopChan:
  232. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  233. return false
  234. }
  235. })
  236. err := resp.Body.Close()
  237. if err != nil {
  238. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  239. }
  240. return nil, usage
  241. }
  242. func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  243. var zhipuResponse ZhipuResponse
  244. responseBody, err := io.ReadAll(resp.Body)
  245. if err != nil {
  246. return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  247. }
  248. err = resp.Body.Close()
  249. if err != nil {
  250. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  251. }
  252. err = json.Unmarshal(responseBody, &zhipuResponse)
  253. if err != nil {
  254. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  255. }
  256. if !zhipuResponse.Success {
  257. return &OpenAIErrorWithStatusCode{
  258. OpenAIError: OpenAIError{
  259. Message: zhipuResponse.Msg,
  260. Type: "zhipu_error",
  261. Param: "",
  262. Code: zhipuResponse.Code,
  263. },
  264. StatusCode: resp.StatusCode,
  265. }, nil
  266. }
  267. fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
  268. jsonResponse, err := json.Marshal(fullTextResponse)
  269. if err != nil {
  270. return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  271. }
  272. c.Writer.Header().Set("Content-Type", "application/json")
  273. c.Writer.WriteHeader(resp.StatusCode)
  274. _, err = c.Writer.Write(jsonResponse)
  275. return nil, &fullTextResponse.Usage
  276. }