From e09e51bb52d513abcbbf53352a5d8ff68eb5364a Mon Sep 17 00:00:00 2001 From: Derrick Hammer Date: Fri, 4 Aug 2023 11:51:18 -0400 Subject: [PATCH] fix: wrap Register api in an atomic transaction to avoid dead locks --- service/account/account.go | 80 ++++++++++++++++---------------------- 1 file changed, 34 insertions(+), 46 deletions(-) diff --git a/service/account/account.go b/service/account/account.go index 4c4d3ff..11a1bbc 100644 --- a/service/account/account.go +++ b/service/account/account.go @@ -7,7 +7,6 @@ import ( "git.lumeweb.com/LumeWeb/portal/model" "go.uber.org/zap" "gorm.io/gorm" - "strings" ) var ( @@ -19,53 +18,42 @@ var ( ) func Register(email string, password string, pubkey string) error { - // Check if an account with the same email address already exists. - existingAccount := model.Account{} - err := db.Get().Where("email = ?", email).First(&existingAccount).Error - if err == nil { - logger.Get().Debug(ErrEmailExists.Error(), zap.Error(err), zap.String("email", email)) - // An account with the same email address already exists. - // Return an error response to the client. - return ErrEmailExists - } else if !errors.Is(err, gorm.ErrRecordNotFound) { - logger.Get().Error(ErrQueryingAcct.Error(), zap.Error(err)) - return ErrQueryingAcct - } - - if len(pubkey) > 0 { - pubkey = strings.ToLower(pubkey) - var count int64 - err := db.Get().Model(&model.Key{}).Where("pubkey = ?", pubkey).Count(&count).Error - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - logger.Get().Error(ErrQueryingAcct.Error(), zap.Error(err), zap.String("pubkey", pubkey)) - return ErrQueryingAcct - } - if count > 0 { - logger.Get().Debug(ErrPubkeyExists.Error(), zap.Error(err), zap.String("pubkey", pubkey)) - // An account with the same pubkey already exists. - // Return an error response to the client. - return ErrPubkeyExists - } - - } - - // Create a new Account model with the provided email and hashed password. - account := model.Account{ - Email: email, - } - - // Hash the password before saving it to the database. - if len(password) > 0 { - hashedPassword, err := hashPassword(password) - if err != nil { + err := db.Get().Transaction(func(tx *gorm.DB) error { + existingAccount := model.Account{} + err := tx.Where("email = ?", email).First(&existingAccount).Error + if err == nil { + return ErrEmailExists + } else if !errors.Is(err, gorm.ErrRecordNotFound) { return err } - account.Password = &hashedPassword - } + if len(pubkey) > 0 { + var count int64 + err := tx.Model(&model.Key{}).Where("pubkey = ?", pubkey).Count(&count).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + if count > 0 { + // An account with the same pubkey already exists. + // Return an error response to the client. + return ErrPubkeyExists + } + } + + // Create a new Account model with the provided email and hashed password. + account := model.Account{ + Email: email, + } + + // Hash the password before saving it to the database. + if len(password) > 0 { + hashedPassword, err := hashPassword(password) + if err != nil { + return err + } + account.Password = &hashedPassword + } - err = db.Get().Transaction(func(tx *gorm.DB) error { - // do some database operations in the transaction (use 'tx' from this point, not 'db') if err := tx.Create(&account).Error; err != nil { return err } @@ -76,12 +64,12 @@ func Register(email string, password string, pubkey string) error { } } - // return nil will commit the whole transaction return nil }) + if err != nil { logger.Get().Error(ErrFailedCreateAccount.Error(), zap.Error(err)) - return ErrFailedCreateAccount + return err } return nil