relay_info.go 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. package common
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "time"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/dto"
  11. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  12. "github.com/QuantumNous/new-api/setting/model_setting"
  13. "github.com/QuantumNous/new-api/types"
  14. "github.com/gin-gonic/gin"
  15. "github.com/gorilla/websocket"
  16. )
  17. type ThinkingContentInfo struct {
  18. IsFirstThinkingContent bool
  19. SendLastThinkingContent bool
  20. HasSentThinkingContent bool
  21. }
  22. const (
  23. LastMessageTypeNone = "none"
  24. LastMessageTypeText = "text"
  25. LastMessageTypeTools = "tools"
  26. LastMessageTypeThinking = "thinking"
  27. )
  28. type ClaudeConvertInfo struct {
  29. LastMessagesType string
  30. Index int
  31. Usage *dto.Usage
  32. FinishReason string
  33. Done bool
  34. ToolCallBaseIndex int
  35. ToolCallMaxIndexOffset int
  36. }
  37. type RerankerInfo struct {
  38. Documents []any
  39. ReturnDocuments bool
  40. }
  41. type BuildInToolInfo struct {
  42. ToolName string
  43. CallCount int
  44. SearchContextSize string
  45. }
  46. type ResponsesUsageInfo struct {
  47. BuiltInTools map[string]*BuildInToolInfo
  48. }
  49. type ChannelMeta struct {
  50. ChannelType int
  51. ChannelId int
  52. ChannelIsMultiKey bool
  53. ChannelMultiKeyIndex int
  54. ChannelBaseUrl string
  55. ApiType int
  56. ApiVersion string
  57. ApiKey string
  58. Organization string
  59. ChannelCreateTime int64
  60. ParamOverride map[string]interface{}
  61. HeadersOverride map[string]interface{}
  62. ChannelSetting dto.ChannelSettings
  63. ChannelOtherSettings dto.ChannelOtherSettings
  64. UpstreamModelName string
  65. IsModelMapped bool
  66. SupportStreamOptions bool // 是否支持流式选项
  67. }
  68. type TokenCountMeta struct {
  69. //promptTokens int
  70. estimatePromptTokens int
  71. }
  72. type RelayInfo struct {
  73. TokenId int
  74. TokenKey string
  75. TokenGroup string
  76. UserId int
  77. UsingGroup string // 使用的分组,当auto跨分组重试时,会变动
  78. UserGroup string // 用户所在分组
  79. TokenUnlimited bool
  80. StartTime time.Time
  81. FirstResponseTime time.Time
  82. isFirstResponse bool
  83. //SendLastReasoningResponse bool
  84. IsStream bool
  85. IsGeminiBatchEmbedding bool
  86. IsPlayground bool
  87. UsePrice bool
  88. RelayMode int
  89. OriginModelName string
  90. RequestURLPath string
  91. RequestHeaders map[string]string
  92. ShouldIncludeUsage bool
  93. DisablePing bool // 是否禁止向下游发送自定义 Ping
  94. ClientWs *websocket.Conn
  95. TargetWs *websocket.Conn
  96. InputAudioFormat string
  97. OutputAudioFormat string
  98. RealtimeTools []dto.RealTimeTool
  99. IsFirstRequest bool
  100. AudioUsage bool
  101. ReasoningEffort string
  102. UserSetting dto.UserSetting
  103. UserEmail string
  104. UserQuota int
  105. RelayFormat types.RelayFormat
  106. SendResponseCount int
  107. ReceivedResponseCount int
  108. FinalPreConsumedQuota int // 最终预消耗的配额
  109. // ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路,
  110. // 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行,
  111. // 必须在提交前锁定全额。
  112. ForcePreConsume bool
  113. // Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。
  114. // 免费模型时为 nil。
  115. Billing BillingSettler
  116. // BillingSource indicates whether this request is billed from wallet quota or subscription.
  117. // "" or "wallet" => wallet; "subscription" => subscription
  118. BillingSource string
  119. // SubscriptionId is the user_subscriptions.id used when BillingSource == "subscription"
  120. SubscriptionId int
  121. // SubscriptionPreConsumed is the amount pre-consumed on subscription item (quota units or 1)
  122. SubscriptionPreConsumed int64
  123. // SubscriptionPostDelta is the post-consume delta applied to amount_used (quota units; can be negative).
  124. SubscriptionPostDelta int64
  125. // SubscriptionPlanId / SubscriptionPlanTitle are used for logging/UI display.
  126. SubscriptionPlanId int
  127. SubscriptionPlanTitle string
  128. // RequestId is used for idempotent pre-consume/refund
  129. RequestId string
  130. // SubscriptionAmountTotal / SubscriptionAmountUsedAfterPreConsume are used to compute remaining in logs.
  131. SubscriptionAmountTotal int64
  132. SubscriptionAmountUsedAfterPreConsume int64
  133. IsClaudeBetaQuery bool // /v1/messages?beta=true
  134. IsChannelTest bool // channel test request
  135. RetryIndex int
  136. LastError *types.NewAPIError
  137. RuntimeHeadersOverride map[string]interface{}
  138. RuntimeHeadersDeletedNormalized map[string]bool
  139. UseRuntimeHeadersOverride bool
  140. PriceData types.PriceData
  141. Request dto.Request
  142. // RequestConversionChain records request format conversions in order, e.g.
  143. // ["openai", "openai_responses"] or ["openai", "claude"].
  144. RequestConversionChain []types.RelayFormat
  145. // 最终请求到上游的格式。可由 adaptor 显式设置;
  146. // 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
  147. FinalRequestRelayFormat types.RelayFormat
  148. ThinkingContentInfo
  149. TokenCountMeta
  150. *ClaudeConvertInfo
  151. *RerankerInfo
  152. *ResponsesUsageInfo
  153. *ChannelMeta
  154. *TaskRelayInfo
  155. }
  156. func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
  157. channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
  158. paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
  159. headerOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelHeaderOverride)
  160. apiType, _ := common.ChannelType2APIType(channelType)
  161. channelMeta := &ChannelMeta{
  162. ChannelType: channelType,
  163. ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId),
  164. ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
  165. ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
  166. ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
  167. ApiType: apiType,
  168. ApiVersion: c.GetString("api_version"),
  169. ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
  170. Organization: c.GetString("channel_organization"),
  171. ChannelCreateTime: c.GetInt64("channel_create_time"),
  172. ParamOverride: paramOverride,
  173. HeadersOverride: headerOverride,
  174. UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
  175. IsModelMapped: false,
  176. SupportStreamOptions: false,
  177. }
  178. if channelType == constant.ChannelTypeAzure {
  179. channelMeta.ApiVersion = GetAPIVersion(c)
  180. }
  181. if channelType == constant.ChannelTypeVertexAi {
  182. channelMeta.ApiVersion = c.GetString("region")
  183. }
  184. channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
  185. if ok {
  186. channelMeta.ChannelSetting = channelSetting
  187. }
  188. channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
  189. if ok {
  190. channelMeta.ChannelOtherSettings = channelOtherSettings
  191. }
  192. if streamSupportedChannels[channelMeta.ChannelType] {
  193. channelMeta.SupportStreamOptions = true
  194. }
  195. info.ChannelMeta = channelMeta
  196. // reset some fields based on channel meta
  197. // 重置某些字段,例如模型名称等
  198. if info.Request != nil {
  199. info.Request.SetModelName(info.OriginModelName)
  200. }
  201. }
  202. func (info *RelayInfo) ToString() string {
  203. if info == nil {
  204. return "RelayInfo<nil>"
  205. }
  206. // Basic info
  207. b := &strings.Builder{}
  208. fmt.Fprintf(b, "RelayInfo{ ")
  209. fmt.Fprintf(b, "RelayFormat: %s, ", info.RelayFormat)
  210. fmt.Fprintf(b, "RelayMode: %d, ", info.RelayMode)
  211. fmt.Fprintf(b, "IsStream: %t, ", info.IsStream)
  212. fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground)
  213. fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath)
  214. fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName)
  215. fmt.Fprintf(b, "EstimatePromptTokens: %d, ", info.estimatePromptTokens)
  216. fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage)
  217. fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing)
  218. fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount)
  219. fmt.Fprintf(b, "FinalPreConsumedQuota: %d, ", info.FinalPreConsumedQuota)
  220. // User & token info (mask secrets)
  221. fmt.Fprintf(b, "User{ Id: %d, Email: %q, Group: %q, UsingGroup: %q, Quota: %d }, ",
  222. info.UserId, common.MaskEmail(info.UserEmail), info.UserGroup, info.UsingGroup, info.UserQuota)
  223. fmt.Fprintf(b, "Token{ Id: %d, Unlimited: %t, Key: ***masked*** }, ", info.TokenId, info.TokenUnlimited)
  224. // Time info
  225. latencyMs := info.FirstResponseTime.Sub(info.StartTime).Milliseconds()
  226. fmt.Fprintf(b, "Timing{ Start: %s, FirstResponse: %s, LatencyMs: %d }, ",
  227. info.StartTime.Format(time.RFC3339Nano), info.FirstResponseTime.Format(time.RFC3339Nano), latencyMs)
  228. // Audio / realtime
  229. if info.InputAudioFormat != "" || info.OutputAudioFormat != "" || len(info.RealtimeTools) > 0 || info.AudioUsage {
  230. fmt.Fprintf(b, "Realtime{ AudioUsage: %t, InFmt: %q, OutFmt: %q, Tools: %d }, ",
  231. info.AudioUsage, info.InputAudioFormat, info.OutputAudioFormat, len(info.RealtimeTools))
  232. }
  233. // Reasoning
  234. if info.ReasoningEffort != "" {
  235. fmt.Fprintf(b, "ReasoningEffort: %q, ", info.ReasoningEffort)
  236. }
  237. // Price data (non-sensitive)
  238. if info.PriceData.UsePrice {
  239. fmt.Fprintf(b, "PriceData{ %s }, ", info.PriceData.ToSetting())
  240. }
  241. // Channel metadata (mask ApiKey)
  242. if info.ChannelMeta != nil {
  243. cm := info.ChannelMeta
  244. fmt.Fprintf(b, "ChannelMeta{ Type: %d, Id: %d, IsMultiKey: %t, MultiKeyIndex: %d, BaseURL: %q, ApiType: %d, ApiVersion: %q, Organization: %q, CreateTime: %d, UpstreamModelName: %q, IsModelMapped: %t, SupportStreamOptions: %t, ApiKey: ***masked*** }, ",
  245. cm.ChannelType, cm.ChannelId, cm.ChannelIsMultiKey, cm.ChannelMultiKeyIndex, cm.ChannelBaseUrl, cm.ApiType, cm.ApiVersion, cm.Organization, cm.ChannelCreateTime, cm.UpstreamModelName, cm.IsModelMapped, cm.SupportStreamOptions)
  246. }
  247. // Responses usage info (non-sensitive)
  248. if info.ResponsesUsageInfo != nil && len(info.ResponsesUsageInfo.BuiltInTools) > 0 {
  249. fmt.Fprintf(b, "ResponsesTools{ ")
  250. first := true
  251. for name, tool := range info.ResponsesUsageInfo.BuiltInTools {
  252. if !first {
  253. fmt.Fprintf(b, ", ")
  254. }
  255. first = false
  256. if tool != nil {
  257. fmt.Fprintf(b, "%s: calls=%d", name, tool.CallCount)
  258. } else {
  259. fmt.Fprintf(b, "%s: calls=0", name)
  260. }
  261. }
  262. fmt.Fprintf(b, " }, ")
  263. }
  264. fmt.Fprintf(b, "}")
  265. return b.String()
  266. }
  267. // 定义支持流式选项的通道类型
  268. var streamSupportedChannels = map[int]bool{
  269. constant.ChannelTypeOpenAI: true,
  270. constant.ChannelTypeAnthropic: true,
  271. constant.ChannelTypeAws: true,
  272. constant.ChannelTypeGemini: true,
  273. constant.ChannelCloudflare: true,
  274. constant.ChannelTypeAzure: true,
  275. constant.ChannelTypeVolcEngine: true,
  276. constant.ChannelTypeOllama: true,
  277. constant.ChannelTypeXai: true,
  278. constant.ChannelTypeDeepSeek: true,
  279. constant.ChannelTypeBaiduV2: true,
  280. constant.ChannelTypeZhipu_v4: true,
  281. constant.ChannelTypeAli: true,
  282. constant.ChannelTypeSubmodel: true,
  283. constant.ChannelTypeCodex: true,
  284. constant.ChannelTypeMoonshot: true,
  285. constant.ChannelTypeMiniMax: true,
  286. constant.ChannelTypeSiliconFlow: true,
  287. }
  288. func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
  289. info := genBaseRelayInfo(c, nil)
  290. info.RelayFormat = types.RelayFormatOpenAIRealtime
  291. info.ClientWs = ws
  292. info.InputAudioFormat = "pcm16"
  293. info.OutputAudioFormat = "pcm16"
  294. info.IsFirstRequest = true
  295. return info
  296. }
  297. func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
  298. info := genBaseRelayInfo(c, request)
  299. info.RelayFormat = types.RelayFormatClaude
  300. info.ShouldIncludeUsage = false
  301. info.ClaudeConvertInfo = &ClaudeConvertInfo{
  302. LastMessagesType: LastMessageTypeNone,
  303. }
  304. info.IsClaudeBetaQuery = c.Query("beta") == "true" || isClaudeBetaForced(c)
  305. return info
  306. }
  307. func isClaudeBetaForced(c *gin.Context) bool {
  308. channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
  309. return ok && channelOtherSettings.ClaudeBetaQuery
  310. }
  311. func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
  312. info := genBaseRelayInfo(c, request)
  313. info.RelayMode = relayconstant.RelayModeRerank
  314. info.RelayFormat = types.RelayFormatRerank
  315. info.RerankerInfo = &RerankerInfo{
  316. Documents: request.Documents,
  317. ReturnDocuments: request.GetReturnDocuments(),
  318. }
  319. return info
  320. }
  321. func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo {
  322. info := genBaseRelayInfo(c, request)
  323. info.RelayFormat = types.RelayFormatOpenAIAudio
  324. return info
  325. }
  326. func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo {
  327. info := genBaseRelayInfo(c, request)
  328. info.RelayFormat = types.RelayFormatEmbedding
  329. return info
  330. }
  331. func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo {
  332. info := genBaseRelayInfo(c, request)
  333. info.RelayMode = relayconstant.RelayModeResponses
  334. info.RelayFormat = types.RelayFormatOpenAIResponses
  335. info.ResponsesUsageInfo = &ResponsesUsageInfo{
  336. BuiltInTools: make(map[string]*BuildInToolInfo),
  337. }
  338. if len(request.Tools) > 0 {
  339. for _, tool := range request.GetToolsMap() {
  340. toolType := common.Interface2String(tool["type"])
  341. info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
  342. ToolName: toolType,
  343. CallCount: 0,
  344. }
  345. switch toolType {
  346. case dto.BuildInToolWebSearchPreview:
  347. searchContextSize := common.Interface2String(tool["search_context_size"])
  348. if searchContextSize == "" {
  349. searchContextSize = "medium"
  350. }
  351. info.ResponsesUsageInfo.BuiltInTools[toolType].SearchContextSize = searchContextSize
  352. }
  353. }
  354. }
  355. return info
  356. }
  357. func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo {
  358. info := genBaseRelayInfo(c, request)
  359. info.RelayFormat = types.RelayFormatGemini
  360. info.ShouldIncludeUsage = false
  361. return info
  362. }
  363. func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo {
  364. info := genBaseRelayInfo(c, request)
  365. info.RelayFormat = types.RelayFormatOpenAIImage
  366. return info
  367. }
  368. func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo {
  369. info := genBaseRelayInfo(c, request)
  370. info.RelayFormat = types.RelayFormatOpenAI
  371. return info
  372. }
  373. func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
  374. //channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
  375. //channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
  376. //paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
  377. tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
  378. // 当令牌分组为空时,表示使用用户分组
  379. if tokenGroup == "" {
  380. tokenGroup = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
  381. }
  382. startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
  383. if startTime.IsZero() {
  384. startTime = time.Now()
  385. }
  386. isStream := false
  387. if request != nil {
  388. isStream = request.IsStream(c)
  389. }
  390. // firstResponseTime = time.Now() - 1 second
  391. reqId := common.GetContextKeyString(c, common.RequestIdKey)
  392. if reqId == "" {
  393. reqId = common.GetTimeString() + common.GetRandomString(8)
  394. }
  395. info := &RelayInfo{
  396. Request: request,
  397. RequestId: reqId,
  398. UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId),
  399. UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
  400. UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
  401. UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
  402. UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
  403. OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
  404. TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
  405. TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
  406. TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
  407. TokenGroup: tokenGroup,
  408. isFirstResponse: true,
  409. RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
  410. RequestURLPath: c.Request.URL.String(),
  411. RequestHeaders: cloneRequestHeaders(c),
  412. IsStream: isStream,
  413. StartTime: startTime,
  414. FirstResponseTime: startTime.Add(-time.Second),
  415. ThinkingContentInfo: ThinkingContentInfo{
  416. IsFirstThinkingContent: true,
  417. SendLastThinkingContent: false,
  418. },
  419. TokenCountMeta: TokenCountMeta{
  420. //promptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
  421. estimatePromptTokens: common.GetContextKeyInt(c, constant.ContextKeyEstimatedTokens),
  422. },
  423. }
  424. if info.RelayMode == relayconstant.RelayModeUnknown {
  425. info.RelayMode = c.GetInt("relay_mode")
  426. }
  427. if strings.HasPrefix(c.Request.URL.Path, "/pg") {
  428. info.IsPlayground = true
  429. info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
  430. info.RequestURLPath = "/v1" + info.RequestURLPath
  431. }
  432. userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
  433. if ok {
  434. info.UserSetting = userSetting
  435. }
  436. return info
  437. }
  438. func cloneRequestHeaders(c *gin.Context) map[string]string {
  439. if c == nil || c.Request == nil {
  440. return nil
  441. }
  442. if len(c.Request.Header) == 0 {
  443. return nil
  444. }
  445. headers := make(map[string]string, len(c.Request.Header))
  446. for key := range c.Request.Header {
  447. value := strings.TrimSpace(c.Request.Header.Get(key))
  448. if value == "" {
  449. continue
  450. }
  451. headers[key] = value
  452. }
  453. if len(headers) == 0 {
  454. return nil
  455. }
  456. return headers
  457. }
  458. func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
  459. var info *RelayInfo
  460. var err error
  461. switch relayFormat {
  462. case types.RelayFormatOpenAI:
  463. info = GenRelayInfoOpenAI(c, request)
  464. case types.RelayFormatOpenAIAudio:
  465. info = GenRelayInfoOpenAIAudio(c, request)
  466. case types.RelayFormatOpenAIImage:
  467. info = GenRelayInfoImage(c, request)
  468. case types.RelayFormatOpenAIRealtime:
  469. info = GenRelayInfoWs(c, ws)
  470. case types.RelayFormatClaude:
  471. info = GenRelayInfoClaude(c, request)
  472. case types.RelayFormatRerank:
  473. if request, ok := request.(*dto.RerankRequest); ok {
  474. info = GenRelayInfoRerank(c, request)
  475. break
  476. }
  477. err = errors.New("request is not a RerankRequest")
  478. case types.RelayFormatGemini:
  479. info = GenRelayInfoGemini(c, request)
  480. case types.RelayFormatEmbedding:
  481. info = GenRelayInfoEmbedding(c, request)
  482. case types.RelayFormatOpenAIResponses:
  483. if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
  484. info = GenRelayInfoResponses(c, request)
  485. break
  486. }
  487. err = errors.New("request is not a OpenAIResponsesRequest")
  488. case types.RelayFormatOpenAIResponsesCompaction:
  489. if request, ok := request.(*dto.OpenAIResponsesCompactionRequest); ok {
  490. return GenRelayInfoResponsesCompaction(c, request), nil
  491. }
  492. return nil, errors.New("request is not a OpenAIResponsesCompactionRequest")
  493. case types.RelayFormatTask:
  494. info = genBaseRelayInfo(c, nil)
  495. info.TaskRelayInfo = &TaskRelayInfo{}
  496. case types.RelayFormatMjProxy:
  497. info = genBaseRelayInfo(c, nil)
  498. info.TaskRelayInfo = &TaskRelayInfo{}
  499. default:
  500. err = errors.New("invalid relay format")
  501. }
  502. if err != nil {
  503. return nil, err
  504. }
  505. if info == nil {
  506. return nil, errors.New("failed to build relay info")
  507. }
  508. info.InitRequestConversionChain()
  509. return info, nil
  510. }
  511. func (info *RelayInfo) InitRequestConversionChain() {
  512. if info == nil {
  513. return
  514. }
  515. if len(info.RequestConversionChain) > 0 {
  516. return
  517. }
  518. if info.RelayFormat == "" {
  519. return
  520. }
  521. info.RequestConversionChain = []types.RelayFormat{info.RelayFormat}
  522. }
  523. func (info *RelayInfo) AppendRequestConversion(format types.RelayFormat) {
  524. if info == nil {
  525. return
  526. }
  527. if format == "" {
  528. return
  529. }
  530. if len(info.RequestConversionChain) == 0 {
  531. info.RequestConversionChain = []types.RelayFormat{format}
  532. return
  533. }
  534. last := info.RequestConversionChain[len(info.RequestConversionChain)-1]
  535. if last == format {
  536. return
  537. }
  538. info.RequestConversionChain = append(info.RequestConversionChain, format)
  539. }
  540. func (info *RelayInfo) GetFinalRequestRelayFormat() types.RelayFormat {
  541. if info == nil {
  542. return ""
  543. }
  544. if info.FinalRequestRelayFormat != "" {
  545. return info.FinalRequestRelayFormat
  546. }
  547. if n := len(info.RequestConversionChain); n > 0 {
  548. return info.RequestConversionChain[n-1]
  549. }
  550. return info.RelayFormat
  551. }
  552. func GenRelayInfoResponsesCompaction(c *gin.Context, request *dto.OpenAIResponsesCompactionRequest) *RelayInfo {
  553. info := genBaseRelayInfo(c, request)
  554. if info.RelayMode == relayconstant.RelayModeUnknown {
  555. info.RelayMode = relayconstant.RelayModeResponsesCompact
  556. }
  557. info.RelayFormat = types.RelayFormatOpenAIResponsesCompaction
  558. return info
  559. }
  560. //func (info *RelayInfo) SetPromptTokens(promptTokens int) {
  561. // info.promptTokens = promptTokens
  562. //}
  563. func (info *RelayInfo) SetEstimatePromptTokens(promptTokens int) {
  564. info.estimatePromptTokens = promptTokens
  565. }
  566. func (info *RelayInfo) GetEstimatePromptTokens() int {
  567. return info.estimatePromptTokens
  568. }
  569. func (info *RelayInfo) SetFirstResponseTime() {
  570. if info.isFirstResponse {
  571. info.FirstResponseTime = time.Now()
  572. info.isFirstResponse = false
  573. }
  574. }
  575. func (info *RelayInfo) HasSendResponse() bool {
  576. return info.FirstResponseTime.After(info.StartTime)
  577. }
  578. type TaskRelayInfo struct {
  579. Action string
  580. OriginTaskID string
  581. // PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID,
  582. // 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。
  583. PublicTaskID string
  584. ConsumeQuota bool
  585. // LockedChannel holds the full channel object when the request is bound to
  586. // a specific channel (e.g., remix on origin task's channel). Stored as any
  587. // to avoid an import cycle with model; callers type-assert to *model.Channel.
  588. LockedChannel any
  589. }
  590. type TaskSubmitReq struct {
  591. Prompt string `json:"prompt"`
  592. Model string `json:"model,omitempty"`
  593. Mode string `json:"mode,omitempty"`
  594. Image string `json:"image,omitempty"`
  595. Images []string `json:"images,omitempty"`
  596. Size string `json:"size,omitempty"`
  597. Duration int `json:"duration,omitempty"`
  598. Seconds string `json:"seconds,omitempty"`
  599. InputReference string `json:"input_reference,omitempty"`
  600. Metadata map[string]interface{} `json:"metadata,omitempty"`
  601. }
  602. func (t *TaskSubmitReq) GetPrompt() string {
  603. return t.Prompt
  604. }
  605. func (t *TaskSubmitReq) HasImage() bool {
  606. return len(t.Images) > 0
  607. }
  608. func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
  609. type Alias TaskSubmitReq
  610. aux := &struct {
  611. Metadata json.RawMessage `json:"metadata,omitempty"`
  612. *Alias
  613. }{
  614. Alias: (*Alias)(t),
  615. }
  616. if err := common.Unmarshal(data, &aux); err != nil {
  617. return err
  618. }
  619. if len(aux.Metadata) > 0 {
  620. var metadataStr string
  621. if err := common.Unmarshal(aux.Metadata, &metadataStr); err == nil && metadataStr != "" {
  622. var metadataObj map[string]interface{}
  623. if err := common.Unmarshal([]byte(metadataStr), &metadataObj); err == nil {
  624. t.Metadata = metadataObj
  625. return nil
  626. }
  627. }
  628. var metadataObj map[string]interface{}
  629. if err := common.Unmarshal(aux.Metadata, &metadataObj); err == nil {
  630. t.Metadata = metadataObj
  631. }
  632. }
  633. return nil
  634. }
  635. func (t *TaskSubmitReq) UnmarshalMetadata(v any) error {
  636. metadata := t.Metadata
  637. if metadata != nil {
  638. metadataBytes, err := common.Marshal(metadata)
  639. if err != nil {
  640. return fmt.Errorf("marshal metadata failed: %w", err)
  641. }
  642. err = common.Unmarshal(metadataBytes, v)
  643. if err != nil {
  644. return fmt.Errorf("unmarshal metadata to target failed: %w", err)
  645. }
  646. }
  647. return nil
  648. }
  649. type TaskInfo struct {
  650. Code int `json:"code"`
  651. TaskID string `json:"task_id"`
  652. Status string `json:"status"`
  653. Reason string `json:"reason,omitempty"`
  654. Url string `json:"url,omitempty"`
  655. RemoteUrl string `json:"remote_url,omitempty"`
  656. Progress string `json:"progress,omitempty"`
  657. CompletionTokens int `json:"completion_tokens,omitempty"` // 用于按倍率计费
  658. TotalTokens int `json:"total_tokens,omitempty"` // 用于按倍率计费
  659. }
  660. func FailTaskInfo(reason string) *TaskInfo {
  661. return &TaskInfo{
  662. Status: "FAILURE",
  663. Reason: reason,
  664. }
  665. }
  666. // RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
  667. // service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持)
  668. // inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤)
  669. // store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
  670. // safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私)
  671. // stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持)
  672. func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings, channelPassThroughEnabled bool) ([]byte, error) {
  673. if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
  674. return jsonData, nil
  675. }
  676. var data map[string]interface{}
  677. if err := common.Unmarshal(jsonData, &data); err != nil {
  678. common.SysError("RemoveDisabledFields Unmarshal error :" + err.Error())
  679. return jsonData, nil
  680. }
  681. // 默认移除 service_tier,除非明确允许(避免额外计费风险)
  682. if !channelOtherSettings.AllowServiceTier {
  683. if _, exists := data["service_tier"]; exists {
  684. delete(data, "service_tier")
  685. }
  686. }
  687. // 默认移除 inference_geo,除非明确允许(避免在未授权情况下透传数据驻留区域)
  688. if !channelOtherSettings.AllowInferenceGeo {
  689. if _, exists := data["inference_geo"]; exists {
  690. delete(data, "inference_geo")
  691. }
  692. }
  693. // 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用)
  694. if channelOtherSettings.DisableStore {
  695. if _, exists := data["store"]; exists {
  696. delete(data, "store")
  697. }
  698. }
  699. // 默认移除 safety_identifier,除非明确允许(保护用户隐私,避免向 OpenAI 报告用户信息)
  700. if !channelOtherSettings.AllowSafetyIdentifier {
  701. if _, exists := data["safety_identifier"]; exists {
  702. delete(data, "safety_identifier")
  703. }
  704. }
  705. // 默认移除 stream_options.include_obfuscation,除非明确允许(避免关闭响应流混淆保护)
  706. if !channelOtherSettings.AllowIncludeObfuscation {
  707. if streamOptionsAny, exists := data["stream_options"]; exists {
  708. if streamOptions, ok := streamOptionsAny.(map[string]interface{}); ok {
  709. if _, includeExists := streamOptions["include_obfuscation"]; includeExists {
  710. delete(streamOptions, "include_obfuscation")
  711. }
  712. if len(streamOptions) == 0 {
  713. delete(data, "stream_options")
  714. } else {
  715. data["stream_options"] = streamOptions
  716. }
  717. }
  718. }
  719. }
  720. jsonDataAfter, err := common.Marshal(data)
  721. if err != nil {
  722. common.SysError("RemoveDisabledFields Marshal error :" + err.Error())
  723. return jsonData, nil
  724. }
  725. return jsonDataAfter, nil
  726. }
  727. // RemoveGeminiDisabledFields removes disabled fields from Gemini request JSON data
  728. // Currently supports removing functionResponse.id field which Vertex AI does not support
  729. func RemoveGeminiDisabledFields(jsonData []byte) ([]byte, error) {
  730. if !model_setting.GetGeminiSettings().RemoveFunctionResponseIdEnabled {
  731. return jsonData, nil
  732. }
  733. var data map[string]interface{}
  734. if err := common.Unmarshal(jsonData, &data); err != nil {
  735. common.SysError("RemoveGeminiDisabledFields Unmarshal error: " + err.Error())
  736. return jsonData, nil
  737. }
  738. // Process contents array
  739. // Handle both camelCase (functionResponse) and snake_case (function_response)
  740. if contents, ok := data["contents"].([]interface{}); ok {
  741. for _, content := range contents {
  742. if contentMap, ok := content.(map[string]interface{}); ok {
  743. if parts, ok := contentMap["parts"].([]interface{}); ok {
  744. for _, part := range parts {
  745. if partMap, ok := part.(map[string]interface{}); ok {
  746. // Check functionResponse (camelCase)
  747. if funcResp, ok := partMap["functionResponse"].(map[string]interface{}); ok {
  748. delete(funcResp, "id")
  749. }
  750. // Check function_response (snake_case)
  751. if funcResp, ok := partMap["function_response"].(map[string]interface{}); ok {
  752. delete(funcResp, "id")
  753. }
  754. }
  755. }
  756. }
  757. }
  758. }
  759. }
  760. jsonDataAfter, err := common.Marshal(data)
  761. if err != nil {
  762. common.SysError("RemoveGeminiDisabledFields Marshal error: " + err.Error())
  763. return jsonData, nil
  764. }
  765. return jsonDataAfter, nil
  766. }