diff --git a/account/account.go b/account/account.go index 0795a19..1d2bd2d 100644 --- a/account/account.go +++ b/account/account.go @@ -11,6 +11,10 @@ import ( "time" ) +var ( + ErrInvalidOTPCode = errors.New("Invalid OTP code") +) + type AccountServiceParams struct { fx.In Db *gorm.DB @@ -107,7 +111,7 @@ func (s AccountServiceDefault) AddPubkeyToAccount(user models.User, pubkey strin return nil } func (s AccountServiceDefault) LoginPassword(email string, password string, ip string) (string, *models.User, error) { - valid, user, err := s.ValidLogin(email, password) + valid, user, err := s.ValidLoginByEmail(email, password) if err != nil { return "", nil, err @@ -126,7 +130,33 @@ func (s AccountServiceDefault) LoginPassword(email string, password string, ip s return token, user, nil } -func (s AccountServiceDefault) ValidLogin(email string, password string) (bool, *models.User, error) { +func (s AccountServiceDefault) LoginOTP(userId uint, code string) (string, error) { + valid, err := s.OTPVerify(userId, code) + + if err != nil { + return "", err + } + + if !valid { + return "", ErrInvalidOTPCode + } + + var user models.User + user.ID = userId + + token, err := JWTGenerateToken(s.config.GetString("core.domain"), s.identity, user.ID, JWTPurposeLogin) + if err != nil { + return "", err + } + + return token, nil +} + +func (s AccountServiceDefault) ValidLoginByUserObj(user *models.User, password string) bool { + return s.validPassword(user, password) +} + +func (s AccountServiceDefault) ValidLoginByEmail(email string, password string) (bool, *models.User, error) { var user models.User result := s.db.Model(&models.User{}).Where(&models.User{Email: email}).First(&user) @@ -135,14 +165,35 @@ func (s AccountServiceDefault) ValidLogin(email string, password string) (bool, return false, nil, result.Error } - err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) - if err != nil { - return false, nil, err + valid := s.ValidLoginByUserObj(&user, password) + + if !valid { + return false, nil, nil } return true, nil, nil } +func (s AccountServiceDefault) ValidLoginByUserID(id uint, password string) (bool, *models.User, error) { + var user models.User + + user.ID = id + + result := s.db.Model(&user).Where(&user).First(&user) + + if result.RowsAffected == 0 || result.Error != nil { + return false, nil, result.Error + } + + valid := s.ValidLoginByUserObj(&user, password) + + if !valid { + return false, nil, nil + } + + return true, &user, nil +} + func (s AccountServiceDefault) LoginPubkey(pubkey string) (string, error) { var model models.PublicKey @@ -252,8 +303,62 @@ func (s AccountServiceDefault) PinByID(uploadId uint, accountID uint) error { return nil } +func (s AccountServiceDefault) OTPGenerate(userId uint) (string, error) { + exists, user, err := s.AccountExists(userId) + + if !exists || err != nil { + return "", err + } + + otp, err := TOTPGenerate(user.Email, s.config.GetString("core.domain")) + if err != nil { + return "", err + } + + err = s.updateAccountInfo(user.ID, models.User{OTPSecret: otp}) + return otp, nil +} + +func (s AccountServiceDefault) OTPVerify(userId uint, code string) (bool, error) { + exists, user, err := s.AccountExists(userId) + + if !exists || err != nil { + return false, err + } + + valid := TOTPValidate(user.OTPSecret, code) + if !valid { + return false, nil + } + + return true, nil +} + +func (s AccountServiceDefault) OTPEnable(userId uint, code string) error { + verify, err := s.OTPVerify(userId, code) + if err != nil { + return err + } + + if !verify { + return ErrInvalidOTPCode + } + + return s.updateAccountInfo(userId, models.User{OTPEnabled: true}) +} + +func (s AccountServiceDefault) OTPDisable(userId uint) error { + return s.updateAccountInfo(userId, models.User{OTPEnabled: false, OTPSecret: ""}) +} + func (s AccountServiceDefault) doLogin(user *models.User, ip string) (string, error) { - token, err := JWTGenerateToken(s.config.GetString("core.domain"), s.identity, user.ID, JWTPurposeLogin) + purpose := JWTPurposeLogin + + if user.OTPEnabled { + purpose = JWTPurpose2FA + } + + token, err := JWTGenerateToken(s.config.GetString("core.domain"), s.identity, user.ID, purpose) if err != nil { return "", err } @@ -295,3 +400,9 @@ func (s AccountServiceDefault) exists(model interface{}, conditions map[string]i return exists, model, result.Error } + +func (s AccountServiceDefault) validPassword(user *models.User, password string) bool { + err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) + + return err == nil +}