relay_info.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. package common
  2. import (
  3. "errors"
  4. "one-api/common"
  5. "one-api/constant"
  6. "one-api/dto"
  7. relayconstant "one-api/relay/constant"
  8. "one-api/types"
  9. "strings"
  10. "time"
  11. "github.com/gin-gonic/gin"
  12. "github.com/gorilla/websocket"
  13. )
  14. type ThinkingContentInfo struct {
  15. IsFirstThinkingContent bool
  16. SendLastThinkingContent bool
  17. HasSentThinkingContent bool
  18. }
  19. const (
  20. LastMessageTypeNone = "none"
  21. LastMessageTypeText = "text"
  22. LastMessageTypeTools = "tools"
  23. LastMessageTypeThinking = "thinking"
  24. )
  25. type ClaudeConvertInfo struct {
  26. LastMessagesType string
  27. Index int
  28. Usage *dto.Usage
  29. FinishReason string
  30. Done bool
  31. }
  32. type RerankerInfo struct {
  33. Documents []any
  34. ReturnDocuments bool
  35. }
  36. type BuildInToolInfo struct {
  37. ToolName string
  38. CallCount int
  39. SearchContextSize string
  40. }
  41. type ResponsesUsageInfo struct {
  42. BuiltInTools map[string]*BuildInToolInfo
  43. }
  44. type ChannelMeta struct {
  45. ChannelType int
  46. ChannelId int
  47. ChannelIsMultiKey bool
  48. ChannelMultiKeyIndex int
  49. ChannelBaseUrl string
  50. ApiType int
  51. ApiVersion string
  52. ApiKey string
  53. Organization string
  54. ChannelCreateTime int64
  55. ParamOverride map[string]interface{}
  56. ChannelSetting dto.ChannelSettings
  57. ChannelOtherSettings dto.ChannelOtherSettings
  58. UpstreamModelName string
  59. IsModelMapped bool
  60. SupportStreamOptions bool // 是否支持流式选项
  61. }
  62. type RelayInfo struct {
  63. TokenId int
  64. TokenKey string
  65. UserId int
  66. UsingGroup string // 使用的分组
  67. UserGroup string // 用户所在分组
  68. TokenUnlimited bool
  69. StartTime time.Time
  70. FirstResponseTime time.Time
  71. isFirstResponse bool
  72. //SendLastReasoningResponse bool
  73. IsStream bool
  74. IsGeminiBatchEmbedding bool
  75. IsPlayground bool
  76. UsePrice bool
  77. RelayMode int
  78. OriginModelName string
  79. //RecodeModelName string
  80. RequestURLPath string
  81. PromptTokens int
  82. //SupportStreamOptions bool
  83. ShouldIncludeUsage bool
  84. DisablePing bool // 是否禁止向下游发送自定义 Ping
  85. ClientWs *websocket.Conn
  86. TargetWs *websocket.Conn
  87. InputAudioFormat string
  88. OutputAudioFormat string
  89. RealtimeTools []dto.RealTimeTool
  90. IsFirstRequest bool
  91. AudioUsage bool
  92. ReasoningEffort string
  93. UserSetting dto.UserSetting
  94. UserEmail string
  95. UserQuota int
  96. RelayFormat types.RelayFormat
  97. SendResponseCount int
  98. FinalPreConsumedQuota int // 最终预消耗的配额
  99. PriceData types.PriceData
  100. Request dto.Request
  101. ThinkingContentInfo
  102. *ClaudeConvertInfo
  103. *RerankerInfo
  104. *ResponsesUsageInfo
  105. *ChannelMeta
  106. }
  107. func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
  108. channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
  109. paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
  110. apiType, _ := common.ChannelType2APIType(channelType)
  111. channelMeta := &ChannelMeta{
  112. ChannelType: channelType,
  113. ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId),
  114. ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
  115. ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
  116. ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
  117. ApiType: apiType,
  118. ApiVersion: c.GetString("api_version"),
  119. ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
  120. Organization: c.GetString("channel_organization"),
  121. ChannelCreateTime: c.GetInt64("channel_create_time"),
  122. ParamOverride: paramOverride,
  123. UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
  124. IsModelMapped: false,
  125. SupportStreamOptions: false,
  126. }
  127. channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
  128. if ok {
  129. channelMeta.ChannelSetting = channelSetting
  130. }
  131. channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
  132. if ok {
  133. channelMeta.ChannelOtherSettings = channelOtherSettings
  134. }
  135. if streamSupportedChannels[channelMeta.ChannelType] {
  136. channelMeta.SupportStreamOptions = true
  137. }
  138. info.ChannelMeta = channelMeta
  139. }
  140. // 定义支持流式选项的通道类型
  141. var streamSupportedChannels = map[int]bool{
  142. constant.ChannelTypeOpenAI: true,
  143. constant.ChannelTypeAnthropic: true,
  144. constant.ChannelTypeAws: true,
  145. constant.ChannelTypeGemini: true,
  146. constant.ChannelCloudflare: true,
  147. constant.ChannelTypeAzure: true,
  148. constant.ChannelTypeVolcEngine: true,
  149. constant.ChannelTypeOllama: true,
  150. constant.ChannelTypeXai: true,
  151. constant.ChannelTypeDeepSeek: true,
  152. constant.ChannelTypeBaiduV2: true,
  153. }
  154. func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
  155. info := genBaseRelayInfo(c, nil)
  156. info.RelayFormat = types.RelayFormatOpenAIRealtime
  157. info.ClientWs = ws
  158. info.InputAudioFormat = "pcm16"
  159. info.OutputAudioFormat = "pcm16"
  160. info.IsFirstRequest = true
  161. return info
  162. }
  163. func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
  164. info := genBaseRelayInfo(c, request)
  165. info.RelayFormat = types.RelayFormatClaude
  166. info.ShouldIncludeUsage = false
  167. info.ClaudeConvertInfo = &ClaudeConvertInfo{
  168. LastMessagesType: LastMessageTypeNone,
  169. }
  170. return info
  171. }
  172. func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
  173. info := genBaseRelayInfo(c, request)
  174. info.RelayMode = relayconstant.RelayModeRerank
  175. info.RelayFormat = types.RelayFormatRerank
  176. info.RerankerInfo = &RerankerInfo{
  177. Documents: request.Documents,
  178. ReturnDocuments: request.GetReturnDocuments(),
  179. }
  180. return info
  181. }
  182. func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo {
  183. info := genBaseRelayInfo(c, request)
  184. info.RelayFormat = types.RelayFormatOpenAIAudio
  185. return info
  186. }
  187. func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo {
  188. info := genBaseRelayInfo(c, request)
  189. info.RelayFormat = types.RelayFormatEmbedding
  190. return info
  191. }
  192. func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo {
  193. info := genBaseRelayInfo(c, request)
  194. info.RelayMode = relayconstant.RelayModeResponses
  195. info.RelayFormat = types.RelayFormatOpenAIResponses
  196. info.SupportStreamOptions = false
  197. info.ResponsesUsageInfo = &ResponsesUsageInfo{
  198. BuiltInTools: make(map[string]*BuildInToolInfo),
  199. }
  200. if len(request.Tools) > 0 {
  201. for _, tool := range request.Tools {
  202. toolType := common.Interface2String(tool["type"])
  203. info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
  204. ToolName: toolType,
  205. CallCount: 0,
  206. }
  207. switch toolType {
  208. case dto.BuildInToolWebSearchPreview:
  209. searchContextSize := common.Interface2String(tool["search_context_size"])
  210. if searchContextSize == "" {
  211. searchContextSize = "medium"
  212. }
  213. info.ResponsesUsageInfo.BuiltInTools[toolType].SearchContextSize = searchContextSize
  214. }
  215. }
  216. }
  217. return info
  218. }
  219. func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo {
  220. info := genBaseRelayInfo(c, request)
  221. info.RelayFormat = types.RelayFormatGemini
  222. info.ShouldIncludeUsage = false
  223. return info
  224. }
  225. func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo {
  226. info := genBaseRelayInfo(c, request)
  227. info.RelayFormat = types.RelayFormatOpenAIImage
  228. return info
  229. }
  230. func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo {
  231. info := genBaseRelayInfo(c, request)
  232. info.RelayFormat = types.RelayFormatOpenAI
  233. return info
  234. }
  235. func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
  236. //channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
  237. //channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
  238. //paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
  239. startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
  240. if startTime.IsZero() {
  241. startTime = time.Now()
  242. }
  243. isStream := false
  244. if request != nil {
  245. isStream = request.IsStream(c)
  246. }
  247. // firstResponseTime = time.Now() - 1 second
  248. info := &RelayInfo{
  249. Request: request,
  250. UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId),
  251. UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
  252. UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
  253. UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
  254. UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
  255. OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
  256. PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
  257. TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
  258. TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
  259. TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
  260. isFirstResponse: true,
  261. RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
  262. RequestURLPath: c.Request.URL.String(),
  263. IsStream: isStream,
  264. StartTime: startTime,
  265. FirstResponseTime: startTime.Add(-time.Second),
  266. ThinkingContentInfo: ThinkingContentInfo{
  267. IsFirstThinkingContent: true,
  268. SendLastThinkingContent: false,
  269. },
  270. }
  271. if strings.HasPrefix(c.Request.URL.Path, "/pg") {
  272. info.IsPlayground = true
  273. info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
  274. info.RequestURLPath = "/v1" + info.RequestURLPath
  275. }
  276. userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
  277. if ok {
  278. info.UserSetting = userSetting
  279. }
  280. return info
  281. }
  282. func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
  283. switch relayFormat {
  284. case types.RelayFormatOpenAI:
  285. return GenRelayInfoOpenAI(c, request), nil
  286. case types.RelayFormatOpenAIAudio:
  287. return GenRelayInfoOpenAIAudio(c, request), nil
  288. case types.RelayFormatOpenAIImage:
  289. return GenRelayInfoImage(c, request), nil
  290. case types.RelayFormatOpenAIRealtime:
  291. return GenRelayInfoWs(c, ws), nil
  292. case types.RelayFormatClaude:
  293. return GenRelayInfoClaude(c, request), nil
  294. case types.RelayFormatRerank:
  295. if request, ok := request.(*dto.RerankRequest); ok {
  296. return GenRelayInfoRerank(c, request), nil
  297. }
  298. return nil, errors.New("request is not a RerankRequest")
  299. case types.RelayFormatGemini:
  300. return GenRelayInfoGemini(c, request), nil
  301. case types.RelayFormatEmbedding:
  302. return GenRelayInfoEmbedding(c, request), nil
  303. case types.RelayFormatOpenAIResponses:
  304. if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
  305. return GenRelayInfoResponses(c, request), nil
  306. }
  307. return nil, errors.New("request is not a OpenAIResponsesRequest")
  308. case types.RelayFormatTask:
  309. return genBaseRelayInfo(c, nil), nil
  310. case types.RelayFormatMjProxy:
  311. return genBaseRelayInfo(c, nil), nil
  312. default:
  313. return nil, errors.New("invalid relay format")
  314. }
  315. }
  316. func (info *RelayInfo) SetPromptTokens(promptTokens int) {
  317. info.PromptTokens = promptTokens
  318. }
  319. func (info *RelayInfo) SetFirstResponseTime() {
  320. if info.isFirstResponse {
  321. info.FirstResponseTime = time.Now()
  322. info.isFirstResponse = false
  323. }
  324. }
  325. func (info *RelayInfo) HasSendResponse() bool {
  326. return info.FirstResponseTime.After(info.StartTime)
  327. }
  328. type TaskRelayInfo struct {
  329. *RelayInfo
  330. Action string
  331. OriginTaskID string
  332. ConsumeQuota bool
  333. }
  334. func GenTaskRelayInfo(c *gin.Context) (*TaskRelayInfo, error) {
  335. relayInfo, err := GenRelayInfo(c, types.RelayFormatTask, nil, nil)
  336. if err != nil {
  337. return nil, err
  338. }
  339. info := &TaskRelayInfo{
  340. RelayInfo: relayInfo,
  341. }
  342. return info, nil
  343. }
  344. type TaskSubmitReq struct {
  345. Prompt string `json:"prompt"`
  346. Model string `json:"model,omitempty"`
  347. Mode string `json:"mode,omitempty"`
  348. Image string `json:"image,omitempty"`
  349. Size string `json:"size,omitempty"`
  350. Duration int `json:"duration,omitempty"`
  351. Metadata map[string]interface{} `json:"metadata,omitempty"`
  352. }
  353. type TaskInfo struct {
  354. Code int `json:"code"`
  355. TaskID string `json:"task_id"`
  356. Status string `json:"status"`
  357. Reason string `json:"reason,omitempty"`
  358. Url string `json:"url,omitempty"`
  359. Progress string `json:"progress,omitempty"`
  360. }