relay_info.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. package common
  2. import (
  3. "errors"
  4. "fmt"
  5. "one-api/common"
  6. "one-api/constant"
  7. "one-api/dto"
  8. relayconstant "one-api/relay/constant"
  9. "one-api/types"
  10. "strings"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/gorilla/websocket"
  14. )
  15. type ThinkingContentInfo struct {
  16. IsFirstThinkingContent bool
  17. SendLastThinkingContent bool
  18. HasSentThinkingContent bool
  19. }
  20. const (
  21. LastMessageTypeNone = "none"
  22. LastMessageTypeText = "text"
  23. LastMessageTypeTools = "tools"
  24. LastMessageTypeThinking = "thinking"
  25. )
  26. type ClaudeConvertInfo struct {
  27. LastMessagesType string
  28. Index int
  29. Usage *dto.Usage
  30. FinishReason string
  31. Done bool
  32. }
  33. type RerankerInfo struct {
  34. Documents []any
  35. ReturnDocuments bool
  36. }
  37. type BuildInToolInfo struct {
  38. ToolName string
  39. CallCount int
  40. SearchContextSize string
  41. }
  42. type ResponsesUsageInfo struct {
  43. BuiltInTools map[string]*BuildInToolInfo
  44. }
  45. type ChannelMeta struct {
  46. ChannelType int
  47. ChannelId int
  48. ChannelIsMultiKey bool
  49. ChannelMultiKeyIndex int
  50. ChannelBaseUrl string
  51. ApiType int
  52. ApiVersion string
  53. ApiKey string
  54. Organization string
  55. ChannelCreateTime int64
  56. ParamOverride map[string]interface{}
  57. ChannelSetting dto.ChannelSettings
  58. ChannelOtherSettings dto.ChannelOtherSettings
  59. UpstreamModelName string
  60. IsModelMapped bool
  61. SupportStreamOptions bool // 是否支持流式选项
  62. }
  63. type RelayInfo struct {
  64. TokenId int
  65. TokenKey string
  66. UserId int
  67. UsingGroup string // 使用的分组
  68. UserGroup string // 用户所在分组
  69. TokenUnlimited bool
  70. StartTime time.Time
  71. FirstResponseTime time.Time
  72. isFirstResponse bool
  73. //SendLastReasoningResponse bool
  74. IsStream bool
  75. IsGeminiBatchEmbedding bool
  76. IsPlayground bool
  77. UsePrice bool
  78. RelayMode int
  79. OriginModelName string
  80. RequestURLPath string
  81. PromptTokens int
  82. ShouldIncludeUsage bool
  83. DisablePing bool // 是否禁止向下游发送自定义 Ping
  84. ClientWs *websocket.Conn
  85. TargetWs *websocket.Conn
  86. InputAudioFormat string
  87. OutputAudioFormat string
  88. RealtimeTools []dto.RealTimeTool
  89. IsFirstRequest bool
  90. AudioUsage bool
  91. ReasoningEffort string
  92. UserSetting dto.UserSetting
  93. UserEmail string
  94. UserQuota int
  95. RelayFormat types.RelayFormat
  96. SendResponseCount int
  97. FinalPreConsumedQuota int // 最终预消耗的配额
  98. PriceData types.PriceData
  99. Request dto.Request
  100. ThinkingContentInfo
  101. *ClaudeConvertInfo
  102. *RerankerInfo
  103. *ResponsesUsageInfo
  104. *ChannelMeta
  105. }
  106. func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
  107. channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
  108. paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
  109. apiType, _ := common.ChannelType2APIType(channelType)
  110. channelMeta := &ChannelMeta{
  111. ChannelType: channelType,
  112. ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId),
  113. ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
  114. ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
  115. ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
  116. ApiType: apiType,
  117. ApiVersion: c.GetString("api_version"),
  118. ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
  119. Organization: c.GetString("channel_organization"),
  120. ChannelCreateTime: c.GetInt64("channel_create_time"),
  121. ParamOverride: paramOverride,
  122. UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
  123. IsModelMapped: false,
  124. SupportStreamOptions: false,
  125. }
  126. if channelType == constant.ChannelTypeAzure {
  127. channelMeta.ApiVersion = GetAPIVersion(c)
  128. }
  129. if channelType == constant.ChannelTypeVertexAi {
  130. channelMeta.ApiVersion = c.GetString("region")
  131. }
  132. channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
  133. if ok {
  134. channelMeta.ChannelSetting = channelSetting
  135. }
  136. channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
  137. if ok {
  138. channelMeta.ChannelOtherSettings = channelOtherSettings
  139. }
  140. if streamSupportedChannels[channelMeta.ChannelType] {
  141. channelMeta.SupportStreamOptions = true
  142. }
  143. info.ChannelMeta = channelMeta
  144. // reset some fields based on channel meta
  145. // 重置某些字段,例如模型名称等
  146. if info.Request != nil {
  147. info.Request.SetModelName(info.OriginModelName)
  148. }
  149. }
  150. func (info *RelayInfo) ToString() string {
  151. if info == nil {
  152. return "RelayInfo<nil>"
  153. }
  154. // Basic info
  155. b := &strings.Builder{}
  156. fmt.Fprintf(b, "RelayInfo{ ")
  157. fmt.Fprintf(b, "RelayFormat: %s, ", info.RelayFormat)
  158. fmt.Fprintf(b, "RelayMode: %d, ", info.RelayMode)
  159. fmt.Fprintf(b, "IsStream: %t, ", info.IsStream)
  160. fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground)
  161. fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath)
  162. fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName)
  163. fmt.Fprintf(b, "PromptTokens: %d, ", info.PromptTokens)
  164. fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage)
  165. fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing)
  166. fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount)
  167. fmt.Fprintf(b, "FinalPreConsumedQuota: %d, ", info.FinalPreConsumedQuota)
  168. // User & token info (mask secrets)
  169. fmt.Fprintf(b, "User{ Id: %d, Email: %q, Group: %q, UsingGroup: %q, Quota: %d }, ",
  170. info.UserId, common.MaskEmail(info.UserEmail), info.UserGroup, info.UsingGroup, info.UserQuota)
  171. fmt.Fprintf(b, "Token{ Id: %d, Unlimited: %t, Key: ***masked*** }, ", info.TokenId, info.TokenUnlimited)
  172. // Time info
  173. latencyMs := info.FirstResponseTime.Sub(info.StartTime).Milliseconds()
  174. fmt.Fprintf(b, "Timing{ Start: %s, FirstResponse: %s, LatencyMs: %d }, ",
  175. info.StartTime.Format(time.RFC3339Nano), info.FirstResponseTime.Format(time.RFC3339Nano), latencyMs)
  176. // Audio / realtime
  177. if info.InputAudioFormat != "" || info.OutputAudioFormat != "" || len(info.RealtimeTools) > 0 || info.AudioUsage {
  178. fmt.Fprintf(b, "Realtime{ AudioUsage: %t, InFmt: %q, OutFmt: %q, Tools: %d }, ",
  179. info.AudioUsage, info.InputAudioFormat, info.OutputAudioFormat, len(info.RealtimeTools))
  180. }
  181. // Reasoning
  182. if info.ReasoningEffort != "" {
  183. fmt.Fprintf(b, "ReasoningEffort: %q, ", info.ReasoningEffort)
  184. }
  185. // Price data (non-sensitive)
  186. if info.PriceData.UsePrice {
  187. fmt.Fprintf(b, "PriceData{ %s }, ", info.PriceData.ToSetting())
  188. }
  189. // Channel metadata (mask ApiKey)
  190. if info.ChannelMeta != nil {
  191. cm := info.ChannelMeta
  192. fmt.Fprintf(b, "ChannelMeta{ Type: %d, Id: %d, IsMultiKey: %t, MultiKeyIndex: %d, BaseURL: %q, ApiType: %d, ApiVersion: %q, Organization: %q, CreateTime: %d, UpstreamModelName: %q, IsModelMapped: %t, SupportStreamOptions: %t, ApiKey: ***masked*** }, ",
  193. cm.ChannelType, cm.ChannelId, cm.ChannelIsMultiKey, cm.ChannelMultiKeyIndex, cm.ChannelBaseUrl, cm.ApiType, cm.ApiVersion, cm.Organization, cm.ChannelCreateTime, cm.UpstreamModelName, cm.IsModelMapped, cm.SupportStreamOptions)
  194. }
  195. // Responses usage info (non-sensitive)
  196. if info.ResponsesUsageInfo != nil && len(info.ResponsesUsageInfo.BuiltInTools) > 0 {
  197. fmt.Fprintf(b, "ResponsesTools{ ")
  198. first := true
  199. for name, tool := range info.ResponsesUsageInfo.BuiltInTools {
  200. if !first {
  201. fmt.Fprintf(b, ", ")
  202. }
  203. first = false
  204. if tool != nil {
  205. fmt.Fprintf(b, "%s: calls=%d", name, tool.CallCount)
  206. } else {
  207. fmt.Fprintf(b, "%s: calls=0", name)
  208. }
  209. }
  210. fmt.Fprintf(b, " }, ")
  211. }
  212. fmt.Fprintf(b, "}")
  213. return b.String()
  214. }
  215. // 定义支持流式选项的通道类型
  216. var streamSupportedChannels = map[int]bool{
  217. constant.ChannelTypeOpenAI: true,
  218. constant.ChannelTypeAnthropic: true,
  219. constant.ChannelTypeAws: true,
  220. constant.ChannelTypeGemini: true,
  221. constant.ChannelCloudflare: true,
  222. constant.ChannelTypeAzure: true,
  223. constant.ChannelTypeVolcEngine: true,
  224. constant.ChannelTypeOllama: true,
  225. constant.ChannelTypeXai: true,
  226. constant.ChannelTypeDeepSeek: true,
  227. constant.ChannelTypeBaiduV2: true,
  228. }
  229. func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
  230. info := genBaseRelayInfo(c, nil)
  231. info.RelayFormat = types.RelayFormatOpenAIRealtime
  232. info.ClientWs = ws
  233. info.InputAudioFormat = "pcm16"
  234. info.OutputAudioFormat = "pcm16"
  235. info.IsFirstRequest = true
  236. return info
  237. }
  238. func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
  239. info := genBaseRelayInfo(c, request)
  240. info.RelayFormat = types.RelayFormatClaude
  241. info.ShouldIncludeUsage = false
  242. info.ClaudeConvertInfo = &ClaudeConvertInfo{
  243. LastMessagesType: LastMessageTypeNone,
  244. }
  245. return info
  246. }
  247. func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
  248. info := genBaseRelayInfo(c, request)
  249. info.RelayMode = relayconstant.RelayModeRerank
  250. info.RelayFormat = types.RelayFormatRerank
  251. info.RerankerInfo = &RerankerInfo{
  252. Documents: request.Documents,
  253. ReturnDocuments: request.GetReturnDocuments(),
  254. }
  255. return info
  256. }
  257. func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo {
  258. info := genBaseRelayInfo(c, request)
  259. info.RelayFormat = types.RelayFormatOpenAIAudio
  260. return info
  261. }
  262. func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo {
  263. info := genBaseRelayInfo(c, request)
  264. info.RelayFormat = types.RelayFormatEmbedding
  265. return info
  266. }
  267. func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo {
  268. info := genBaseRelayInfo(c, request)
  269. info.RelayMode = relayconstant.RelayModeResponses
  270. info.RelayFormat = types.RelayFormatOpenAIResponses
  271. info.ResponsesUsageInfo = &ResponsesUsageInfo{
  272. BuiltInTools: make(map[string]*BuildInToolInfo),
  273. }
  274. if len(request.Tools) > 0 {
  275. for _, tool := range request.Tools {
  276. toolType := common.Interface2String(tool["type"])
  277. info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
  278. ToolName: toolType,
  279. CallCount: 0,
  280. }
  281. switch toolType {
  282. case dto.BuildInToolWebSearchPreview:
  283. searchContextSize := common.Interface2String(tool["search_context_size"])
  284. if searchContextSize == "" {
  285. searchContextSize = "medium"
  286. }
  287. info.ResponsesUsageInfo.BuiltInTools[toolType].SearchContextSize = searchContextSize
  288. }
  289. }
  290. }
  291. return info
  292. }
  293. func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo {
  294. info := genBaseRelayInfo(c, request)
  295. info.RelayFormat = types.RelayFormatGemini
  296. info.ShouldIncludeUsage = false
  297. return info
  298. }
  299. func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo {
  300. info := genBaseRelayInfo(c, request)
  301. info.RelayFormat = types.RelayFormatOpenAIImage
  302. return info
  303. }
  304. func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo {
  305. info := genBaseRelayInfo(c, request)
  306. info.RelayFormat = types.RelayFormatOpenAI
  307. return info
  308. }
  309. func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
  310. //channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
  311. //channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
  312. //paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
  313. startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
  314. if startTime.IsZero() {
  315. startTime = time.Now()
  316. }
  317. isStream := false
  318. if request != nil {
  319. isStream = request.IsStream(c)
  320. }
  321. // firstResponseTime = time.Now() - 1 second
  322. info := &RelayInfo{
  323. Request: request,
  324. UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId),
  325. UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
  326. UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
  327. UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
  328. UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
  329. OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
  330. PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
  331. TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
  332. TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
  333. TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
  334. isFirstResponse: true,
  335. RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
  336. RequestURLPath: c.Request.URL.String(),
  337. IsStream: isStream,
  338. StartTime: startTime,
  339. FirstResponseTime: startTime.Add(-time.Second),
  340. ThinkingContentInfo: ThinkingContentInfo{
  341. IsFirstThinkingContent: true,
  342. SendLastThinkingContent: false,
  343. },
  344. }
  345. if strings.HasPrefix(c.Request.URL.Path, "/pg") {
  346. info.IsPlayground = true
  347. info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
  348. info.RequestURLPath = "/v1" + info.RequestURLPath
  349. }
  350. userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
  351. if ok {
  352. info.UserSetting = userSetting
  353. }
  354. return info
  355. }
  356. func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
  357. switch relayFormat {
  358. case types.RelayFormatOpenAI:
  359. return GenRelayInfoOpenAI(c, request), nil
  360. case types.RelayFormatOpenAIAudio:
  361. return GenRelayInfoOpenAIAudio(c, request), nil
  362. case types.RelayFormatOpenAIImage:
  363. return GenRelayInfoImage(c, request), nil
  364. case types.RelayFormatOpenAIRealtime:
  365. return GenRelayInfoWs(c, ws), nil
  366. case types.RelayFormatClaude:
  367. return GenRelayInfoClaude(c, request), nil
  368. case types.RelayFormatRerank:
  369. if request, ok := request.(*dto.RerankRequest); ok {
  370. return GenRelayInfoRerank(c, request), nil
  371. }
  372. return nil, errors.New("request is not a RerankRequest")
  373. case types.RelayFormatGemini:
  374. return GenRelayInfoGemini(c, request), nil
  375. case types.RelayFormatEmbedding:
  376. return GenRelayInfoEmbedding(c, request), nil
  377. case types.RelayFormatOpenAIResponses:
  378. if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
  379. return GenRelayInfoResponses(c, request), nil
  380. }
  381. return nil, errors.New("request is not a OpenAIResponsesRequest")
  382. case types.RelayFormatTask:
  383. return genBaseRelayInfo(c, nil), nil
  384. case types.RelayFormatMjProxy:
  385. return genBaseRelayInfo(c, nil), nil
  386. default:
  387. return nil, errors.New("invalid relay format")
  388. }
  389. }
  390. func (info *RelayInfo) SetPromptTokens(promptTokens int) {
  391. info.PromptTokens = promptTokens
  392. }
  393. func (info *RelayInfo) SetFirstResponseTime() {
  394. if info.isFirstResponse {
  395. info.FirstResponseTime = time.Now()
  396. info.isFirstResponse = false
  397. }
  398. }
  399. func (info *RelayInfo) HasSendResponse() bool {
  400. return info.FirstResponseTime.After(info.StartTime)
  401. }
  402. type TaskRelayInfo struct {
  403. *RelayInfo
  404. Action string
  405. OriginTaskID string
  406. ConsumeQuota bool
  407. }
  408. func GenTaskRelayInfo(c *gin.Context) (*TaskRelayInfo, error) {
  409. relayInfo, err := GenRelayInfo(c, types.RelayFormatTask, nil, nil)
  410. if err != nil {
  411. return nil, err
  412. }
  413. info := &TaskRelayInfo{
  414. RelayInfo: relayInfo,
  415. }
  416. return info, nil
  417. }
  418. type TaskSubmitReq struct {
  419. Prompt string `json:"prompt"`
  420. Model string `json:"model,omitempty"`
  421. Mode string `json:"mode,omitempty"`
  422. Image string `json:"image,omitempty"`
  423. Size string `json:"size,omitempty"`
  424. Duration int `json:"duration,omitempty"`
  425. Metadata map[string]interface{} `json:"metadata,omitempty"`
  426. }
  427. type TaskInfo struct {
  428. Code int `json:"code"`
  429. TaskID string `json:"task_id"`
  430. Status string `json:"status"`
  431. Reason string `json:"reason,omitempty"`
  432. Url string `json:"url,omitempty"`
  433. Progress string `json:"progress,omitempty"`
  434. }