stream.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. package ollama
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. "one-api/logger"
  11. relaycommon "one-api/relay/common"
  12. "one-api/relay/helper"
  13. "one-api/service"
  14. "one-api/types"
  15. "strings"
  16. "time"
  17. "github.com/gin-gonic/gin"
  18. )
  19. type ollamaChatStreamChunk struct {
  20. Model string `json:"model"`
  21. CreatedAt string `json:"created_at"`
  22. // chat
  23. Message *struct {
  24. Role string `json:"role"`
  25. Content string `json:"content"`
  26. Thinking json.RawMessage `json:"thinking"`
  27. ToolCalls []struct {
  28. Function struct {
  29. Name string `json:"name"`
  30. Arguments interface{} `json:"arguments"`
  31. } `json:"function"`
  32. } `json:"tool_calls"`
  33. } `json:"message"`
  34. // generate
  35. Response string `json:"response"`
  36. Done bool `json:"done"`
  37. DoneReason string `json:"done_reason"`
  38. TotalDuration int64 `json:"total_duration"`
  39. LoadDuration int64 `json:"load_duration"`
  40. PromptEvalCount int `json:"prompt_eval_count"`
  41. EvalCount int `json:"eval_count"`
  42. PromptEvalDuration int64 `json:"prompt_eval_duration"`
  43. EvalDuration int64 `json:"eval_duration"`
  44. }
  45. func toUnix(ts string) int64 {
  46. if ts == "" { return time.Now().Unix() }
  47. // try time.RFC3339 or with nanoseconds
  48. t, err := time.Parse(time.RFC3339Nano, ts)
  49. if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() }
  50. return t.Unix()
  51. }
  52. func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  53. if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
  54. defer service.CloseResponseBodyGracefully(resp)
  55. helper.SetEventStreamHeaders(c)
  56. scanner := bufio.NewScanner(resp.Body)
  57. usage := &dto.Usage{}
  58. var model = info.UpstreamModelName
  59. var responseId = common.GetUUID()
  60. var created = time.Now().Unix()
  61. var toolCallIndex int
  62. start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
  63. if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
  64. for scanner.Scan() {
  65. line := scanner.Text()
  66. line = strings.TrimSpace(line)
  67. if line == "" { continue }
  68. var chunk ollamaChatStreamChunk
  69. if err := json.Unmarshal([]byte(line), &chunk); err != nil {
  70. logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
  71. return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  72. }
  73. if chunk.Model != "" { model = chunk.Model }
  74. created = toUnix(chunk.CreatedAt)
  75. if !chunk.Done {
  76. // delta content
  77. var content string
  78. if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
  79. delta := dto.ChatCompletionsStreamResponse{
  80. Id: responseId,
  81. Object: "chat.completion.chunk",
  82. Created: created,
  83. Model: model,
  84. Choices: []dto.ChatCompletionsStreamResponseChoice{ {
  85. Index: 0,
  86. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" },
  87. } },
  88. }
  89. if content != "" { delta.Choices[0].Delta.SetContentString(content) }
  90. if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
  91. raw := strings.TrimSpace(string(chunk.Message.Thinking))
  92. if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) }
  93. }
  94. // tool calls
  95. if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
  96. delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
  97. for _, tc := range chunk.Message.ToolCalls {
  98. // arguments -> string
  99. argBytes, _ := json.Marshal(tc.Function.Arguments)
  100. toolId := fmt.Sprintf("call_%d", toolCallIndex)
  101. tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
  102. tr.SetIndex(toolCallIndex)
  103. toolCallIndex++
  104. delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
  105. }
  106. }
  107. if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) }
  108. continue
  109. }
  110. // done frame
  111. // finalize once and break loop
  112. usage.PromptTokens = chunk.PromptEvalCount
  113. usage.CompletionTokens = chunk.EvalCount
  114. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  115. finishReason := chunk.DoneReason
  116. if finishReason == "" { finishReason = "stop" }
  117. // emit stop delta
  118. if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
  119. if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
  120. }
  121. // emit usage frame
  122. if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
  123. if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) }
  124. }
  125. // send [DONE]
  126. helper.Done(c)
  127. break
  128. }
  129. if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) }
  130. return usage, nil
  131. }
  132. // non-stream handler for chat/generate
  133. func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  134. body, err := io.ReadAll(resp.Body)
  135. if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
  136. service.CloseResponseBodyGracefully(resp)
  137. raw := string(body)
  138. if common.DebugEnabled { println("ollama non-stream raw resp:", raw) }
  139. lines := strings.Split(raw, "\n")
  140. var (
  141. aggContent strings.Builder
  142. reasoningBuilder strings.Builder
  143. lastChunk ollamaChatStreamChunk
  144. parsedAny bool
  145. )
  146. for _, ln := range lines {
  147. ln = strings.TrimSpace(ln)
  148. if ln == "" { continue }
  149. var ck ollamaChatStreamChunk
  150. if err := json.Unmarshal([]byte(ln), &ck); err != nil {
  151. if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
  152. continue
  153. }
  154. parsedAny = true
  155. lastChunk = ck
  156. if ck.Message != nil && len(ck.Message.Thinking) > 0 {
  157. raw := strings.TrimSpace(string(ck.Message.Thinking))
  158. if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) }
  159. }
  160. if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
  161. }
  162. if !parsedAny {
  163. var single ollamaChatStreamChunk
  164. if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
  165. lastChunk = single
  166. if single.Message != nil {
  167. if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } }
  168. aggContent.WriteString(single.Message.Content)
  169. } else { aggContent.WriteString(single.Response) }
  170. }
  171. model := lastChunk.Model
  172. if model == "" { model = info.UpstreamModelName }
  173. created := toUnix(lastChunk.CreatedAt)
  174. usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
  175. content := aggContent.String()
  176. finishReason := lastChunk.DoneReason
  177. if finishReason == "" { finishReason = "stop" }
  178. msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
  179. if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc }
  180. full := dto.OpenAITextResponse{
  181. Id: common.GetUUID(),
  182. Model: model,
  183. Object: "chat.completion",
  184. Created: created,
  185. Choices: []dto.OpenAITextResponseChoice{ {
  186. Index: 0,
  187. Message: msg,
  188. FinishReason: finishReason,
  189. } },
  190. Usage: *usage,
  191. }
  192. out, _ := common.Marshal(full)
  193. service.IOCopyBytesGracefully(c, resp, out)
  194. return usage, nil
  195. }
  196. func contentPtr(s string) *string { if s=="" { return nil }; return &s }