relay_info.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896
  1. package common
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/constant"
  11. "github.com/QuantumNous/new-api/dto"
  12. "github.com/QuantumNous/new-api/pkg/billingexpr"
  13. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  14. "github.com/QuantumNous/new-api/setting/model_setting"
  15. "github.com/QuantumNous/new-api/types"
  16. "github.com/gin-gonic/gin"
  17. "github.com/gorilla/websocket"
  18. )
  19. type ThinkingContentInfo struct {
  20. IsFirstThinkingContent bool
  21. SendLastThinkingContent bool
  22. HasSentThinkingContent bool
  23. }
  24. const (
  25. LastMessageTypeNone = "none"
  26. LastMessageTypeText = "text"
  27. LastMessageTypeTools = "tools"
  28. LastMessageTypeThinking = "thinking"
  29. )
  30. type ClaudeConvertInfo struct {
  31. LastMessagesType string
  32. Index int
  33. Usage *dto.Usage
  34. FinishReason string
  35. Done bool
  36. ToolCallBaseIndex int
  37. ToolCallMaxIndexOffset int
  38. }
  39. type RerankerInfo struct {
  40. Documents []any
  41. ReturnDocuments bool
  42. }
  43. type BuildInToolInfo struct {
  44. ToolName string
  45. CallCount int
  46. SearchContextSize string
  47. }
  48. type ResponsesUsageInfo struct {
  49. BuiltInTools map[string]*BuildInToolInfo
  50. }
  51. type ChannelMeta struct {
  52. ChannelType int
  53. ChannelId int
  54. ChannelIsMultiKey bool
  55. ChannelMultiKeyIndex int
  56. ChannelBaseUrl string
  57. ApiType int
  58. ApiVersion string
  59. ApiKey string
  60. Organization string
  61. ChannelCreateTime int64
  62. ParamOverride map[string]interface{}
  63. HeadersOverride map[string]interface{}
  64. ChannelSetting dto.ChannelSettings
  65. ChannelOtherSettings dto.ChannelOtherSettings
  66. UpstreamModelName string
  67. IsModelMapped bool
  68. SupportStreamOptions bool // 是否支持流式选项
  69. }
  70. type TokenCountMeta struct {
  71. //promptTokens int
  72. estimatePromptTokens int
  73. }
  74. type RelayInfo struct {
  75. TokenId int
  76. TokenKey string
  77. TokenGroup string
  78. UserId int
  79. UsingGroup string // 使用的分组,当auto跨分组重试时,会变动
  80. UserGroup string // 用户所在分组
  81. TokenUnlimited bool
  82. StartTime time.Time
  83. FirstResponseTime time.Time
  84. isFirstResponse bool
  85. //SendLastReasoningResponse bool
  86. IsStream bool
  87. IsGeminiBatchEmbedding bool
  88. IsPlayground bool
  89. UsePrice bool
  90. RelayMode int
  91. OriginModelName string
  92. RequestURLPath string
  93. RequestHeaders map[string]string
  94. ShouldIncludeUsage bool
  95. DisablePing bool // 是否禁止向下游发送自定义 Ping
  96. ClientWs *websocket.Conn
  97. TargetWs *websocket.Conn
  98. InputAudioFormat string
  99. OutputAudioFormat string
  100. RealtimeTools []dto.RealTimeTool
  101. IsFirstRequest bool
  102. AudioUsage bool
  103. ReasoningEffort string
  104. UserSetting dto.UserSetting
  105. UserEmail string
  106. UserQuota int
  107. RelayFormat types.RelayFormat
  108. SendResponseCount int
  109. ReceivedResponseCount int
  110. FinalPreConsumedQuota int // 最终预消耗的配额
  111. // ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路,
  112. // 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行,
  113. // 必须在提交前锁定全额。
  114. ForcePreConsume bool
  115. // Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。
  116. // 免费模型时为 nil。
  117. Billing BillingSettler
  118. // BillingSource indicates whether this request is billed from wallet quota or subscription.
  119. // "" or "wallet" => wallet; "subscription" => subscription
  120. BillingSource string
  121. // SubscriptionId is the user_subscriptions.id used when BillingSource == "subscription"
  122. SubscriptionId int
  123. // SubscriptionPreConsumed is the amount pre-consumed on subscription item (quota units or 1)
  124. SubscriptionPreConsumed int64
  125. // SubscriptionPostDelta is the post-consume delta applied to amount_used (quota units; can be negative).
  126. SubscriptionPostDelta int64
  127. // SubscriptionPlanId / SubscriptionPlanTitle are used for logging/UI display.
  128. SubscriptionPlanId int
  129. SubscriptionPlanTitle string
  130. // RequestId is used for idempotent pre-consume/refund
  131. RequestId string
  132. // SubscriptionAmountTotal / SubscriptionAmountUsedAfterPreConsume are used to compute remaining in logs.
  133. SubscriptionAmountTotal int64
  134. SubscriptionAmountUsedAfterPreConsume int64
  135. IsClaudeBetaQuery bool // /v1/messages?beta=true
  136. IsChannelTest bool // channel test request
  137. RetryIndex int
  138. LastError *types.NewAPIError
  139. RuntimeHeadersOverride map[string]interface{}
  140. UseRuntimeHeadersOverride bool
  141. ParamOverrideAudit []string
  142. PriceData types.PriceData
  143. // TieredBillingSnapshot is a frozen snapshot of tiered billing rules
  144. // captured at pre-consume time. Non-nil only when billing mode is "tiered_expr".
  145. TieredBillingSnapshot *billingexpr.BillingSnapshot
  146. BillingRequestInput *billingexpr.RequestInput
  147. Request dto.Request
  148. // RequestConversionChain records request format conversions in order, e.g.
  149. // ["openai", "openai_responses"] or ["openai", "claude"].
  150. RequestConversionChain []types.RelayFormat
  151. // 最终请求到上游的格式。可由 adaptor 显式设置;
  152. // 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
  153. FinalRequestRelayFormat types.RelayFormat
  154. StreamStatus *StreamStatus
  155. ThinkingContentInfo
  156. TokenCountMeta
  157. *ClaudeConvertInfo
  158. *RerankerInfo
  159. *ResponsesUsageInfo
  160. *ChannelMeta
  161. *TaskRelayInfo
  162. }
  163. func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
  164. channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
  165. paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
  166. headerOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelHeaderOverride)
  167. apiType, _ := common.ChannelType2APIType(channelType)
  168. channelMeta := &ChannelMeta{
  169. ChannelType: channelType,
  170. ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId),
  171. ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
  172. ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
  173. ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
  174. ApiType: apiType,
  175. ApiVersion: c.GetString("api_version"),
  176. ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
  177. Organization: c.GetString("channel_organization"),
  178. ChannelCreateTime: c.GetInt64("channel_create_time"),
  179. ParamOverride: paramOverride,
  180. HeadersOverride: headerOverride,
  181. UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
  182. IsModelMapped: false,
  183. SupportStreamOptions: false,
  184. }
  185. if channelType == constant.ChannelTypeAzure {
  186. channelMeta.ApiVersion = GetAPIVersion(c)
  187. }
  188. if channelType == constant.ChannelTypeVertexAi {
  189. channelMeta.ApiVersion = c.GetString("region")
  190. }
  191. channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
  192. if ok {
  193. channelMeta.ChannelSetting = channelSetting
  194. }
  195. channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
  196. if ok {
  197. channelMeta.ChannelOtherSettings = channelOtherSettings
  198. }
  199. if streamSupportedChannels[channelMeta.ChannelType] {
  200. channelMeta.SupportStreamOptions = true
  201. }
  202. info.ChannelMeta = channelMeta
  203. // reset some fields based on channel meta
  204. // 重置某些字段,例如模型名称等
  205. if info.Request != nil {
  206. info.Request.SetModelName(info.OriginModelName)
  207. }
  208. }
  209. func (info *RelayInfo) ToString() string {
  210. if info == nil {
  211. return "RelayInfo<nil>"
  212. }
  213. // Basic info
  214. b := &strings.Builder{}
  215. fmt.Fprintf(b, "RelayInfo{ ")
  216. fmt.Fprintf(b, "RelayFormat: %s, ", info.RelayFormat)
  217. fmt.Fprintf(b, "RelayMode: %d, ", info.RelayMode)
  218. fmt.Fprintf(b, "IsStream: %t, ", info.IsStream)
  219. fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground)
  220. fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath)
  221. fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName)
  222. fmt.Fprintf(b, "EstimatePromptTokens: %d, ", info.estimatePromptTokens)
  223. fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage)
  224. fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing)
  225. fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount)
  226. fmt.Fprintf(b, "FinalPreConsumedQuota: %d, ", info.FinalPreConsumedQuota)
  227. // User & token info (mask secrets)
  228. fmt.Fprintf(b, "User{ Id: %d, Email: %q, Group: %q, UsingGroup: %q, Quota: %d }, ",
  229. info.UserId, common.MaskEmail(info.UserEmail), info.UserGroup, info.UsingGroup, info.UserQuota)
  230. fmt.Fprintf(b, "Token{ Id: %d, Unlimited: %t, Key: ***masked*** }, ", info.TokenId, info.TokenUnlimited)
  231. // Time info
  232. latencyMs := info.FirstResponseTime.Sub(info.StartTime).Milliseconds()
  233. fmt.Fprintf(b, "Timing{ Start: %s, FirstResponse: %s, LatencyMs: %d }, ",
  234. info.StartTime.Format(time.RFC3339Nano), info.FirstResponseTime.Format(time.RFC3339Nano), latencyMs)
  235. // Audio / realtime
  236. if info.InputAudioFormat != "" || info.OutputAudioFormat != "" || len(info.RealtimeTools) > 0 || info.AudioUsage {
  237. fmt.Fprintf(b, "Realtime{ AudioUsage: %t, InFmt: %q, OutFmt: %q, Tools: %d }, ",
  238. info.AudioUsage, info.InputAudioFormat, info.OutputAudioFormat, len(info.RealtimeTools))
  239. }
  240. // Reasoning
  241. if info.ReasoningEffort != "" {
  242. fmt.Fprintf(b, "ReasoningEffort: %q, ", info.ReasoningEffort)
  243. }
  244. // Price data (non-sensitive)
  245. if info.PriceData.UsePrice {
  246. fmt.Fprintf(b, "PriceData{ %s }, ", info.PriceData.ToSetting())
  247. }
  248. // Channel metadata (mask ApiKey)
  249. if info.ChannelMeta != nil {
  250. cm := info.ChannelMeta
  251. fmt.Fprintf(b, "ChannelMeta{ Type: %d, Id: %d, IsMultiKey: %t, MultiKeyIndex: %d, BaseURL: %q, ApiType: %d, ApiVersion: %q, Organization: %q, CreateTime: %d, UpstreamModelName: %q, IsModelMapped: %t, SupportStreamOptions: %t, ApiKey: ***masked*** }, ",
  252. cm.ChannelType, cm.ChannelId, cm.ChannelIsMultiKey, cm.ChannelMultiKeyIndex, cm.ChannelBaseUrl, cm.ApiType, cm.ApiVersion, cm.Organization, cm.ChannelCreateTime, cm.UpstreamModelName, cm.IsModelMapped, cm.SupportStreamOptions)
  253. }
  254. // Responses usage info (non-sensitive)
  255. if info.ResponsesUsageInfo != nil && len(info.ResponsesUsageInfo.BuiltInTools) > 0 {
  256. fmt.Fprintf(b, "ResponsesTools{ ")
  257. first := true
  258. for name, tool := range info.ResponsesUsageInfo.BuiltInTools {
  259. if !first {
  260. fmt.Fprintf(b, ", ")
  261. }
  262. first = false
  263. if tool != nil {
  264. fmt.Fprintf(b, "%s: calls=%d", name, tool.CallCount)
  265. } else {
  266. fmt.Fprintf(b, "%s: calls=0", name)
  267. }
  268. }
  269. fmt.Fprintf(b, " }, ")
  270. }
  271. fmt.Fprintf(b, "}")
  272. return b.String()
  273. }
  274. // 定义支持流式选项的通道类型
  275. var streamSupportedChannels = map[int]bool{
  276. constant.ChannelTypeOpenAI: true,
  277. constant.ChannelTypeAnthropic: true,
  278. constant.ChannelTypeAws: true,
  279. constant.ChannelTypeGemini: true,
  280. constant.ChannelCloudflare: true,
  281. constant.ChannelTypeAzure: true,
  282. constant.ChannelTypeVolcEngine: true,
  283. constant.ChannelTypeOllama: true,
  284. constant.ChannelTypeXai: true,
  285. constant.ChannelTypeDeepSeek: true,
  286. constant.ChannelTypeBaiduV2: true,
  287. constant.ChannelTypeZhipu_v4: true,
  288. constant.ChannelTypeAli: true,
  289. constant.ChannelTypeSubmodel: true,
  290. constant.ChannelTypeCodex: true,
  291. constant.ChannelTypeMoonshot: true,
  292. constant.ChannelTypeMiniMax: true,
  293. constant.ChannelTypeSiliconFlow: true,
  294. }
  295. func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
  296. info := genBaseRelayInfo(c, nil)
  297. info.RelayFormat = types.RelayFormatOpenAIRealtime
  298. info.ClientWs = ws
  299. info.InputAudioFormat = "pcm16"
  300. info.OutputAudioFormat = "pcm16"
  301. info.IsFirstRequest = true
  302. return info
  303. }
  304. func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
  305. info := genBaseRelayInfo(c, request)
  306. info.RelayFormat = types.RelayFormatClaude
  307. info.ShouldIncludeUsage = false
  308. info.ClaudeConvertInfo = &ClaudeConvertInfo{
  309. LastMessagesType: LastMessageTypeNone,
  310. }
  311. info.IsClaudeBetaQuery = c.Query("beta") == "true"
  312. return info
  313. }
  314. func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
  315. info := genBaseRelayInfo(c, request)
  316. info.RelayMode = relayconstant.RelayModeRerank
  317. info.RelayFormat = types.RelayFormatRerank
  318. info.RerankerInfo = &RerankerInfo{
  319. Documents: request.Documents,
  320. ReturnDocuments: request.GetReturnDocuments(),
  321. }
  322. return info
  323. }
  324. func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo {
  325. info := genBaseRelayInfo(c, request)
  326. info.RelayFormat = types.RelayFormatOpenAIAudio
  327. return info
  328. }
  329. func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo {
  330. info := genBaseRelayInfo(c, request)
  331. info.RelayFormat = types.RelayFormatEmbedding
  332. return info
  333. }
  334. func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo {
  335. info := genBaseRelayInfo(c, request)
  336. info.RelayMode = relayconstant.RelayModeResponses
  337. info.RelayFormat = types.RelayFormatOpenAIResponses
  338. info.ResponsesUsageInfo = &ResponsesUsageInfo{
  339. BuiltInTools: make(map[string]*BuildInToolInfo),
  340. }
  341. if len(request.Tools) > 0 {
  342. for _, tool := range request.GetToolsMap() {
  343. toolType := common.Interface2String(tool["type"])
  344. info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
  345. ToolName: toolType,
  346. CallCount: 0,
  347. }
  348. switch toolType {
  349. case dto.BuildInToolWebSearchPreview:
  350. searchContextSize := common.Interface2String(tool["search_context_size"])
  351. if searchContextSize == "" {
  352. searchContextSize = "medium"
  353. }
  354. info.ResponsesUsageInfo.BuiltInTools[toolType].SearchContextSize = searchContextSize
  355. }
  356. }
  357. }
  358. return info
  359. }
  360. func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo {
  361. info := genBaseRelayInfo(c, request)
  362. info.RelayFormat = types.RelayFormatGemini
  363. info.ShouldIncludeUsage = false
  364. return info
  365. }
  366. func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo {
  367. info := genBaseRelayInfo(c, request)
  368. info.RelayFormat = types.RelayFormatOpenAIImage
  369. return info
  370. }
  371. func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo {
  372. info := genBaseRelayInfo(c, request)
  373. info.RelayFormat = types.RelayFormatOpenAI
  374. return info
  375. }
  376. func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
  377. //channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
  378. //channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
  379. //paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
  380. tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
  381. // 当令牌分组为空时,表示使用用户分组
  382. if tokenGroup == "" {
  383. tokenGroup = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
  384. }
  385. startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
  386. if startTime.IsZero() {
  387. startTime = time.Now()
  388. }
  389. isStream := false
  390. if request != nil {
  391. isStream = request.IsStream(c)
  392. }
  393. c.Set(string(constant.ContextKeyIsStream), isStream)
  394. // firstResponseTime = time.Now() - 1 second
  395. reqId := common.GetContextKeyString(c, common.RequestIdKey)
  396. if reqId == "" {
  397. reqId = common.GetTimeString() + common.GetRandomString(8)
  398. }
  399. info := &RelayInfo{
  400. Request: request,
  401. RequestId: reqId,
  402. UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId),
  403. UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
  404. UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
  405. UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
  406. UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
  407. OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
  408. TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
  409. TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
  410. TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
  411. TokenGroup: tokenGroup,
  412. isFirstResponse: true,
  413. RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
  414. RequestURLPath: c.Request.URL.String(),
  415. RequestHeaders: cloneRequestHeaders(c),
  416. IsStream: isStream,
  417. StartTime: startTime,
  418. FirstResponseTime: startTime.Add(-time.Second),
  419. ThinkingContentInfo: ThinkingContentInfo{
  420. IsFirstThinkingContent: true,
  421. SendLastThinkingContent: false,
  422. },
  423. TokenCountMeta: TokenCountMeta{
  424. //promptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
  425. estimatePromptTokens: common.GetContextKeyInt(c, constant.ContextKeyEstimatedTokens),
  426. },
  427. }
  428. if info.RelayMode == relayconstant.RelayModeUnknown {
  429. info.RelayMode = c.GetInt("relay_mode")
  430. }
  431. if strings.HasPrefix(c.Request.URL.Path, "/pg") {
  432. info.IsPlayground = true
  433. info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
  434. info.RequestURLPath = "/v1" + info.RequestURLPath
  435. }
  436. userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
  437. if ok {
  438. info.UserSetting = userSetting
  439. }
  440. return info
  441. }
  442. func cloneRequestHeaders(c *gin.Context) map[string]string {
  443. if c == nil || c.Request == nil {
  444. return nil
  445. }
  446. if len(c.Request.Header) == 0 {
  447. return nil
  448. }
  449. headers := make(map[string]string, len(c.Request.Header))
  450. for key := range c.Request.Header {
  451. value := strings.TrimSpace(c.Request.Header.Get(key))
  452. if value == "" {
  453. continue
  454. }
  455. headers[key] = value
  456. }
  457. if len(headers) == 0 {
  458. return nil
  459. }
  460. return headers
  461. }
  462. func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
  463. var info *RelayInfo
  464. var err error
  465. switch relayFormat {
  466. case types.RelayFormatOpenAI:
  467. info = GenRelayInfoOpenAI(c, request)
  468. case types.RelayFormatOpenAIAudio:
  469. info = GenRelayInfoOpenAIAudio(c, request)
  470. case types.RelayFormatOpenAIImage:
  471. info = GenRelayInfoImage(c, request)
  472. case types.RelayFormatOpenAIRealtime:
  473. info = GenRelayInfoWs(c, ws)
  474. case types.RelayFormatClaude:
  475. info = GenRelayInfoClaude(c, request)
  476. case types.RelayFormatRerank:
  477. if request, ok := request.(*dto.RerankRequest); ok {
  478. info = GenRelayInfoRerank(c, request)
  479. break
  480. }
  481. err = errors.New("request is not a RerankRequest")
  482. case types.RelayFormatGemini:
  483. info = GenRelayInfoGemini(c, request)
  484. case types.RelayFormatEmbedding:
  485. info = GenRelayInfoEmbedding(c, request)
  486. case types.RelayFormatOpenAIResponses:
  487. if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
  488. info = GenRelayInfoResponses(c, request)
  489. break
  490. }
  491. err = errors.New("request is not a OpenAIResponsesRequest")
  492. case types.RelayFormatOpenAIResponsesCompaction:
  493. if request, ok := request.(*dto.OpenAIResponsesCompactionRequest); ok {
  494. return GenRelayInfoResponsesCompaction(c, request), nil
  495. }
  496. return nil, errors.New("request is not a OpenAIResponsesCompactionRequest")
  497. case types.RelayFormatTask:
  498. info = genBaseRelayInfo(c, nil)
  499. info.TaskRelayInfo = &TaskRelayInfo{}
  500. case types.RelayFormatMjProxy:
  501. info = genBaseRelayInfo(c, nil)
  502. info.TaskRelayInfo = &TaskRelayInfo{}
  503. default:
  504. err = errors.New("invalid relay format")
  505. }
  506. if err != nil {
  507. return nil, err
  508. }
  509. if info == nil {
  510. return nil, errors.New("failed to build relay info")
  511. }
  512. info.InitRequestConversionChain()
  513. return info, nil
  514. }
  515. func (info *RelayInfo) InitRequestConversionChain() {
  516. if info == nil {
  517. return
  518. }
  519. if len(info.RequestConversionChain) > 0 {
  520. return
  521. }
  522. if info.RelayFormat == "" {
  523. return
  524. }
  525. info.RequestConversionChain = []types.RelayFormat{info.RelayFormat}
  526. }
  527. func (info *RelayInfo) AppendRequestConversion(format types.RelayFormat) {
  528. if info == nil {
  529. return
  530. }
  531. if format == "" {
  532. return
  533. }
  534. if len(info.RequestConversionChain) == 0 {
  535. info.RequestConversionChain = []types.RelayFormat{format}
  536. return
  537. }
  538. last := info.RequestConversionChain[len(info.RequestConversionChain)-1]
  539. if last == format {
  540. return
  541. }
  542. info.RequestConversionChain = append(info.RequestConversionChain, format)
  543. }
  544. func (info *RelayInfo) GetFinalRequestRelayFormat() types.RelayFormat {
  545. if info == nil {
  546. return ""
  547. }
  548. if info.FinalRequestRelayFormat != "" {
  549. return info.FinalRequestRelayFormat
  550. }
  551. if n := len(info.RequestConversionChain); n > 0 {
  552. return info.RequestConversionChain[n-1]
  553. }
  554. return info.RelayFormat
  555. }
  556. func GenRelayInfoResponsesCompaction(c *gin.Context, request *dto.OpenAIResponsesCompactionRequest) *RelayInfo {
  557. info := genBaseRelayInfo(c, request)
  558. if info.RelayMode == relayconstant.RelayModeUnknown {
  559. info.RelayMode = relayconstant.RelayModeResponsesCompact
  560. }
  561. info.RelayFormat = types.RelayFormatOpenAIResponsesCompaction
  562. return info
  563. }
  564. //func (info *RelayInfo) SetPromptTokens(promptTokens int) {
  565. // info.promptTokens = promptTokens
  566. //}
  567. func (info *RelayInfo) SetEstimatePromptTokens(promptTokens int) {
  568. info.estimatePromptTokens = promptTokens
  569. }
  570. func (info *RelayInfo) GetEstimatePromptTokens() int {
  571. return info.estimatePromptTokens
  572. }
  573. func (info *RelayInfo) SetFirstResponseTime() {
  574. if info.isFirstResponse {
  575. info.FirstResponseTime = time.Now()
  576. info.isFirstResponse = false
  577. }
  578. }
  579. func (info *RelayInfo) HasSendResponse() bool {
  580. return info.FirstResponseTime.After(info.StartTime)
  581. }
  582. type TaskRelayInfo struct {
  583. Action string
  584. OriginTaskID string
  585. // PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID,
  586. // 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。
  587. PublicTaskID string
  588. ConsumeQuota bool
  589. // LockedChannel holds the full channel object when the request is bound to
  590. // a specific channel (e.g., remix on origin task's channel). Stored as any
  591. // to avoid an import cycle with model; callers type-assert to *model.Channel.
  592. LockedChannel any
  593. }
  594. type TaskSubmitReq struct {
  595. Prompt string `json:"prompt"`
  596. Model string `json:"model,omitempty"`
  597. Mode string `json:"mode,omitempty"`
  598. Image string `json:"image,omitempty"`
  599. Images []string `json:"images,omitempty"`
  600. Size string `json:"size,omitempty"`
  601. Duration int `json:"duration,omitempty"`
  602. Seconds string `json:"seconds,omitempty"`
  603. InputReference string `json:"input_reference,omitempty"`
  604. Metadata map[string]interface{} `json:"metadata,omitempty"`
  605. }
  606. func (t *TaskSubmitReq) GetPrompt() string {
  607. return t.Prompt
  608. }
  609. func (t *TaskSubmitReq) HasImage() bool {
  610. return len(t.Images) > 0
  611. }
  612. func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
  613. type Alias TaskSubmitReq
  614. aux := &struct {
  615. Metadata json.RawMessage `json:"metadata,omitempty"`
  616. Duration json.RawMessage `json:"duration,omitempty"`
  617. *Alias
  618. }{
  619. Alias: (*Alias)(t),
  620. }
  621. if err := common.Unmarshal(data, &aux); err != nil {
  622. return err
  623. }
  624. if len(aux.Duration) > 0 {
  625. var durationInt int
  626. if err := common.Unmarshal(aux.Duration, &durationInt); err == nil {
  627. t.Duration = durationInt
  628. } else {
  629. var durationStr string
  630. if err := common.Unmarshal(aux.Duration, &durationStr); err == nil && durationStr != "" {
  631. if v, err := strconv.Atoi(durationStr); err == nil {
  632. t.Duration = v
  633. }
  634. }
  635. }
  636. }
  637. if len(aux.Metadata) > 0 {
  638. var metadataStr string
  639. if err := common.Unmarshal(aux.Metadata, &metadataStr); err == nil && metadataStr != "" {
  640. var metadataObj map[string]interface{}
  641. if err := common.Unmarshal([]byte(metadataStr), &metadataObj); err == nil {
  642. t.Metadata = metadataObj
  643. return nil
  644. }
  645. }
  646. var metadataObj map[string]interface{}
  647. if err := common.Unmarshal(aux.Metadata, &metadataObj); err == nil {
  648. t.Metadata = metadataObj
  649. }
  650. }
  651. return nil
  652. }
  653. func (t *TaskSubmitReq) UnmarshalMetadata(v any) error {
  654. metadata := t.Metadata
  655. if metadata != nil {
  656. metadataBytes, err := common.Marshal(metadata)
  657. if err != nil {
  658. return fmt.Errorf("marshal metadata failed: %w", err)
  659. }
  660. err = common.Unmarshal(metadataBytes, v)
  661. if err != nil {
  662. return fmt.Errorf("unmarshal metadata to target failed: %w", err)
  663. }
  664. }
  665. return nil
  666. }
  667. type TaskInfo struct {
  668. Code int `json:"code"`
  669. TaskID string `json:"task_id"`
  670. Status string `json:"status"`
  671. Reason string `json:"reason,omitempty"`
  672. Url string `json:"url,omitempty"`
  673. RemoteUrl string `json:"remote_url,omitempty"`
  674. Progress string `json:"progress,omitempty"`
  675. CompletionTokens int `json:"completion_tokens,omitempty"` // 用于按倍率计费
  676. TotalTokens int `json:"total_tokens,omitempty"` // 用于按倍率计费
  677. }
  678. func FailTaskInfo(reason string) *TaskInfo {
  679. return &TaskInfo{
  680. Status: "FAILURE",
  681. Reason: reason,
  682. }
  683. }
  684. // RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
  685. // service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持)
  686. // inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤)
  687. // speed: Claude 推理速度模式字段(仅 Claude 支持,默认过滤)
  688. // store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
  689. // safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私)
  690. // stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持)
  691. func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings, channelPassThroughEnabled bool) ([]byte, error) {
  692. if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
  693. return jsonData, nil
  694. }
  695. var data map[string]interface{}
  696. if err := common.Unmarshal(jsonData, &data); err != nil {
  697. common.SysError("RemoveDisabledFields Unmarshal error :" + err.Error())
  698. return jsonData, nil
  699. }
  700. // 默认移除 service_tier,除非明确允许(避免额外计费风险)
  701. if !channelOtherSettings.AllowServiceTier {
  702. if _, exists := data["service_tier"]; exists {
  703. delete(data, "service_tier")
  704. }
  705. }
  706. // 默认移除 inference_geo,除非明确允许(避免在未授权情况下透传数据驻留区域)
  707. if !channelOtherSettings.AllowInferenceGeo {
  708. if _, exists := data["inference_geo"]; exists {
  709. delete(data, "inference_geo")
  710. }
  711. }
  712. // 默认移除 speed,除非明确允许(避免意外切换 Claude 推理速度模式)
  713. if !channelOtherSettings.AllowSpeed {
  714. if _, exists := data["speed"]; exists {
  715. delete(data, "speed")
  716. }
  717. }
  718. // 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用)
  719. if channelOtherSettings.DisableStore {
  720. if _, exists := data["store"]; exists {
  721. delete(data, "store")
  722. }
  723. }
  724. // 默认移除 safety_identifier,除非明确允许(保护用户隐私,避免向 OpenAI 报告用户信息)
  725. if !channelOtherSettings.AllowSafetyIdentifier {
  726. if _, exists := data["safety_identifier"]; exists {
  727. delete(data, "safety_identifier")
  728. }
  729. }
  730. // 默认移除 stream_options.include_obfuscation,除非明确允许(避免关闭响应流混淆保护)
  731. if !channelOtherSettings.AllowIncludeObfuscation {
  732. if streamOptionsAny, exists := data["stream_options"]; exists {
  733. if streamOptions, ok := streamOptionsAny.(map[string]interface{}); ok {
  734. if _, includeExists := streamOptions["include_obfuscation"]; includeExists {
  735. delete(streamOptions, "include_obfuscation")
  736. }
  737. if len(streamOptions) == 0 {
  738. delete(data, "stream_options")
  739. } else {
  740. data["stream_options"] = streamOptions
  741. }
  742. }
  743. }
  744. }
  745. jsonDataAfter, err := common.Marshal(data)
  746. if err != nil {
  747. common.SysError("RemoveDisabledFields Marshal error :" + err.Error())
  748. return jsonData, nil
  749. }
  750. return jsonDataAfter, nil
  751. }
  752. // RemoveGeminiDisabledFields removes disabled fields from Gemini request JSON data
  753. // Currently supports removing functionResponse.id field which Vertex AI does not support
  754. func RemoveGeminiDisabledFields(jsonData []byte) ([]byte, error) {
  755. if !model_setting.GetGeminiSettings().RemoveFunctionResponseIdEnabled {
  756. return jsonData, nil
  757. }
  758. var data map[string]interface{}
  759. if err := common.Unmarshal(jsonData, &data); err != nil {
  760. common.SysError("RemoveGeminiDisabledFields Unmarshal error: " + err.Error())
  761. return jsonData, nil
  762. }
  763. // Process contents array
  764. // Handle both camelCase (functionResponse) and snake_case (function_response)
  765. if contents, ok := data["contents"].([]interface{}); ok {
  766. for _, content := range contents {
  767. if contentMap, ok := content.(map[string]interface{}); ok {
  768. if parts, ok := contentMap["parts"].([]interface{}); ok {
  769. for _, part := range parts {
  770. if partMap, ok := part.(map[string]interface{}); ok {
  771. // Check functionResponse (camelCase)
  772. if funcResp, ok := partMap["functionResponse"].(map[string]interface{}); ok {
  773. delete(funcResp, "id")
  774. }
  775. // Check function_response (snake_case)
  776. if funcResp, ok := partMap["function_response"].(map[string]interface{}); ok {
  777. delete(funcResp, "id")
  778. }
  779. }
  780. }
  781. }
  782. }
  783. }
  784. }
  785. jsonDataAfter, err := common.Marshal(data)
  786. if err != nil {
  787. common.SysError("RemoveGeminiDisabledFields Marshal error: " + err.Error())
  788. return jsonData, nil
  789. }
  790. return jsonDataAfter, nil
  791. }