relay_info.go 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871
  1. package common
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "time"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/dto"
  11. "github.com/QuantumNous/new-api/pkg/billingexpr"
  12. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  13. "github.com/QuantumNous/new-api/setting/model_setting"
  14. "github.com/QuantumNous/new-api/types"
  15. "github.com/gin-gonic/gin"
  16. "github.com/gorilla/websocket"
  17. )
  18. type ThinkingContentInfo struct {
  19. IsFirstThinkingContent bool
  20. SendLastThinkingContent bool
  21. HasSentThinkingContent bool
  22. }
  23. const (
  24. LastMessageTypeNone = "none"
  25. LastMessageTypeText = "text"
  26. LastMessageTypeTools = "tools"
  27. LastMessageTypeThinking = "thinking"
  28. )
  29. type ClaudeConvertInfo struct {
  30. LastMessagesType string
  31. Index int
  32. Usage *dto.Usage
  33. FinishReason string
  34. Done bool
  35. ToolCallBaseIndex int
  36. ToolCallMaxIndexOffset int
  37. }
  38. type RerankerInfo struct {
  39. Documents []any
  40. ReturnDocuments bool
  41. }
  42. type BuildInToolInfo struct {
  43. ToolName string
  44. CallCount int
  45. SearchContextSize string
  46. }
  47. type ResponsesUsageInfo struct {
  48. BuiltInTools map[string]*BuildInToolInfo
  49. }
  50. type ChannelMeta struct {
  51. ChannelType int
  52. ChannelId int
  53. ChannelIsMultiKey bool
  54. ChannelMultiKeyIndex int
  55. ChannelBaseUrl string
  56. ApiType int
  57. ApiVersion string
  58. ApiKey string
  59. Organization string
  60. ChannelCreateTime int64
  61. ParamOverride map[string]interface{}
  62. HeadersOverride map[string]interface{}
  63. ChannelSetting dto.ChannelSettings
  64. ChannelOtherSettings dto.ChannelOtherSettings
  65. UpstreamModelName string
  66. IsModelMapped bool
  67. SupportStreamOptions bool // 是否支持流式选项
  68. }
  69. type TokenCountMeta struct {
  70. //promptTokens int
  71. estimatePromptTokens int
  72. }
  73. type RelayInfo struct {
  74. TokenId int
  75. TokenKey string
  76. TokenGroup string
  77. UserId int
  78. UsingGroup string // 使用的分组,当auto跨分组重试时,会变动
  79. UserGroup string // 用户所在分组
  80. TokenUnlimited bool
  81. StartTime time.Time
  82. FirstResponseTime time.Time
  83. isFirstResponse bool
  84. //SendLastReasoningResponse bool
  85. IsStream bool
  86. IsGeminiBatchEmbedding bool
  87. IsPlayground bool
  88. UsePrice bool
  89. RelayMode int
  90. OriginModelName string
  91. RequestURLPath string
  92. RequestHeaders map[string]string
  93. ShouldIncludeUsage bool
  94. DisablePing bool // 是否禁止向下游发送自定义 Ping
  95. ClientWs *websocket.Conn
  96. TargetWs *websocket.Conn
  97. InputAudioFormat string
  98. OutputAudioFormat string
  99. RealtimeTools []dto.RealTimeTool
  100. IsFirstRequest bool
  101. AudioUsage bool
  102. ReasoningEffort string
  103. UserSetting dto.UserSetting
  104. UserEmail string
  105. UserQuota int
  106. RelayFormat types.RelayFormat
  107. SendResponseCount int
  108. ReceivedResponseCount int
  109. FinalPreConsumedQuota int // 最终预消耗的配额
  110. // ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路,
  111. // 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行,
  112. // 必须在提交前锁定全额。
  113. ForcePreConsume bool
  114. // Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。
  115. // 免费模型时为 nil。
  116. Billing BillingSettler
  117. // BillingSource indicates whether this request is billed from wallet quota or subscription.
  118. // "" or "wallet" => wallet; "subscription" => subscription
  119. BillingSource string
  120. // SubscriptionId is the user_subscriptions.id used when BillingSource == "subscription"
  121. SubscriptionId int
  122. // SubscriptionPreConsumed is the amount pre-consumed on subscription item (quota units or 1)
  123. SubscriptionPreConsumed int64
  124. // SubscriptionPostDelta is the post-consume delta applied to amount_used (quota units; can be negative).
  125. SubscriptionPostDelta int64
  126. // SubscriptionPlanId / SubscriptionPlanTitle are used for logging/UI display.
  127. SubscriptionPlanId int
  128. SubscriptionPlanTitle string
  129. // RequestId is used for idempotent pre-consume/refund
  130. RequestId string
  131. // SubscriptionAmountTotal / SubscriptionAmountUsedAfterPreConsume are used to compute remaining in logs.
  132. SubscriptionAmountTotal int64
  133. SubscriptionAmountUsedAfterPreConsume int64
  134. IsClaudeBetaQuery bool // /v1/messages?beta=true
  135. IsChannelTest bool // channel test request
  136. RetryIndex int
  137. LastError *types.NewAPIError
  138. RuntimeHeadersOverride map[string]interface{}
  139. UseRuntimeHeadersOverride bool
  140. ParamOverrideAudit []string
  141. PriceData types.PriceData
  142. // TieredBillingSnapshot is a frozen snapshot of tiered billing rules
  143. // captured at pre-consume time. Non-nil only when billing mode is "tiered_expr".
  144. TieredBillingSnapshot *billingexpr.BillingSnapshot
  145. BillingRequestInput *billingexpr.RequestInput
  146. Request dto.Request
  147. // RequestConversionChain records request format conversions in order, e.g.
  148. // ["openai", "openai_responses"] or ["openai", "claude"].
  149. RequestConversionChain []types.RelayFormat
  150. // 最终请求到上游的格式。可由 adaptor 显式设置;
  151. // 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
  152. FinalRequestRelayFormat types.RelayFormat
  153. StreamStatus *StreamStatus
  154. ThinkingContentInfo
  155. TokenCountMeta
  156. *ClaudeConvertInfo
  157. *RerankerInfo
  158. *ResponsesUsageInfo
  159. *ChannelMeta
  160. *TaskRelayInfo
  161. }
  162. func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
  163. channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
  164. paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
  165. headerOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelHeaderOverride)
  166. apiType, _ := common.ChannelType2APIType(channelType)
  167. channelMeta := &ChannelMeta{
  168. ChannelType: channelType,
  169. ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId),
  170. ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
  171. ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
  172. ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
  173. ApiType: apiType,
  174. ApiVersion: c.GetString("api_version"),
  175. ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
  176. Organization: c.GetString("channel_organization"),
  177. ChannelCreateTime: c.GetInt64("channel_create_time"),
  178. ParamOverride: paramOverride,
  179. HeadersOverride: headerOverride,
  180. UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
  181. IsModelMapped: false,
  182. SupportStreamOptions: false,
  183. }
  184. if channelType == constant.ChannelTypeAzure {
  185. channelMeta.ApiVersion = GetAPIVersion(c)
  186. }
  187. if channelType == constant.ChannelTypeVertexAi {
  188. channelMeta.ApiVersion = c.GetString("region")
  189. }
  190. channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
  191. if ok {
  192. channelMeta.ChannelSetting = channelSetting
  193. }
  194. channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
  195. if ok {
  196. channelMeta.ChannelOtherSettings = channelOtherSettings
  197. }
  198. if streamSupportedChannels[channelMeta.ChannelType] {
  199. channelMeta.SupportStreamOptions = true
  200. }
  201. info.ChannelMeta = channelMeta
  202. // reset some fields based on channel meta
  203. // 重置某些字段,例如模型名称等
  204. if info.Request != nil {
  205. info.Request.SetModelName(info.OriginModelName)
  206. }
  207. }
  208. func (info *RelayInfo) ToString() string {
  209. if info == nil {
  210. return "RelayInfo<nil>"
  211. }
  212. // Basic info
  213. b := &strings.Builder{}
  214. fmt.Fprintf(b, "RelayInfo{ ")
  215. fmt.Fprintf(b, "RelayFormat: %s, ", info.RelayFormat)
  216. fmt.Fprintf(b, "RelayMode: %d, ", info.RelayMode)
  217. fmt.Fprintf(b, "IsStream: %t, ", info.IsStream)
  218. fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground)
  219. fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath)
  220. fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName)
  221. fmt.Fprintf(b, "EstimatePromptTokens: %d, ", info.estimatePromptTokens)
  222. fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage)
  223. fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing)
  224. fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount)
  225. fmt.Fprintf(b, "FinalPreConsumedQuota: %d, ", info.FinalPreConsumedQuota)
  226. // User & token info (mask secrets)
  227. fmt.Fprintf(b, "User{ Id: %d, Email: %q, Group: %q, UsingGroup: %q, Quota: %d }, ",
  228. info.UserId, common.MaskEmail(info.UserEmail), info.UserGroup, info.UsingGroup, info.UserQuota)
  229. fmt.Fprintf(b, "Token{ Id: %d, Unlimited: %t, Key: ***masked*** }, ", info.TokenId, info.TokenUnlimited)
  230. // Time info
  231. latencyMs := info.FirstResponseTime.Sub(info.StartTime).Milliseconds()
  232. fmt.Fprintf(b, "Timing{ Start: %s, FirstResponse: %s, LatencyMs: %d }, ",
  233. info.StartTime.Format(time.RFC3339Nano), info.FirstResponseTime.Format(time.RFC3339Nano), latencyMs)
  234. // Audio / realtime
  235. if info.InputAudioFormat != "" || info.OutputAudioFormat != "" || len(info.RealtimeTools) > 0 || info.AudioUsage {
  236. fmt.Fprintf(b, "Realtime{ AudioUsage: %t, InFmt: %q, OutFmt: %q, Tools: %d }, ",
  237. info.AudioUsage, info.InputAudioFormat, info.OutputAudioFormat, len(info.RealtimeTools))
  238. }
  239. // Reasoning
  240. if info.ReasoningEffort != "" {
  241. fmt.Fprintf(b, "ReasoningEffort: %q, ", info.ReasoningEffort)
  242. }
  243. // Price data (non-sensitive)
  244. if info.PriceData.UsePrice {
  245. fmt.Fprintf(b, "PriceData{ %s }, ", info.PriceData.ToSetting())
  246. }
  247. // Channel metadata (mask ApiKey)
  248. if info.ChannelMeta != nil {
  249. cm := info.ChannelMeta
  250. fmt.Fprintf(b, "ChannelMeta{ Type: %d, Id: %d, IsMultiKey: %t, MultiKeyIndex: %d, BaseURL: %q, ApiType: %d, ApiVersion: %q, Organization: %q, CreateTime: %d, UpstreamModelName: %q, IsModelMapped: %t, SupportStreamOptions: %t, ApiKey: ***masked*** }, ",
  251. cm.ChannelType, cm.ChannelId, cm.ChannelIsMultiKey, cm.ChannelMultiKeyIndex, cm.ChannelBaseUrl, cm.ApiType, cm.ApiVersion, cm.Organization, cm.ChannelCreateTime, cm.UpstreamModelName, cm.IsModelMapped, cm.SupportStreamOptions)
  252. }
  253. // Responses usage info (non-sensitive)
  254. if info.ResponsesUsageInfo != nil && len(info.ResponsesUsageInfo.BuiltInTools) > 0 {
  255. fmt.Fprintf(b, "ResponsesTools{ ")
  256. first := true
  257. for name, tool := range info.ResponsesUsageInfo.BuiltInTools {
  258. if !first {
  259. fmt.Fprintf(b, ", ")
  260. }
  261. first = false
  262. if tool != nil {
  263. fmt.Fprintf(b, "%s: calls=%d", name, tool.CallCount)
  264. } else {
  265. fmt.Fprintf(b, "%s: calls=0", name)
  266. }
  267. }
  268. fmt.Fprintf(b, " }, ")
  269. }
  270. fmt.Fprintf(b, "}")
  271. return b.String()
  272. }
  273. // 定义支持流式选项的通道类型
  274. var streamSupportedChannels = map[int]bool{
  275. constant.ChannelTypeOpenAI: true,
  276. constant.ChannelTypeAnthropic: true,
  277. constant.ChannelTypeAws: true,
  278. constant.ChannelTypeGemini: true,
  279. constant.ChannelCloudflare: true,
  280. constant.ChannelTypeAzure: true,
  281. constant.ChannelTypeVolcEngine: true,
  282. constant.ChannelTypeOllama: true,
  283. constant.ChannelTypeXai: true,
  284. constant.ChannelTypeDeepSeek: true,
  285. constant.ChannelTypeBaiduV2: true,
  286. constant.ChannelTypeZhipu_v4: true,
  287. constant.ChannelTypeAli: true,
  288. constant.ChannelTypeSubmodel: true,
  289. constant.ChannelTypeCodex: true,
  290. constant.ChannelTypeMoonshot: true,
  291. constant.ChannelTypeMiniMax: true,
  292. constant.ChannelTypeSiliconFlow: true,
  293. }
  294. func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
  295. info := genBaseRelayInfo(c, nil)
  296. info.RelayFormat = types.RelayFormatOpenAIRealtime
  297. info.ClientWs = ws
  298. info.InputAudioFormat = "pcm16"
  299. info.OutputAudioFormat = "pcm16"
  300. info.IsFirstRequest = true
  301. return info
  302. }
  303. func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
  304. info := genBaseRelayInfo(c, request)
  305. info.RelayFormat = types.RelayFormatClaude
  306. info.ShouldIncludeUsage = false
  307. info.ClaudeConvertInfo = &ClaudeConvertInfo{
  308. LastMessagesType: LastMessageTypeNone,
  309. }
  310. info.IsClaudeBetaQuery = c.Query("beta") == "true"
  311. return info
  312. }
  313. func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
  314. info := genBaseRelayInfo(c, request)
  315. info.RelayMode = relayconstant.RelayModeRerank
  316. info.RelayFormat = types.RelayFormatRerank
  317. info.RerankerInfo = &RerankerInfo{
  318. Documents: request.Documents,
  319. ReturnDocuments: request.GetReturnDocuments(),
  320. }
  321. return info
  322. }
  323. func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo {
  324. info := genBaseRelayInfo(c, request)
  325. info.RelayFormat = types.RelayFormatOpenAIAudio
  326. return info
  327. }
  328. func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo {
  329. info := genBaseRelayInfo(c, request)
  330. info.RelayFormat = types.RelayFormatEmbedding
  331. return info
  332. }
  333. func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo {
  334. info := genBaseRelayInfo(c, request)
  335. info.RelayMode = relayconstant.RelayModeResponses
  336. info.RelayFormat = types.RelayFormatOpenAIResponses
  337. info.ResponsesUsageInfo = &ResponsesUsageInfo{
  338. BuiltInTools: make(map[string]*BuildInToolInfo),
  339. }
  340. if len(request.Tools) > 0 {
  341. for _, tool := range request.GetToolsMap() {
  342. toolType := common.Interface2String(tool["type"])
  343. info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
  344. ToolName: toolType,
  345. CallCount: 0,
  346. }
  347. switch toolType {
  348. case dto.BuildInToolWebSearchPreview:
  349. searchContextSize := common.Interface2String(tool["search_context_size"])
  350. if searchContextSize == "" {
  351. searchContextSize = "medium"
  352. }
  353. info.ResponsesUsageInfo.BuiltInTools[toolType].SearchContextSize = searchContextSize
  354. }
  355. }
  356. }
  357. return info
  358. }
  359. func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo {
  360. info := genBaseRelayInfo(c, request)
  361. info.RelayFormat = types.RelayFormatGemini
  362. info.ShouldIncludeUsage = false
  363. return info
  364. }
  365. func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo {
  366. info := genBaseRelayInfo(c, request)
  367. info.RelayFormat = types.RelayFormatOpenAIImage
  368. return info
  369. }
  370. func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo {
  371. info := genBaseRelayInfo(c, request)
  372. info.RelayFormat = types.RelayFormatOpenAI
  373. return info
  374. }
  375. func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
  376. //channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
  377. //channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
  378. //paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
  379. tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
  380. // 当令牌分组为空时,表示使用用户分组
  381. if tokenGroup == "" {
  382. tokenGroup = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
  383. }
  384. startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
  385. if startTime.IsZero() {
  386. startTime = time.Now()
  387. }
  388. isStream := false
  389. if request != nil {
  390. isStream = request.IsStream(c)
  391. }
  392. // firstResponseTime = time.Now() - 1 second
  393. reqId := common.GetContextKeyString(c, common.RequestIdKey)
  394. if reqId == "" {
  395. reqId = common.GetTimeString() + common.GetRandomString(8)
  396. }
  397. info := &RelayInfo{
  398. Request: request,
  399. RequestId: reqId,
  400. UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId),
  401. UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
  402. UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
  403. UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
  404. UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
  405. OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
  406. TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
  407. TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
  408. TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
  409. TokenGroup: tokenGroup,
  410. isFirstResponse: true,
  411. RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
  412. RequestURLPath: c.Request.URL.String(),
  413. RequestHeaders: cloneRequestHeaders(c),
  414. IsStream: isStream,
  415. StartTime: startTime,
  416. FirstResponseTime: startTime.Add(-time.Second),
  417. ThinkingContentInfo: ThinkingContentInfo{
  418. IsFirstThinkingContent: true,
  419. SendLastThinkingContent: false,
  420. },
  421. TokenCountMeta: TokenCountMeta{
  422. //promptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
  423. estimatePromptTokens: common.GetContextKeyInt(c, constant.ContextKeyEstimatedTokens),
  424. },
  425. }
  426. if info.RelayMode == relayconstant.RelayModeUnknown {
  427. info.RelayMode = c.GetInt("relay_mode")
  428. }
  429. if strings.HasPrefix(c.Request.URL.Path, "/pg") {
  430. info.IsPlayground = true
  431. info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
  432. info.RequestURLPath = "/v1" + info.RequestURLPath
  433. }
  434. userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
  435. if ok {
  436. info.UserSetting = userSetting
  437. }
  438. return info
  439. }
  440. func cloneRequestHeaders(c *gin.Context) map[string]string {
  441. if c == nil || c.Request == nil {
  442. return nil
  443. }
  444. if len(c.Request.Header) == 0 {
  445. return nil
  446. }
  447. headers := make(map[string]string, len(c.Request.Header))
  448. for key := range c.Request.Header {
  449. value := strings.TrimSpace(c.Request.Header.Get(key))
  450. if value == "" {
  451. continue
  452. }
  453. headers[key] = value
  454. }
  455. if len(headers) == 0 {
  456. return nil
  457. }
  458. return headers
  459. }
  460. func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
  461. var info *RelayInfo
  462. var err error
  463. switch relayFormat {
  464. case types.RelayFormatOpenAI:
  465. info = GenRelayInfoOpenAI(c, request)
  466. case types.RelayFormatOpenAIAudio:
  467. info = GenRelayInfoOpenAIAudio(c, request)
  468. case types.RelayFormatOpenAIImage:
  469. info = GenRelayInfoImage(c, request)
  470. case types.RelayFormatOpenAIRealtime:
  471. info = GenRelayInfoWs(c, ws)
  472. case types.RelayFormatClaude:
  473. info = GenRelayInfoClaude(c, request)
  474. case types.RelayFormatRerank:
  475. if request, ok := request.(*dto.RerankRequest); ok {
  476. info = GenRelayInfoRerank(c, request)
  477. break
  478. }
  479. err = errors.New("request is not a RerankRequest")
  480. case types.RelayFormatGemini:
  481. info = GenRelayInfoGemini(c, request)
  482. case types.RelayFormatEmbedding:
  483. info = GenRelayInfoEmbedding(c, request)
  484. case types.RelayFormatOpenAIResponses:
  485. if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
  486. info = GenRelayInfoResponses(c, request)
  487. break
  488. }
  489. err = errors.New("request is not a OpenAIResponsesRequest")
  490. case types.RelayFormatOpenAIResponsesCompaction:
  491. if request, ok := request.(*dto.OpenAIResponsesCompactionRequest); ok {
  492. return GenRelayInfoResponsesCompaction(c, request), nil
  493. }
  494. return nil, errors.New("request is not a OpenAIResponsesCompactionRequest")
  495. case types.RelayFormatTask:
  496. info = genBaseRelayInfo(c, nil)
  497. info.TaskRelayInfo = &TaskRelayInfo{}
  498. case types.RelayFormatMjProxy:
  499. info = genBaseRelayInfo(c, nil)
  500. info.TaskRelayInfo = &TaskRelayInfo{}
  501. default:
  502. err = errors.New("invalid relay format")
  503. }
  504. if err != nil {
  505. return nil, err
  506. }
  507. if info == nil {
  508. return nil, errors.New("failed to build relay info")
  509. }
  510. info.InitRequestConversionChain()
  511. return info, nil
  512. }
  513. func (info *RelayInfo) InitRequestConversionChain() {
  514. if info == nil {
  515. return
  516. }
  517. if len(info.RequestConversionChain) > 0 {
  518. return
  519. }
  520. if info.RelayFormat == "" {
  521. return
  522. }
  523. info.RequestConversionChain = []types.RelayFormat{info.RelayFormat}
  524. }
  525. func (info *RelayInfo) AppendRequestConversion(format types.RelayFormat) {
  526. if info == nil {
  527. return
  528. }
  529. if format == "" {
  530. return
  531. }
  532. if len(info.RequestConversionChain) == 0 {
  533. info.RequestConversionChain = []types.RelayFormat{format}
  534. return
  535. }
  536. last := info.RequestConversionChain[len(info.RequestConversionChain)-1]
  537. if last == format {
  538. return
  539. }
  540. info.RequestConversionChain = append(info.RequestConversionChain, format)
  541. }
  542. func (info *RelayInfo) GetFinalRequestRelayFormat() types.RelayFormat {
  543. if info == nil {
  544. return ""
  545. }
  546. if info.FinalRequestRelayFormat != "" {
  547. return info.FinalRequestRelayFormat
  548. }
  549. if n := len(info.RequestConversionChain); n > 0 {
  550. return info.RequestConversionChain[n-1]
  551. }
  552. return info.RelayFormat
  553. }
  554. func GenRelayInfoResponsesCompaction(c *gin.Context, request *dto.OpenAIResponsesCompactionRequest) *RelayInfo {
  555. info := genBaseRelayInfo(c, request)
  556. if info.RelayMode == relayconstant.RelayModeUnknown {
  557. info.RelayMode = relayconstant.RelayModeResponsesCompact
  558. }
  559. info.RelayFormat = types.RelayFormatOpenAIResponsesCompaction
  560. return info
  561. }
  562. //func (info *RelayInfo) SetPromptTokens(promptTokens int) {
  563. // info.promptTokens = promptTokens
  564. //}
  565. func (info *RelayInfo) SetEstimatePromptTokens(promptTokens int) {
  566. info.estimatePromptTokens = promptTokens
  567. }
  568. func (info *RelayInfo) GetEstimatePromptTokens() int {
  569. return info.estimatePromptTokens
  570. }
  571. func (info *RelayInfo) SetFirstResponseTime() {
  572. if info.isFirstResponse {
  573. info.FirstResponseTime = time.Now()
  574. info.isFirstResponse = false
  575. }
  576. }
  577. func (info *RelayInfo) HasSendResponse() bool {
  578. return info.FirstResponseTime.After(info.StartTime)
  579. }
  580. type TaskRelayInfo struct {
  581. Action string
  582. OriginTaskID string
  583. // PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID,
  584. // 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。
  585. PublicTaskID string
  586. ConsumeQuota bool
  587. // LockedChannel holds the full channel object when the request is bound to
  588. // a specific channel (e.g., remix on origin task's channel). Stored as any
  589. // to avoid an import cycle with model; callers type-assert to *model.Channel.
  590. LockedChannel any
  591. }
  592. type TaskSubmitReq struct {
  593. Prompt string `json:"prompt"`
  594. Model string `json:"model,omitempty"`
  595. Mode string `json:"mode,omitempty"`
  596. Image string `json:"image,omitempty"`
  597. Images []string `json:"images,omitempty"`
  598. Size string `json:"size,omitempty"`
  599. Duration int `json:"duration,omitempty"`
  600. Seconds string `json:"seconds,omitempty"`
  601. InputReference string `json:"input_reference,omitempty"`
  602. Metadata map[string]interface{} `json:"metadata,omitempty"`
  603. }
  604. func (t *TaskSubmitReq) GetPrompt() string {
  605. return t.Prompt
  606. }
  607. func (t *TaskSubmitReq) HasImage() bool {
  608. return len(t.Images) > 0
  609. }
  610. func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
  611. type Alias TaskSubmitReq
  612. aux := &struct {
  613. Metadata json.RawMessage `json:"metadata,omitempty"`
  614. *Alias
  615. }{
  616. Alias: (*Alias)(t),
  617. }
  618. if err := common.Unmarshal(data, &aux); err != nil {
  619. return err
  620. }
  621. if len(aux.Metadata) > 0 {
  622. var metadataStr string
  623. if err := common.Unmarshal(aux.Metadata, &metadataStr); err == nil && metadataStr != "" {
  624. var metadataObj map[string]interface{}
  625. if err := common.Unmarshal([]byte(metadataStr), &metadataObj); err == nil {
  626. t.Metadata = metadataObj
  627. return nil
  628. }
  629. }
  630. var metadataObj map[string]interface{}
  631. if err := common.Unmarshal(aux.Metadata, &metadataObj); err == nil {
  632. t.Metadata = metadataObj
  633. }
  634. }
  635. return nil
  636. }
  637. func (t *TaskSubmitReq) UnmarshalMetadata(v any) error {
  638. metadata := t.Metadata
  639. if metadata != nil {
  640. metadataBytes, err := common.Marshal(metadata)
  641. if err != nil {
  642. return fmt.Errorf("marshal metadata failed: %w", err)
  643. }
  644. err = common.Unmarshal(metadataBytes, v)
  645. if err != nil {
  646. return fmt.Errorf("unmarshal metadata to target failed: %w", err)
  647. }
  648. }
  649. return nil
  650. }
  651. type TaskInfo struct {
  652. Code int `json:"code"`
  653. TaskID string `json:"task_id"`
  654. Status string `json:"status"`
  655. Reason string `json:"reason,omitempty"`
  656. Url string `json:"url,omitempty"`
  657. RemoteUrl string `json:"remote_url,omitempty"`
  658. Progress string `json:"progress,omitempty"`
  659. CompletionTokens int `json:"completion_tokens,omitempty"` // 用于按倍率计费
  660. TotalTokens int `json:"total_tokens,omitempty"` // 用于按倍率计费
  661. }
  662. func FailTaskInfo(reason string) *TaskInfo {
  663. return &TaskInfo{
  664. Status: "FAILURE",
  665. Reason: reason,
  666. }
  667. }
  668. // RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
  669. // service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持)
  670. // inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤)
  671. // store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
  672. // safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私)
  673. // stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持)
  674. func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings, channelPassThroughEnabled bool) ([]byte, error) {
  675. if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
  676. return jsonData, nil
  677. }
  678. var data map[string]interface{}
  679. if err := common.Unmarshal(jsonData, &data); err != nil {
  680. common.SysError("RemoveDisabledFields Unmarshal error :" + err.Error())
  681. return jsonData, nil
  682. }
  683. // 默认移除 service_tier,除非明确允许(避免额外计费风险)
  684. if !channelOtherSettings.AllowServiceTier {
  685. if _, exists := data["service_tier"]; exists {
  686. delete(data, "service_tier")
  687. }
  688. }
  689. // 默认移除 inference_geo,除非明确允许(避免在未授权情况下透传数据驻留区域)
  690. if !channelOtherSettings.AllowInferenceGeo {
  691. if _, exists := data["inference_geo"]; exists {
  692. delete(data, "inference_geo")
  693. }
  694. }
  695. // 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用)
  696. if channelOtherSettings.DisableStore {
  697. if _, exists := data["store"]; exists {
  698. delete(data, "store")
  699. }
  700. }
  701. // 默认移除 safety_identifier,除非明确允许(保护用户隐私,避免向 OpenAI 报告用户信息)
  702. if !channelOtherSettings.AllowSafetyIdentifier {
  703. if _, exists := data["safety_identifier"]; exists {
  704. delete(data, "safety_identifier")
  705. }
  706. }
  707. // 默认移除 stream_options.include_obfuscation,除非明确允许(避免关闭响应流混淆保护)
  708. if !channelOtherSettings.AllowIncludeObfuscation {
  709. if streamOptionsAny, exists := data["stream_options"]; exists {
  710. if streamOptions, ok := streamOptionsAny.(map[string]interface{}); ok {
  711. if _, includeExists := streamOptions["include_obfuscation"]; includeExists {
  712. delete(streamOptions, "include_obfuscation")
  713. }
  714. if len(streamOptions) == 0 {
  715. delete(data, "stream_options")
  716. } else {
  717. data["stream_options"] = streamOptions
  718. }
  719. }
  720. }
  721. }
  722. jsonDataAfter, err := common.Marshal(data)
  723. if err != nil {
  724. common.SysError("RemoveDisabledFields Marshal error :" + err.Error())
  725. return jsonData, nil
  726. }
  727. return jsonDataAfter, nil
  728. }
  729. // RemoveGeminiDisabledFields removes disabled fields from Gemini request JSON data
  730. // Currently supports removing functionResponse.id field which Vertex AI does not support
  731. func RemoveGeminiDisabledFields(jsonData []byte) ([]byte, error) {
  732. if !model_setting.GetGeminiSettings().RemoveFunctionResponseIdEnabled {
  733. return jsonData, nil
  734. }
  735. var data map[string]interface{}
  736. if err := common.Unmarshal(jsonData, &data); err != nil {
  737. common.SysError("RemoveGeminiDisabledFields Unmarshal error: " + err.Error())
  738. return jsonData, nil
  739. }
  740. // Process contents array
  741. // Handle both camelCase (functionResponse) and snake_case (function_response)
  742. if contents, ok := data["contents"].([]interface{}); ok {
  743. for _, content := range contents {
  744. if contentMap, ok := content.(map[string]interface{}); ok {
  745. if parts, ok := contentMap["parts"].([]interface{}); ok {
  746. for _, part := range parts {
  747. if partMap, ok := part.(map[string]interface{}); ok {
  748. // Check functionResponse (camelCase)
  749. if funcResp, ok := partMap["functionResponse"].(map[string]interface{}); ok {
  750. delete(funcResp, "id")
  751. }
  752. // Check function_response (snake_case)
  753. if funcResp, ok := partMap["function_response"].(map[string]interface{}); ok {
  754. delete(funcResp, "id")
  755. }
  756. }
  757. }
  758. }
  759. }
  760. }
  761. }
  762. jsonDataAfter, err := common.Marshal(data)
  763. if err != nil {
  764. common.SysError("RemoveGeminiDisabledFields Marshal error: " + err.Error())
  765. return jsonData, nil
  766. }
  767. return jsonDataAfter, nil
  768. }