relay-zhipu.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. package controller
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "github.com/gin-gonic/gin"
  6. "github.com/golang-jwt/jwt"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. // https://open.bigmodel.cn/doc/api#chatglm_std
  15. // chatglm_std, chatglm_lite
  16. // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
  17. // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
  18. type ZhipuMessage struct {
  19. Role string `json:"role"`
  20. Content string `json:"content"`
  21. }
  22. type ZhipuRequest struct {
  23. Prompt []ZhipuMessage `json:"prompt"`
  24. Temperature float64 `json:"temperature,omitempty"`
  25. TopP float64 `json:"top_p,omitempty"`
  26. RequestId string `json:"request_id,omitempty"`
  27. Incremental bool `json:"incremental,omitempty"`
  28. }
  29. type ZhipuResponseData struct {
  30. TaskId string `json:"task_id"`
  31. RequestId string `json:"request_id"`
  32. TaskStatus string `json:"task_status"`
  33. Choices []ZhipuMessage `json:"choices"`
  34. Usage `json:"usage"`
  35. }
  36. type ZhipuResponse struct {
  37. Code int `json:"code"`
  38. Msg string `json:"msg"`
  39. Success bool `json:"success"`
  40. Data ZhipuResponseData `json:"data"`
  41. }
  42. type ZhipuStreamMetaResponse struct {
  43. RequestId string `json:"request_id"`
  44. TaskId string `json:"task_id"`
  45. TaskStatus string `json:"task_status"`
  46. Usage `json:"usage"`
  47. }
  48. type zhipuTokenData struct {
  49. Token string
  50. ExpiryTime time.Time
  51. }
  52. var zhipuTokens sync.Map
  53. var expSeconds int64 = 24 * 3600
  54. func getZhipuToken(apikey string) string {
  55. data, ok := zhipuTokens.Load(apikey)
  56. if ok {
  57. tokenData := data.(zhipuTokenData)
  58. if time.Now().Before(tokenData.ExpiryTime) {
  59. return tokenData.Token
  60. }
  61. }
  62. split := strings.Split(apikey, ".")
  63. if len(split) != 2 {
  64. common.SysError("invalid zhipu key: " + apikey)
  65. return ""
  66. }
  67. id := split[0]
  68. secret := split[1]
  69. expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
  70. expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
  71. timestamp := time.Now().UnixNano() / 1e6
  72. payload := jwt.MapClaims{
  73. "api_key": id,
  74. "exp": expMillis,
  75. "timestamp": timestamp,
  76. }
  77. token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
  78. token.Header["alg"] = "HS256"
  79. token.Header["sign_type"] = "SIGN"
  80. tokenString, err := token.SignedString([]byte(secret))
  81. if err != nil {
  82. return ""
  83. }
  84. zhipuTokens.Store(apikey, zhipuTokenData{
  85. Token: tokenString,
  86. ExpiryTime: expiryTime,
  87. })
  88. return tokenString
  89. }
  90. func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
  91. messages := make([]ZhipuMessage, 0, len(request.Messages))
  92. for _, message := range request.Messages {
  93. if message.Role == "system" {
  94. messages = append(messages, ZhipuMessage{
  95. Role: "system",
  96. Content: message.Content,
  97. })
  98. messages = append(messages, ZhipuMessage{
  99. Role: "user",
  100. Content: "Okay",
  101. })
  102. } else {
  103. messages = append(messages, ZhipuMessage{
  104. Role: message.Role,
  105. Content: message.Content,
  106. })
  107. }
  108. }
  109. return &ZhipuRequest{
  110. Prompt: messages,
  111. Temperature: request.Temperature,
  112. TopP: request.TopP,
  113. Incremental: false,
  114. }
  115. }
  116. func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
  117. fullTextResponse := OpenAITextResponse{
  118. Id: response.Data.TaskId,
  119. Object: "chat.completion",
  120. Created: common.GetTimestamp(),
  121. Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
  122. Usage: response.Data.Usage,
  123. }
  124. for i, choice := range response.Data.Choices {
  125. openaiChoice := OpenAITextResponseChoice{
  126. Index: i,
  127. Message: Message{
  128. Role: choice.Role,
  129. Content: strings.Trim(choice.Content, "\""),
  130. },
  131. FinishReason: "",
  132. }
  133. if i == len(response.Data.Choices)-1 {
  134. openaiChoice.FinishReason = "stop"
  135. }
  136. fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
  137. }
  138. return &fullTextResponse
  139. }
  140. func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
  141. var choice ChatCompletionsStreamResponseChoice
  142. choice.Delta.Content = zhipuResponse
  143. response := ChatCompletionsStreamResponse{
  144. Object: "chat.completion.chunk",
  145. Created: common.GetTimestamp(),
  146. Model: "chatglm",
  147. Choices: []ChatCompletionsStreamResponseChoice{choice},
  148. }
  149. return &response
  150. }
  151. func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
  152. var choice ChatCompletionsStreamResponseChoice
  153. choice.Delta.Content = ""
  154. choice.FinishReason = &stopFinishReason
  155. response := ChatCompletionsStreamResponse{
  156. Id: zhipuResponse.RequestId,
  157. Object: "chat.completion.chunk",
  158. Created: common.GetTimestamp(),
  159. Model: "chatglm",
  160. Choices: []ChatCompletionsStreamResponseChoice{choice},
  161. }
  162. return &response, &zhipuResponse.Usage
  163. }
  164. func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  165. var usage *Usage
  166. scanner := bufio.NewScanner(resp.Body)
  167. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  168. if atEOF && len(data) == 0 {
  169. return 0, nil, nil
  170. }
  171. if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
  172. return i + 2, data[0:i], nil
  173. }
  174. if atEOF {
  175. return len(data), data, nil
  176. }
  177. return 0, nil, nil
  178. })
  179. dataChan := make(chan string)
  180. metaChan := make(chan string)
  181. stopChan := make(chan bool)
  182. go func() {
  183. for scanner.Scan() {
  184. data := scanner.Text()
  185. lines := strings.Split(data, "\n")
  186. for i, line := range lines {
  187. if len(line) < 5 {
  188. continue
  189. }
  190. if line[:5] == "data:" {
  191. dataChan <- line[5:]
  192. if i != len(lines)-1 {
  193. dataChan <- "\n"
  194. }
  195. } else if line[:5] == "meta:" {
  196. metaChan <- line[5:]
  197. }
  198. }
  199. }
  200. stopChan <- true
  201. }()
  202. setEventStreamHeaders(c)
  203. c.Stream(func(w io.Writer) bool {
  204. select {
  205. case data := <-dataChan:
  206. response := streamResponseZhipu2OpenAI(data)
  207. jsonResponse, err := json.Marshal(response)
  208. if err != nil {
  209. common.SysError("error marshalling stream response: " + err.Error())
  210. return true
  211. }
  212. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  213. return true
  214. case data := <-metaChan:
  215. var zhipuResponse ZhipuStreamMetaResponse
  216. err := json.Unmarshal([]byte(data), &zhipuResponse)
  217. if err != nil {
  218. common.SysError("error unmarshalling stream response: " + err.Error())
  219. return true
  220. }
  221. response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
  222. jsonResponse, err := json.Marshal(response)
  223. if err != nil {
  224. common.SysError("error marshalling stream response: " + err.Error())
  225. return true
  226. }
  227. usage = zhipuUsage
  228. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  229. return true
  230. case <-stopChan:
  231. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  232. return false
  233. }
  234. })
  235. err := resp.Body.Close()
  236. if err != nil {
  237. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  238. }
  239. return nil, usage
  240. }
  241. func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  242. var zhipuResponse ZhipuResponse
  243. responseBody, err := io.ReadAll(resp.Body)
  244. if err != nil {
  245. return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  246. }
  247. err = resp.Body.Close()
  248. if err != nil {
  249. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  250. }
  251. err = json.Unmarshal(responseBody, &zhipuResponse)
  252. if err != nil {
  253. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  254. }
  255. if !zhipuResponse.Success {
  256. return &OpenAIErrorWithStatusCode{
  257. OpenAIError: OpenAIError{
  258. Message: zhipuResponse.Msg,
  259. Type: "zhipu_error",
  260. Param: "",
  261. Code: zhipuResponse.Code,
  262. },
  263. StatusCode: resp.StatusCode,
  264. }, nil
  265. }
  266. fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
  267. jsonResponse, err := json.Marshal(fullTextResponse)
  268. if err != nil {
  269. return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  270. }
  271. c.Writer.Header().Set("Content-Type", "application/json")
  272. c.Writer.WriteHeader(resp.StatusCode)
  273. _, err = c.Writer.Write(jsonResponse)
  274. return nil, &fullTextResponse.Usage
  275. }