passkey.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. package model
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "one-api/common"
  7. "strings"
  8. "time"
  9. "github.com/go-webauthn/webauthn/protocol"
  10. "github.com/go-webauthn/webauthn/webauthn"
  11. "gorm.io/gorm"
  12. )
  13. var (
  14. ErrPasskeyNotFound = errors.New("passkey credential not found")
  15. ErrFriendlyPasskeyNotFound = errors.New("Passkey 验证失败,请重试或联系管理员")
  16. )
  17. type PasskeyCredential struct {
  18. ID int `json:"id" gorm:"primaryKey"`
  19. UserID int `json:"user_id" gorm:"uniqueIndex;not null"`
  20. CredentialID []byte `json:"credential_id" gorm:"type:blob;uniqueIndex;not null"`
  21. PublicKey []byte `json:"public_key" gorm:"type:blob;not null"`
  22. AttestationType string `json:"attestation_type" gorm:"type:varchar(255)"`
  23. AAGUID []byte `json:"aaguid" gorm:"type:blob"`
  24. SignCount uint32 `json:"sign_count" gorm:"default:0"`
  25. CloneWarning bool `json:"clone_warning"`
  26. UserPresent bool `json:"user_present"`
  27. UserVerified bool `json:"user_verified"`
  28. BackupEligible bool `json:"backup_eligible"`
  29. BackupState bool `json:"backup_state"`
  30. Transports string `json:"transports" gorm:"type:text"`
  31. Attachment string `json:"attachment" gorm:"type:varchar(32)"`
  32. LastUsedAt *time.Time `json:"last_used_at"`
  33. CreatedAt time.Time `json:"created_at"`
  34. UpdatedAt time.Time `json:"updated_at"`
  35. DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
  36. }
  37. func (p *PasskeyCredential) TransportList() []protocol.AuthenticatorTransport {
  38. if p == nil || strings.TrimSpace(p.Transports) == "" {
  39. return nil
  40. }
  41. var transports []string
  42. if err := json.Unmarshal([]byte(p.Transports), &transports); err != nil {
  43. return nil
  44. }
  45. result := make([]protocol.AuthenticatorTransport, 0, len(transports))
  46. for _, transport := range transports {
  47. result = append(result, protocol.AuthenticatorTransport(transport))
  48. }
  49. return result
  50. }
  51. func (p *PasskeyCredential) SetTransports(list []protocol.AuthenticatorTransport) {
  52. if len(list) == 0 {
  53. p.Transports = ""
  54. return
  55. }
  56. stringList := make([]string, len(list))
  57. for i, transport := range list {
  58. stringList[i] = string(transport)
  59. }
  60. encoded, err := json.Marshal(stringList)
  61. if err != nil {
  62. return
  63. }
  64. p.Transports = string(encoded)
  65. }
  66. func (p *PasskeyCredential) ToWebAuthnCredential() webauthn.Credential {
  67. flags := webauthn.CredentialFlags{
  68. UserPresent: p.UserPresent,
  69. UserVerified: p.UserVerified,
  70. BackupEligible: p.BackupEligible,
  71. BackupState: p.BackupState,
  72. }
  73. return webauthn.Credential{
  74. ID: p.CredentialID,
  75. PublicKey: p.PublicKey,
  76. AttestationType: p.AttestationType,
  77. Transport: p.TransportList(),
  78. Flags: flags,
  79. Authenticator: webauthn.Authenticator{
  80. AAGUID: p.AAGUID,
  81. SignCount: p.SignCount,
  82. CloneWarning: p.CloneWarning,
  83. Attachment: protocol.AuthenticatorAttachment(p.Attachment),
  84. },
  85. }
  86. }
  87. func NewPasskeyCredentialFromWebAuthn(userID int, credential *webauthn.Credential) *PasskeyCredential {
  88. if credential == nil {
  89. return nil
  90. }
  91. passkey := &PasskeyCredential{
  92. UserID: userID,
  93. CredentialID: credential.ID,
  94. PublicKey: credential.PublicKey,
  95. AttestationType: credential.AttestationType,
  96. AAGUID: credential.Authenticator.AAGUID,
  97. SignCount: credential.Authenticator.SignCount,
  98. CloneWarning: credential.Authenticator.CloneWarning,
  99. UserPresent: credential.Flags.UserPresent,
  100. UserVerified: credential.Flags.UserVerified,
  101. BackupEligible: credential.Flags.BackupEligible,
  102. BackupState: credential.Flags.BackupState,
  103. Attachment: string(credential.Authenticator.Attachment),
  104. }
  105. passkey.SetTransports(credential.Transport)
  106. return passkey
  107. }
  108. func (p *PasskeyCredential) ApplyValidatedCredential(credential *webauthn.Credential) {
  109. if credential == nil || p == nil {
  110. return
  111. }
  112. p.CredentialID = credential.ID
  113. p.PublicKey = credential.PublicKey
  114. p.AttestationType = credential.AttestationType
  115. p.AAGUID = credential.Authenticator.AAGUID
  116. p.SignCount = credential.Authenticator.SignCount
  117. p.CloneWarning = credential.Authenticator.CloneWarning
  118. p.UserPresent = credential.Flags.UserPresent
  119. p.UserVerified = credential.Flags.UserVerified
  120. p.BackupEligible = credential.Flags.BackupEligible
  121. p.BackupState = credential.Flags.BackupState
  122. p.Attachment = string(credential.Authenticator.Attachment)
  123. p.SetTransports(credential.Transport)
  124. }
  125. func GetPasskeyByUserID(userID int) (*PasskeyCredential, error) {
  126. if userID == 0 {
  127. common.SysLog("GetPasskeyByUserID: empty user ID")
  128. return nil, ErrFriendlyPasskeyNotFound
  129. }
  130. var credential PasskeyCredential
  131. if err := DB.Where("user_id = ?", userID).First(&credential).Error; err != nil {
  132. if errors.Is(err, gorm.ErrRecordNotFound) {
  133. common.SysLog(fmt.Sprintf("GetPasskeyByUserID: passkey not found for user %d", userID))
  134. return nil, ErrFriendlyPasskeyNotFound
  135. }
  136. common.SysLog(fmt.Sprintf("GetPasskeyByUserID: database error for user %d: %v", userID, err))
  137. return nil, ErrFriendlyPasskeyNotFound
  138. }
  139. return &credential, nil
  140. }
  141. func GetPasskeyByCredentialID(credentialID []byte) (*PasskeyCredential, error) {
  142. if len(credentialID) == 0 {
  143. common.SysLog("GetPasskeyByCredentialID: empty credential ID")
  144. return nil, ErrFriendlyPasskeyNotFound
  145. }
  146. var credential PasskeyCredential
  147. if err := DB.Where("credential_id = ?", credentialID).First(&credential).Error; err != nil {
  148. if errors.Is(err, gorm.ErrRecordNotFound) {
  149. common.SysLog(fmt.Sprintf("GetPasskeyByCredentialID: passkey not found for credential ID length %d", len(credentialID)))
  150. return nil, ErrFriendlyPasskeyNotFound
  151. }
  152. common.SysLog(fmt.Sprintf("GetPasskeyByCredentialID: database error for credential ID: %v", err))
  153. return nil, ErrFriendlyPasskeyNotFound
  154. }
  155. return &credential, nil
  156. }
  157. func UpsertPasskeyCredential(credential *PasskeyCredential) error {
  158. if credential == nil {
  159. common.SysLog("UpsertPasskeyCredential: nil credential provided")
  160. return fmt.Errorf("Passkey 保存失败,请重试")
  161. }
  162. return DB.Transaction(func(tx *gorm.DB) error {
  163. // 使用Unscoped()进行硬删除,避免唯一索引冲突
  164. if err := tx.Unscoped().Where("user_id = ?", credential.UserID).Delete(&PasskeyCredential{}).Error; err != nil {
  165. common.SysLog(fmt.Sprintf("UpsertPasskeyCredential: failed to delete existing credential for user %d: %v", credential.UserID, err))
  166. return fmt.Errorf("Passkey 保存失败,请重试")
  167. }
  168. if err := tx.Create(credential).Error; err != nil {
  169. common.SysLog(fmt.Sprintf("UpsertPasskeyCredential: failed to create credential for user %d: %v", credential.UserID, err))
  170. return fmt.Errorf("Passkey 保存失败,请重试")
  171. }
  172. return nil
  173. })
  174. }
  175. func DeletePasskeyByUserID(userID int) error {
  176. if userID == 0 {
  177. common.SysLog("DeletePasskeyByUserID: empty user ID")
  178. return fmt.Errorf("删除失败,请重试")
  179. }
  180. // 使用Unscoped()进行硬删除,避免唯一索引冲突
  181. if err := DB.Unscoped().Where("user_id = ?", userID).Delete(&PasskeyCredential{}).Error; err != nil {
  182. common.SysLog(fmt.Sprintf("DeletePasskeyByUserID: failed to delete passkey for user %d: %v", userID, err))
  183. return fmt.Errorf("删除失败,请重试")
  184. }
  185. return nil
  186. }