relay-ollama.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. package ollama
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "time"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/dto"
  12. relaycommon "github.com/QuantumNous/new-api/relay/common"
  13. "github.com/QuantumNous/new-api/service"
  14. "github.com/QuantumNous/new-api/types"
  15. "github.com/gin-gonic/gin"
  16. "github.com/samber/lo"
  17. )
  18. func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) {
  19. chatReq := &OllamaChatRequest{
  20. Model: r.Model,
  21. Stream: lo.FromPtrOr(r.Stream, false),
  22. Options: map[string]any{},
  23. Think: r.Think,
  24. }
  25. if r.ResponseFormat != nil {
  26. if r.ResponseFormat.Type == "json" {
  27. chatReq.Format = "json"
  28. } else if r.ResponseFormat.Type == "json_schema" {
  29. if len(r.ResponseFormat.JsonSchema) > 0 {
  30. var schema any
  31. _ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema)
  32. chatReq.Format = schema
  33. }
  34. }
  35. }
  36. // options mapping
  37. if r.Temperature != nil {
  38. chatReq.Options["temperature"] = r.Temperature
  39. }
  40. if r.TopP != nil {
  41. chatReq.Options["top_p"] = lo.FromPtr(r.TopP)
  42. }
  43. if r.TopK != nil {
  44. chatReq.Options["top_k"] = lo.FromPtr(r.TopK)
  45. }
  46. if r.FrequencyPenalty != nil {
  47. chatReq.Options["frequency_penalty"] = lo.FromPtr(r.FrequencyPenalty)
  48. }
  49. if r.PresencePenalty != nil {
  50. chatReq.Options["presence_penalty"] = lo.FromPtr(r.PresencePenalty)
  51. }
  52. if r.Seed != nil {
  53. chatReq.Options["seed"] = int(lo.FromPtr(r.Seed))
  54. }
  55. if mt := r.GetMaxTokens(); mt != 0 {
  56. chatReq.Options["num_predict"] = int(mt)
  57. }
  58. if r.Stop != nil {
  59. switch v := r.Stop.(type) {
  60. case string:
  61. chatReq.Options["stop"] = []string{v}
  62. case []string:
  63. chatReq.Options["stop"] = v
  64. case []any:
  65. arr := make([]string, 0, len(v))
  66. for _, i := range v {
  67. if s, ok := i.(string); ok {
  68. arr = append(arr, s)
  69. }
  70. }
  71. if len(arr) > 0 {
  72. chatReq.Options["stop"] = arr
  73. }
  74. }
  75. }
  76. if len(r.Tools) > 0 {
  77. tools := make([]OllamaTool, 0, len(r.Tools))
  78. for _, t := range r.Tools {
  79. tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}})
  80. }
  81. chatReq.Tools = tools
  82. }
  83. chatReq.Messages = make([]OllamaChatMessage, 0, len(r.Messages))
  84. for _, m := range r.Messages {
  85. var textBuilder strings.Builder
  86. var images []string
  87. if m.IsStringContent() {
  88. textBuilder.WriteString(m.StringContent())
  89. } else {
  90. parts := m.ParseContent()
  91. for _, part := range parts {
  92. if part.Type == dto.ContentTypeImageURL {
  93. source := part.ToFileSource()
  94. if source != nil {
  95. base64Data, _, err := service.GetBase64Data(c, source, "fetch image for ollama chat")
  96. if err != nil {
  97. return nil, err
  98. }
  99. if base64Data != "" {
  100. images = append(images, base64Data)
  101. }
  102. }
  103. } else if part.Type == dto.ContentTypeText {
  104. textBuilder.WriteString(part.Text)
  105. }
  106. }
  107. }
  108. cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()}
  109. if len(images) > 0 {
  110. cm.Images = images
  111. }
  112. if m.Role == "tool" && m.Name != nil {
  113. cm.ToolName = *m.Name
  114. }
  115. if m.ToolCalls != nil && len(m.ToolCalls) > 0 {
  116. parsed := m.ParseToolCalls()
  117. if len(parsed) > 0 {
  118. calls := make([]OllamaToolCall, 0, len(parsed))
  119. for _, tc := range parsed {
  120. var args interface{}
  121. if tc.Function.Arguments != "" {
  122. _ = json.Unmarshal([]byte(tc.Function.Arguments), &args)
  123. }
  124. if args == nil {
  125. args = map[string]any{}
  126. }
  127. oc := OllamaToolCall{}
  128. oc.Function.Name = tc.Function.Name
  129. oc.Function.Arguments = args
  130. calls = append(calls, oc)
  131. }
  132. cm.ToolCalls = calls
  133. }
  134. }
  135. chatReq.Messages = append(chatReq.Messages, cm)
  136. }
  137. return chatReq, nil
  138. }
  139. // openAIToGenerate converts OpenAI completions request to Ollama generate
  140. func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) {
  141. gen := &OllamaGenerateRequest{
  142. Model: r.Model,
  143. Stream: lo.FromPtrOr(r.Stream, false),
  144. Options: map[string]any{},
  145. Think: r.Think,
  146. }
  147. // Prompt may be in r.Prompt (string or []any)
  148. if r.Prompt != nil {
  149. switch v := r.Prompt.(type) {
  150. case string:
  151. gen.Prompt = v
  152. case []any:
  153. var sb strings.Builder
  154. for _, it := range v {
  155. if s, ok := it.(string); ok {
  156. sb.WriteString(s)
  157. }
  158. }
  159. gen.Prompt = sb.String()
  160. default:
  161. gen.Prompt = fmt.Sprintf("%v", r.Prompt)
  162. }
  163. }
  164. if r.Suffix != nil {
  165. if s, ok := r.Suffix.(string); ok {
  166. gen.Suffix = s
  167. }
  168. }
  169. if r.ResponseFormat != nil {
  170. if r.ResponseFormat.Type == "json" {
  171. gen.Format = "json"
  172. } else if r.ResponseFormat.Type == "json_schema" {
  173. var schema any
  174. _ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema)
  175. gen.Format = schema
  176. }
  177. }
  178. if r.Temperature != nil {
  179. gen.Options["temperature"] = r.Temperature
  180. }
  181. if r.TopP != nil {
  182. gen.Options["top_p"] = lo.FromPtr(r.TopP)
  183. }
  184. if r.TopK != nil {
  185. gen.Options["top_k"] = lo.FromPtr(r.TopK)
  186. }
  187. if r.FrequencyPenalty != nil {
  188. gen.Options["frequency_penalty"] = lo.FromPtr(r.FrequencyPenalty)
  189. }
  190. if r.PresencePenalty != nil {
  191. gen.Options["presence_penalty"] = lo.FromPtr(r.PresencePenalty)
  192. }
  193. if r.Seed != nil {
  194. gen.Options["seed"] = int(lo.FromPtr(r.Seed))
  195. }
  196. if mt := r.GetMaxTokens(); mt != 0 {
  197. gen.Options["num_predict"] = int(mt)
  198. }
  199. if r.Stop != nil {
  200. switch v := r.Stop.(type) {
  201. case string:
  202. gen.Options["stop"] = []string{v}
  203. case []string:
  204. gen.Options["stop"] = v
  205. case []any:
  206. arr := make([]string, 0, len(v))
  207. for _, i := range v {
  208. if s, ok := i.(string); ok {
  209. arr = append(arr, s)
  210. }
  211. }
  212. if len(arr) > 0 {
  213. gen.Options["stop"] = arr
  214. }
  215. }
  216. }
  217. return gen, nil
  218. }
  219. func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest {
  220. opts := map[string]any{}
  221. if r.Temperature != nil {
  222. opts["temperature"] = r.Temperature
  223. }
  224. if r.TopP != nil {
  225. opts["top_p"] = lo.FromPtr(r.TopP)
  226. }
  227. if r.FrequencyPenalty != nil {
  228. opts["frequency_penalty"] = lo.FromPtr(r.FrequencyPenalty)
  229. }
  230. if r.PresencePenalty != nil {
  231. opts["presence_penalty"] = lo.FromPtr(r.PresencePenalty)
  232. }
  233. if r.Seed != nil {
  234. opts["seed"] = int(lo.FromPtr(r.Seed))
  235. }
  236. dimensions := lo.FromPtrOr(r.Dimensions, 0)
  237. if r.Dimensions != nil {
  238. opts["dimensions"] = dimensions
  239. }
  240. input := r.ParseInput()
  241. if len(input) == 1 {
  242. return &OllamaEmbeddingRequest{Model: r.Model, Input: input[0], Options: opts, Dimensions: dimensions}
  243. }
  244. return &OllamaEmbeddingRequest{Model: r.Model, Input: input, Options: opts, Dimensions: dimensions}
  245. }
  246. func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  247. var oResp OllamaEmbeddingResponse
  248. body, err := io.ReadAll(resp.Body)
  249. if err != nil {
  250. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  251. }
  252. service.CloseResponseBodyGracefully(resp)
  253. if err = common.Unmarshal(body, &oResp); err != nil {
  254. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  255. }
  256. if oResp.Error != "" {
  257. return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  258. }
  259. data := make([]dto.OpenAIEmbeddingResponseItem, 0, len(oResp.Embeddings))
  260. for i, emb := range oResp.Embeddings {
  261. data = append(data, dto.OpenAIEmbeddingResponseItem{Index: i, Object: "embedding", Embedding: emb})
  262. }
  263. usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens: 0, TotalTokens: oResp.PromptEvalCount}
  264. embResp := &dto.OpenAIEmbeddingResponse{Object: "list", Data: data, Model: info.UpstreamModelName, Usage: *usage}
  265. out, _ := common.Marshal(embResp)
  266. service.IOCopyBytesGracefully(c, resp, out)
  267. return usage, nil
  268. }
  269. func FetchOllamaModels(baseURL, apiKey string) ([]OllamaModel, error) {
  270. url := fmt.Sprintf("%s/api/tags", baseURL)
  271. client := &http.Client{}
  272. request, err := http.NewRequest("GET", url, nil)
  273. if err != nil {
  274. return nil, fmt.Errorf("创建请求失败: %v", err)
  275. }
  276. // Ollama 通常不需要 Bearer token,但为了兼容性保留
  277. if apiKey != "" {
  278. request.Header.Set("Authorization", "Bearer "+apiKey)
  279. }
  280. response, err := client.Do(request)
  281. if err != nil {
  282. return nil, fmt.Errorf("请求失败: %v", err)
  283. }
  284. defer response.Body.Close()
  285. if response.StatusCode != http.StatusOK {
  286. body, _ := io.ReadAll(response.Body)
  287. return nil, fmt.Errorf("服务器返回错误 %d: %s", response.StatusCode, string(body))
  288. }
  289. var tagsResponse OllamaTagsResponse
  290. body, err := io.ReadAll(response.Body)
  291. if err != nil {
  292. return nil, fmt.Errorf("读取响应失败: %v", err)
  293. }
  294. err = common.Unmarshal(body, &tagsResponse)
  295. if err != nil {
  296. return nil, fmt.Errorf("解析响应失败: %v", err)
  297. }
  298. return tagsResponse.Models, nil
  299. }
  300. // 拉取 Ollama 模型 (非流式)
  301. func PullOllamaModel(baseURL, apiKey, modelName string) error {
  302. url := fmt.Sprintf("%s/api/pull", baseURL)
  303. pullRequest := OllamaPullRequest{
  304. Name: modelName,
  305. Stream: false, // 非流式,简化处理
  306. }
  307. requestBody, err := common.Marshal(pullRequest)
  308. if err != nil {
  309. return fmt.Errorf("序列化请求失败: %v", err)
  310. }
  311. client := &http.Client{
  312. Timeout: 30 * 60 * 1000 * time.Millisecond, // 30分钟超时,支持大模型
  313. }
  314. request, err := http.NewRequest("POST", url, strings.NewReader(string(requestBody)))
  315. if err != nil {
  316. return fmt.Errorf("创建请求失败: %v", err)
  317. }
  318. request.Header.Set("Content-Type", "application/json")
  319. if apiKey != "" {
  320. request.Header.Set("Authorization", "Bearer "+apiKey)
  321. }
  322. response, err := client.Do(request)
  323. if err != nil {
  324. return fmt.Errorf("请求失败: %v", err)
  325. }
  326. defer response.Body.Close()
  327. if response.StatusCode != http.StatusOK {
  328. body, _ := io.ReadAll(response.Body)
  329. return fmt.Errorf("拉取模型失败 %d: %s", response.StatusCode, string(body))
  330. }
  331. return nil
  332. }
  333. // 流式拉取 Ollama 模型 (支持进度回调)
  334. func PullOllamaModelStream(baseURL, apiKey, modelName string, progressCallback func(OllamaPullResponse)) error {
  335. url := fmt.Sprintf("%s/api/pull", baseURL)
  336. pullRequest := OllamaPullRequest{
  337. Name: modelName,
  338. Stream: true, // 启用流式
  339. }
  340. requestBody, err := common.Marshal(pullRequest)
  341. if err != nil {
  342. return fmt.Errorf("序列化请求失败: %v", err)
  343. }
  344. client := &http.Client{
  345. Timeout: 60 * 60 * 1000 * time.Millisecond, // 1小时超时,支持超大模型
  346. }
  347. request, err := http.NewRequest("POST", url, strings.NewReader(string(requestBody)))
  348. if err != nil {
  349. return fmt.Errorf("创建请求失败: %v", err)
  350. }
  351. request.Header.Set("Content-Type", "application/json")
  352. if apiKey != "" {
  353. request.Header.Set("Authorization", "Bearer "+apiKey)
  354. }
  355. response, err := client.Do(request)
  356. if err != nil {
  357. return fmt.Errorf("请求失败: %v", err)
  358. }
  359. defer response.Body.Close()
  360. if response.StatusCode != http.StatusOK {
  361. body, _ := io.ReadAll(response.Body)
  362. return fmt.Errorf("拉取模型失败 %d: %s", response.StatusCode, string(body))
  363. }
  364. // 读取流式响应
  365. scanner := bufio.NewScanner(response.Body)
  366. successful := false
  367. for scanner.Scan() {
  368. line := scanner.Text()
  369. if strings.TrimSpace(line) == "" {
  370. continue
  371. }
  372. var pullResponse OllamaPullResponse
  373. if err := common.Unmarshal([]byte(line), &pullResponse); err != nil {
  374. continue // 忽略解析失败的行
  375. }
  376. if progressCallback != nil {
  377. progressCallback(pullResponse)
  378. }
  379. // 检查是否出现错误或完成
  380. if strings.EqualFold(pullResponse.Status, "error") {
  381. return fmt.Errorf("拉取模型失败: %s", strings.TrimSpace(line))
  382. }
  383. if strings.EqualFold(pullResponse.Status, "success") {
  384. successful = true
  385. break
  386. }
  387. }
  388. if err := scanner.Err(); err != nil {
  389. return fmt.Errorf("读取流式响应失败: %v", err)
  390. }
  391. if !successful {
  392. return fmt.Errorf("拉取模型未完成: 未收到成功状态")
  393. }
  394. return nil
  395. }
  396. // 删除 Ollama 模型
  397. func DeleteOllamaModel(baseURL, apiKey, modelName string) error {
  398. url := fmt.Sprintf("%s/api/delete", baseURL)
  399. deleteRequest := OllamaDeleteRequest{
  400. Name: modelName,
  401. }
  402. requestBody, err := common.Marshal(deleteRequest)
  403. if err != nil {
  404. return fmt.Errorf("序列化请求失败: %v", err)
  405. }
  406. client := &http.Client{}
  407. request, err := http.NewRequest("DELETE", url, strings.NewReader(string(requestBody)))
  408. if err != nil {
  409. return fmt.Errorf("创建请求失败: %v", err)
  410. }
  411. request.Header.Set("Content-Type", "application/json")
  412. if apiKey != "" {
  413. request.Header.Set("Authorization", "Bearer "+apiKey)
  414. }
  415. response, err := client.Do(request)
  416. if err != nil {
  417. return fmt.Errorf("请求失败: %v", err)
  418. }
  419. defer response.Body.Close()
  420. if response.StatusCode != http.StatusOK {
  421. body, _ := io.ReadAll(response.Body)
  422. return fmt.Errorf("删除模型失败 %d: %s", response.StatusCode, string(body))
  423. }
  424. return nil
  425. }
  426. func FetchOllamaVersion(baseURL, apiKey string) (string, error) {
  427. trimmedBase := strings.TrimRight(baseURL, "/")
  428. if trimmedBase == "" {
  429. return "", fmt.Errorf("baseURL 为空")
  430. }
  431. url := fmt.Sprintf("%s/api/version", trimmedBase)
  432. client := &http.Client{Timeout: 10 * time.Second}
  433. request, err := http.NewRequest("GET", url, nil)
  434. if err != nil {
  435. return "", fmt.Errorf("创建请求失败: %v", err)
  436. }
  437. if apiKey != "" {
  438. request.Header.Set("Authorization", "Bearer "+apiKey)
  439. }
  440. response, err := client.Do(request)
  441. if err != nil {
  442. return "", fmt.Errorf("请求失败: %v", err)
  443. }
  444. defer response.Body.Close()
  445. body, err := io.ReadAll(response.Body)
  446. if err != nil {
  447. return "", fmt.Errorf("读取响应失败: %v", err)
  448. }
  449. if response.StatusCode != http.StatusOK {
  450. return "", fmt.Errorf("查询版本失败 %d: %s", response.StatusCode, string(body))
  451. }
  452. var versionResp struct {
  453. Version string `json:"version"`
  454. }
  455. if err := json.Unmarshal(body, &versionResp); err != nil {
  456. return "", fmt.Errorf("解析响应失败: %v", err)
  457. }
  458. if versionResp.Version == "" {
  459. return "", fmt.Errorf("未返回版本信息")
  460. }
  461. return versionResp.Version, nil
  462. }