subscription.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. package controller
  2. import (
  3. "strconv"
  4. "strings"
  5. "github.com/QuantumNous/new-api/common"
  6. "github.com/QuantumNous/new-api/model"
  7. "github.com/QuantumNous/new-api/setting/ratio_setting"
  8. "github.com/gin-gonic/gin"
  9. "gorm.io/gorm"
  10. )
  11. // ---- Shared types ----
  12. type SubscriptionPlanDTO struct {
  13. Plan model.SubscriptionPlan `json:"plan"`
  14. }
  15. type BillingPreferenceRequest struct {
  16. BillingPreference string `json:"billing_preference"`
  17. }
  18. // ---- User APIs ----
  19. func GetSubscriptionPlans(c *gin.Context) {
  20. var plans []model.SubscriptionPlan
  21. if err := model.DB.Where("enabled = ?", true).Order("sort_order desc, id desc").Find(&plans).Error; err != nil {
  22. common.ApiError(c, err)
  23. return
  24. }
  25. result := make([]SubscriptionPlanDTO, 0, len(plans))
  26. for _, p := range plans {
  27. result = append(result, SubscriptionPlanDTO{
  28. Plan: p,
  29. })
  30. }
  31. common.ApiSuccess(c, result)
  32. }
  33. func GetSubscriptionSelf(c *gin.Context) {
  34. userId := c.GetInt("id")
  35. settingMap, _ := model.GetUserSetting(userId, false)
  36. pref := common.NormalizeBillingPreference(settingMap.BillingPreference)
  37. // Get all subscriptions (including expired)
  38. allSubscriptions, err := model.GetAllUserSubscriptions(userId)
  39. if err != nil {
  40. allSubscriptions = []model.SubscriptionSummary{}
  41. }
  42. // Get active subscriptions for backward compatibility
  43. activeSubscriptions, err := model.GetAllActiveUserSubscriptions(userId)
  44. if err != nil {
  45. activeSubscriptions = []model.SubscriptionSummary{}
  46. }
  47. common.ApiSuccess(c, gin.H{
  48. "billing_preference": pref,
  49. "subscriptions": activeSubscriptions, // all active subscriptions
  50. "all_subscriptions": allSubscriptions, // all subscriptions including expired
  51. })
  52. }
  53. func UpdateSubscriptionPreference(c *gin.Context) {
  54. userId := c.GetInt("id")
  55. var req BillingPreferenceRequest
  56. if err := c.ShouldBindJSON(&req); err != nil {
  57. common.ApiErrorMsg(c, "参数错误")
  58. return
  59. }
  60. pref := common.NormalizeBillingPreference(req.BillingPreference)
  61. user, err := model.GetUserById(userId, true)
  62. if err != nil {
  63. common.ApiError(c, err)
  64. return
  65. }
  66. current := user.GetSetting()
  67. current.BillingPreference = pref
  68. user.SetSetting(current)
  69. if err := user.Update(false); err != nil {
  70. common.ApiError(c, err)
  71. return
  72. }
  73. common.ApiSuccess(c, gin.H{"billing_preference": pref})
  74. }
  75. // ---- Admin APIs ----
  76. func AdminListSubscriptionPlans(c *gin.Context) {
  77. var plans []model.SubscriptionPlan
  78. if err := model.DB.Order("sort_order desc, id desc").Find(&plans).Error; err != nil {
  79. common.ApiError(c, err)
  80. return
  81. }
  82. result := make([]SubscriptionPlanDTO, 0, len(plans))
  83. for _, p := range plans {
  84. result = append(result, SubscriptionPlanDTO{
  85. Plan: p,
  86. })
  87. }
  88. common.ApiSuccess(c, result)
  89. }
  90. type AdminUpsertSubscriptionPlanRequest struct {
  91. Plan model.SubscriptionPlan `json:"plan"`
  92. }
  93. func AdminCreateSubscriptionPlan(c *gin.Context) {
  94. var req AdminUpsertSubscriptionPlanRequest
  95. if err := c.ShouldBindJSON(&req); err != nil {
  96. common.ApiErrorMsg(c, "参数错误")
  97. return
  98. }
  99. req.Plan.Id = 0
  100. if strings.TrimSpace(req.Plan.Title) == "" {
  101. common.ApiErrorMsg(c, "套餐标题不能为空")
  102. return
  103. }
  104. if req.Plan.Currency == "" {
  105. req.Plan.Currency = "USD"
  106. }
  107. req.Plan.Currency = "USD"
  108. if req.Plan.DurationUnit == "" {
  109. req.Plan.DurationUnit = model.SubscriptionDurationMonth
  110. }
  111. if req.Plan.DurationValue <= 0 && req.Plan.DurationUnit != model.SubscriptionDurationCustom {
  112. req.Plan.DurationValue = 1
  113. }
  114. if req.Plan.MaxPurchasePerUser < 0 {
  115. common.ApiErrorMsg(c, "购买上限不能为负数")
  116. return
  117. }
  118. if req.Plan.TotalAmount < 0 {
  119. common.ApiErrorMsg(c, "总额度不能为负数")
  120. return
  121. }
  122. req.Plan.UpgradeGroup = strings.TrimSpace(req.Plan.UpgradeGroup)
  123. if req.Plan.UpgradeGroup != "" {
  124. if _, ok := ratio_setting.GetGroupRatioCopy()[req.Plan.UpgradeGroup]; !ok {
  125. common.ApiErrorMsg(c, "升级分组不存在")
  126. return
  127. }
  128. }
  129. req.Plan.QuotaResetPeriod = model.NormalizeResetPeriod(req.Plan.QuotaResetPeriod)
  130. if req.Plan.QuotaResetPeriod == model.SubscriptionResetCustom && req.Plan.QuotaResetCustomSeconds <= 0 {
  131. common.ApiErrorMsg(c, "自定义重置周期需大于0秒")
  132. return
  133. }
  134. err := model.DB.Create(&req.Plan).Error
  135. if err != nil {
  136. common.ApiError(c, err)
  137. return
  138. }
  139. model.InvalidateSubscriptionPlanCache(req.Plan.Id)
  140. common.ApiSuccess(c, req.Plan)
  141. }
  142. func AdminUpdateSubscriptionPlan(c *gin.Context) {
  143. id, _ := strconv.Atoi(c.Param("id"))
  144. if id <= 0 {
  145. common.ApiErrorMsg(c, "无效的ID")
  146. return
  147. }
  148. var req AdminUpsertSubscriptionPlanRequest
  149. if err := c.ShouldBindJSON(&req); err != nil {
  150. common.ApiErrorMsg(c, "参数错误")
  151. return
  152. }
  153. if strings.TrimSpace(req.Plan.Title) == "" {
  154. common.ApiErrorMsg(c, "套餐标题不能为空")
  155. return
  156. }
  157. req.Plan.Id = id
  158. if req.Plan.Currency == "" {
  159. req.Plan.Currency = "USD"
  160. }
  161. req.Plan.Currency = "USD"
  162. if req.Plan.DurationUnit == "" {
  163. req.Plan.DurationUnit = model.SubscriptionDurationMonth
  164. }
  165. if req.Plan.DurationValue <= 0 && req.Plan.DurationUnit != model.SubscriptionDurationCustom {
  166. req.Plan.DurationValue = 1
  167. }
  168. if req.Plan.MaxPurchasePerUser < 0 {
  169. common.ApiErrorMsg(c, "购买上限不能为负数")
  170. return
  171. }
  172. if req.Plan.TotalAmount < 0 {
  173. common.ApiErrorMsg(c, "总额度不能为负数")
  174. return
  175. }
  176. req.Plan.UpgradeGroup = strings.TrimSpace(req.Plan.UpgradeGroup)
  177. if req.Plan.UpgradeGroup != "" {
  178. if _, ok := ratio_setting.GetGroupRatioCopy()[req.Plan.UpgradeGroup]; !ok {
  179. common.ApiErrorMsg(c, "升级分组不存在")
  180. return
  181. }
  182. }
  183. req.Plan.QuotaResetPeriod = model.NormalizeResetPeriod(req.Plan.QuotaResetPeriod)
  184. if req.Plan.QuotaResetPeriod == model.SubscriptionResetCustom && req.Plan.QuotaResetCustomSeconds <= 0 {
  185. common.ApiErrorMsg(c, "自定义重置周期需大于0秒")
  186. return
  187. }
  188. err := model.DB.Transaction(func(tx *gorm.DB) error {
  189. // update plan (allow zero values updates with map)
  190. updateMap := map[string]interface{}{
  191. "title": req.Plan.Title,
  192. "subtitle": req.Plan.Subtitle,
  193. "price_amount": req.Plan.PriceAmount,
  194. "currency": req.Plan.Currency,
  195. "duration_unit": req.Plan.DurationUnit,
  196. "duration_value": req.Plan.DurationValue,
  197. "custom_seconds": req.Plan.CustomSeconds,
  198. "enabled": req.Plan.Enabled,
  199. "sort_order": req.Plan.SortOrder,
  200. "stripe_price_id": req.Plan.StripePriceId,
  201. "creem_product_id": req.Plan.CreemProductId,
  202. "max_purchase_per_user": req.Plan.MaxPurchasePerUser,
  203. "total_amount": req.Plan.TotalAmount,
  204. "upgrade_group": req.Plan.UpgradeGroup,
  205. "quota_reset_period": req.Plan.QuotaResetPeriod,
  206. "quota_reset_custom_seconds": req.Plan.QuotaResetCustomSeconds,
  207. "updated_at": common.GetTimestamp(),
  208. }
  209. if err := tx.Model(&model.SubscriptionPlan{}).Where("id = ?", id).Updates(updateMap).Error; err != nil {
  210. return err
  211. }
  212. return nil
  213. })
  214. if err != nil {
  215. common.ApiError(c, err)
  216. return
  217. }
  218. model.InvalidateSubscriptionPlanCache(id)
  219. common.ApiSuccess(c, nil)
  220. }
  221. type AdminUpdateSubscriptionPlanStatusRequest struct {
  222. Enabled *bool `json:"enabled"`
  223. }
  224. func AdminUpdateSubscriptionPlanStatus(c *gin.Context) {
  225. id, _ := strconv.Atoi(c.Param("id"))
  226. if id <= 0 {
  227. common.ApiErrorMsg(c, "无效的ID")
  228. return
  229. }
  230. var req AdminUpdateSubscriptionPlanStatusRequest
  231. if err := c.ShouldBindJSON(&req); err != nil || req.Enabled == nil {
  232. common.ApiErrorMsg(c, "参数错误")
  233. return
  234. }
  235. if err := model.DB.Model(&model.SubscriptionPlan{}).Where("id = ?", id).Update("enabled", *req.Enabled).Error; err != nil {
  236. common.ApiError(c, err)
  237. return
  238. }
  239. model.InvalidateSubscriptionPlanCache(id)
  240. common.ApiSuccess(c, nil)
  241. }
  242. type AdminBindSubscriptionRequest struct {
  243. UserId int `json:"user_id"`
  244. PlanId int `json:"plan_id"`
  245. }
  246. func AdminBindSubscription(c *gin.Context) {
  247. var req AdminBindSubscriptionRequest
  248. if err := c.ShouldBindJSON(&req); err != nil || req.UserId <= 0 || req.PlanId <= 0 {
  249. common.ApiErrorMsg(c, "参数错误")
  250. return
  251. }
  252. msg, err := model.AdminBindSubscription(req.UserId, req.PlanId, "")
  253. if err != nil {
  254. common.ApiError(c, err)
  255. return
  256. }
  257. if msg != "" {
  258. common.ApiSuccess(c, gin.H{"message": msg})
  259. return
  260. }
  261. common.ApiSuccess(c, nil)
  262. }
  263. // ---- Admin: user subscription management ----
  264. func AdminListUserSubscriptions(c *gin.Context) {
  265. userId, _ := strconv.Atoi(c.Param("id"))
  266. if userId <= 0 {
  267. common.ApiErrorMsg(c, "无效的用户ID")
  268. return
  269. }
  270. subs, err := model.GetAllUserSubscriptions(userId)
  271. if err != nil {
  272. common.ApiError(c, err)
  273. return
  274. }
  275. common.ApiSuccess(c, subs)
  276. }
  277. type AdminCreateUserSubscriptionRequest struct {
  278. PlanId int `json:"plan_id"`
  279. }
  280. // AdminCreateUserSubscription creates a new user subscription from a plan (no payment).
  281. func AdminCreateUserSubscription(c *gin.Context) {
  282. userId, _ := strconv.Atoi(c.Param("id"))
  283. if userId <= 0 {
  284. common.ApiErrorMsg(c, "无效的用户ID")
  285. return
  286. }
  287. var req AdminCreateUserSubscriptionRequest
  288. if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
  289. common.ApiErrorMsg(c, "参数错误")
  290. return
  291. }
  292. msg, err := model.AdminBindSubscription(userId, req.PlanId, "")
  293. if err != nil {
  294. common.ApiError(c, err)
  295. return
  296. }
  297. if msg != "" {
  298. common.ApiSuccess(c, gin.H{"message": msg})
  299. return
  300. }
  301. common.ApiSuccess(c, nil)
  302. }
  303. // AdminInvalidateUserSubscription cancels a user subscription immediately.
  304. func AdminInvalidateUserSubscription(c *gin.Context) {
  305. subId, _ := strconv.Atoi(c.Param("id"))
  306. if subId <= 0 {
  307. common.ApiErrorMsg(c, "无效的订阅ID")
  308. return
  309. }
  310. msg, err := model.AdminInvalidateUserSubscription(subId)
  311. if err != nil {
  312. common.ApiError(c, err)
  313. return
  314. }
  315. if msg != "" {
  316. common.ApiSuccess(c, gin.H{"message": msg})
  317. return
  318. }
  319. common.ApiSuccess(c, nil)
  320. }
  321. // AdminDeleteUserSubscription hard-deletes a user subscription.
  322. func AdminDeleteUserSubscription(c *gin.Context) {
  323. subId, _ := strconv.Atoi(c.Param("id"))
  324. if subId <= 0 {
  325. common.ApiErrorMsg(c, "无效的订阅ID")
  326. return
  327. }
  328. msg, err := model.AdminDeleteUserSubscription(subId)
  329. if err != nil {
  330. common.ApiError(c, err)
  331. return
  332. }
  333. if msg != "" {
  334. common.ApiSuccess(c, gin.H{"message": msg})
  335. return
  336. }
  337. common.ApiSuccess(c, nil)
  338. }