stream.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. package ollama
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "time"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/dto"
  12. "github.com/QuantumNous/new-api/logger"
  13. relaycommon "github.com/QuantumNous/new-api/relay/common"
  14. "github.com/QuantumNous/new-api/relay/helper"
  15. "github.com/QuantumNous/new-api/service"
  16. "github.com/QuantumNous/new-api/types"
  17. "github.com/gin-gonic/gin"
  18. )
  19. type ollamaChatStreamChunk struct {
  20. Model string `json:"model"`
  21. CreatedAt string `json:"created_at"`
  22. // chat
  23. Message *struct {
  24. Role string `json:"role"`
  25. Content string `json:"content"`
  26. Thinking json.RawMessage `json:"thinking"`
  27. ToolCalls []struct {
  28. Function struct {
  29. Name string `json:"name"`
  30. Arguments interface{} `json:"arguments"`
  31. } `json:"function"`
  32. } `json:"tool_calls"`
  33. } `json:"message"`
  34. // generate
  35. Response string `json:"response"`
  36. Done bool `json:"done"`
  37. DoneReason string `json:"done_reason"`
  38. TotalDuration int64 `json:"total_duration"`
  39. LoadDuration int64 `json:"load_duration"`
  40. PromptEvalCount int `json:"prompt_eval_count"`
  41. EvalCount int `json:"eval_count"`
  42. PromptEvalDuration int64 `json:"prompt_eval_duration"`
  43. EvalDuration int64 `json:"eval_duration"`
  44. }
  45. func toUnix(ts string) int64 {
  46. if ts == "" {
  47. return time.Now().Unix()
  48. }
  49. // try time.RFC3339 or with nanoseconds
  50. t, err := time.Parse(time.RFC3339Nano, ts)
  51. if err != nil {
  52. t2, err2 := time.Parse(time.RFC3339, ts)
  53. if err2 == nil {
  54. return t2.Unix()
  55. }
  56. return time.Now().Unix()
  57. }
  58. return t.Unix()
  59. }
  60. func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  61. if resp == nil || resp.Body == nil {
  62. return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest)
  63. }
  64. defer service.CloseResponseBodyGracefully(resp)
  65. helper.SetEventStreamHeaders(c)
  66. scanner := bufio.NewScanner(resp.Body)
  67. usage := &dto.Usage{}
  68. var model = info.UpstreamModelName
  69. var responseId = common.GetUUID()
  70. var created = time.Now().Unix()
  71. var toolCallIndex int
  72. start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
  73. if data, err := common.Marshal(start); err == nil {
  74. _ = helper.StringData(c, string(data))
  75. }
  76. for scanner.Scan() {
  77. line := scanner.Text()
  78. line = strings.TrimSpace(line)
  79. if line == "" {
  80. continue
  81. }
  82. var chunk ollamaChatStreamChunk
  83. if err := json.Unmarshal([]byte(line), &chunk); err != nil {
  84. logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
  85. return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  86. }
  87. if chunk.Model != "" {
  88. model = chunk.Model
  89. }
  90. created = toUnix(chunk.CreatedAt)
  91. if !chunk.Done {
  92. // delta content
  93. var content string
  94. if chunk.Message != nil {
  95. content = chunk.Message.Content
  96. } else {
  97. content = chunk.Response
  98. }
  99. delta := dto.ChatCompletionsStreamResponse{
  100. Id: responseId,
  101. Object: "chat.completion.chunk",
  102. Created: created,
  103. Model: model,
  104. Choices: []dto.ChatCompletionsStreamResponseChoice{{
  105. Index: 0,
  106. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{Role: "assistant"},
  107. }},
  108. }
  109. if content != "" {
  110. delta.Choices[0].Delta.SetContentString(content)
  111. }
  112. if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
  113. raw := strings.TrimSpace(string(chunk.Message.Thinking))
  114. if raw != "" && raw != "null" {
  115. // Unmarshal the JSON string to get the actual content without quotes
  116. var thinkingContent string
  117. if err := json.Unmarshal(chunk.Message.Thinking, &thinkingContent); err == nil {
  118. delta.Choices[0].Delta.SetReasoningContent(thinkingContent)
  119. } else {
  120. // Fallback to raw string if it's not a JSON string
  121. delta.Choices[0].Delta.SetReasoningContent(raw)
  122. }
  123. }
  124. }
  125. // tool calls
  126. if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
  127. delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse, 0, len(chunk.Message.ToolCalls))
  128. for _, tc := range chunk.Message.ToolCalls {
  129. // arguments -> string
  130. argBytes, _ := json.Marshal(tc.Function.Arguments)
  131. toolId := fmt.Sprintf("call_%d", toolCallIndex)
  132. tr := dto.ToolCallResponse{ID: toolId, Type: "function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
  133. tr.SetIndex(toolCallIndex)
  134. toolCallIndex++
  135. delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
  136. }
  137. }
  138. if data, err := common.Marshal(delta); err == nil {
  139. _ = helper.StringData(c, string(data))
  140. }
  141. continue
  142. }
  143. // done frame
  144. // finalize once and break loop
  145. usage.PromptTokens = chunk.PromptEvalCount
  146. usage.CompletionTokens = chunk.EvalCount
  147. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  148. finishReason := chunk.DoneReason
  149. if finishReason == "" {
  150. finishReason = "stop"
  151. }
  152. // emit stop delta
  153. if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
  154. if data, err := common.Marshal(stop); err == nil {
  155. _ = helper.StringData(c, string(data))
  156. }
  157. }
  158. // emit usage frame
  159. if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
  160. if data, err := common.Marshal(final); err == nil {
  161. _ = helper.StringData(c, string(data))
  162. }
  163. }
  164. // send [DONE]
  165. helper.Done(c)
  166. break
  167. }
  168. if err := scanner.Err(); err != nil && err != io.EOF {
  169. logger.LogError(c, "ollama stream scan error: "+err.Error())
  170. }
  171. return usage, nil
  172. }
  173. // non-stream handler for chat/generate
  174. func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  175. body, err := io.ReadAll(resp.Body)
  176. if err != nil {
  177. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  178. }
  179. service.CloseResponseBodyGracefully(resp)
  180. raw := string(body)
  181. if common.DebugEnabled {
  182. println("ollama non-stream raw resp:", raw)
  183. }
  184. lines := strings.Split(raw, "\n")
  185. var (
  186. aggContent strings.Builder
  187. reasoningBuilder strings.Builder
  188. lastChunk ollamaChatStreamChunk
  189. parsedAny bool
  190. )
  191. for _, ln := range lines {
  192. ln = strings.TrimSpace(ln)
  193. if ln == "" {
  194. continue
  195. }
  196. var ck ollamaChatStreamChunk
  197. if err := json.Unmarshal([]byte(ln), &ck); err != nil {
  198. if len(lines) == 1 {
  199. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  200. }
  201. continue
  202. }
  203. parsedAny = true
  204. lastChunk = ck
  205. if ck.Message != nil && len(ck.Message.Thinking) > 0 {
  206. raw := strings.TrimSpace(string(ck.Message.Thinking))
  207. if raw != "" && raw != "null" {
  208. // Unmarshal the JSON string to get the actual content without quotes
  209. var thinkingContent string
  210. if err := json.Unmarshal(ck.Message.Thinking, &thinkingContent); err == nil {
  211. reasoningBuilder.WriteString(thinkingContent)
  212. } else {
  213. // Fallback to raw string if it's not a JSON string
  214. reasoningBuilder.WriteString(raw)
  215. }
  216. }
  217. }
  218. if ck.Message != nil && ck.Message.Content != "" {
  219. aggContent.WriteString(ck.Message.Content)
  220. } else if ck.Response != "" {
  221. aggContent.WriteString(ck.Response)
  222. }
  223. }
  224. if !parsedAny {
  225. var single ollamaChatStreamChunk
  226. if err := json.Unmarshal(body, &single); err != nil {
  227. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  228. }
  229. lastChunk = single
  230. if single.Message != nil {
  231. if len(single.Message.Thinking) > 0 {
  232. raw := strings.TrimSpace(string(single.Message.Thinking))
  233. if raw != "" && raw != "null" {
  234. // Unmarshal the JSON string to get the actual content without quotes
  235. var thinkingContent string
  236. if err := json.Unmarshal(single.Message.Thinking, &thinkingContent); err == nil {
  237. reasoningBuilder.WriteString(thinkingContent)
  238. } else {
  239. // Fallback to raw string if it's not a JSON string
  240. reasoningBuilder.WriteString(raw)
  241. }
  242. }
  243. }
  244. aggContent.WriteString(single.Message.Content)
  245. } else {
  246. aggContent.WriteString(single.Response)
  247. }
  248. }
  249. model := lastChunk.Model
  250. if model == "" {
  251. model = info.UpstreamModelName
  252. }
  253. created := toUnix(lastChunk.CreatedAt)
  254. usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
  255. content := aggContent.String()
  256. finishReason := lastChunk.DoneReason
  257. if finishReason == "" {
  258. finishReason = "stop"
  259. }
  260. msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
  261. if rc := reasoningBuilder.String(); rc != "" {
  262. msg.ReasoningContent = rc
  263. }
  264. full := dto.OpenAITextResponse{
  265. Id: common.GetUUID(),
  266. Model: model,
  267. Object: "chat.completion",
  268. Created: created,
  269. Choices: []dto.OpenAITextResponseChoice{{
  270. Index: 0,
  271. Message: msg,
  272. FinishReason: finishReason,
  273. }},
  274. Usage: *usage,
  275. }
  276. out, _ := common.Marshal(full)
  277. service.IOCopyBytesGracefully(c, resp, out)
  278. return usage, nil
  279. }
  280. func contentPtr(s string) *string {
  281. if s == "" {
  282. return nil
  283. }
  284. return &s
  285. }