diff --git a/internal/db/initialize.go b/internal/db/initialize.go index 12c248b..b039c17 100644 --- a/internal/db/initialize.go +++ b/internal/db/initialize.go @@ -46,6 +46,11 @@ func InitializeDB() error { return fmt.Errorf("[DB] couldn't create database: [%w]", err) } + _, err = models.DB.NewCreateTable().IfNotExists().Model((*models.MFAScratchCode)(nil)).Exec(context.Background()) + if err != nil { + return fmt.Errorf("[DB] couldn't create database: [%w]", err) + } + err = doMigrations() if err != nil { return fmt.Errorf("[DB] Error during Migrations: [%w]", err) diff --git a/internal/db/multifactor.go b/internal/db/multifactor.go index 1a9c1ef..040fb7a 100644 --- a/internal/db/multifactor.go +++ b/internal/db/multifactor.go @@ -18,3 +18,18 @@ func UserHasMFA(user models.User) (bool, error) { } return false, nil } + +// ScratchCodeUnique checks the database if the generated scratch code +// is unique (not in the database yet) +func ScratchCodeIsUnique(scratchcode string) bool { + var dbitem models.MFAScratchCode + numrows, err := models.DB.NewSelect().Model(&dbitem).Where("code = ?", scratchcode).Count(context.Background()) + if err != nil { + return false + } + + if numrows != 0 { + return false + } + return true +} diff --git a/internal/web/multifactor-login.go b/internal/web/multifactor-login.go index 66879ed..cae4480 100644 --- a/internal/web/multifactor-login.go +++ b/internal/web/multifactor-login.go @@ -2,6 +2,7 @@ package web import ( "context" + "fmt" "log" "slices" @@ -45,20 +46,28 @@ func HandleAdminLoginMFAPost(c *fiber.Ctx) error { } // check token/scratch validity - var passcodeIsValid bool + var scratchcodeIsValid bool if istotp { - passcodeIsValid = checkTotpIsValid(token.Token, user) + scratchcodeIsValid = checkTotpIsValid(token.Token, user) } else { - passcodeIsValid = checkScratchIsValid(token.Token, user) + scratchcodeIsValid = checkScratchIsValid(token.Token, user) } - if passcodeIsValid { + if scratchcodeIsValid { err = misc.SetLoginCookie(c, user, constants.LoginCookieExpiryDuration) if err != nil { log.Printf("[HandleAdminLoginPost] Error setting cookie: %q\n", err) return misc.New500Error() } + _, err = models.DB.NewUpdate(). + Model(&models.MFAScratchCode{Code: token.Token, IsUsed: true}). + OmitZero().WherePK().Exec(context.Background()) + if err != nil { + fmt.Printf("[HandleAdminLoginMFAPost] Error marking scratch code as used: %v\n", err.Error()) + return fiber.NewError(fiber.StatusInternalServerError, "500 Internal Server Error") + } + c.Status(fiber.StatusOK) return nil } @@ -81,16 +90,21 @@ func checkTotpIsValid(passcode string, user models.User) bool { } func checkScratchIsValid(scratch string, user models.User) bool { - var mfaconfig models.MFAConfig + var scratchcodes []models.MFAScratchCode - err := models.DB.NewSelect().Model(&mfaconfig).Where("username = ?", user.UserName).Scan(context.Background()) + err := models.DB.NewSelect().Model(&scratchcodes).Where("username = ?", user.UserName).Where("is_used = ?", false).Scan(context.Background()) if err != nil { log.Printf("Error getting MFA config for %v from DB: %v\n", user.UserName, err.Error()) // TODO: Debug logging return false } - return slices.Contains(mfaconfig.RecoveryCodes, scratch) + var scratchcodeSlice []string + for _, code := range scratchcodes { + scratchcodeSlice = append(scratchcodeSlice, code.Code) + } + + return slices.Contains(scratchcodeSlice, scratch) } func HandleAdminLoginMFAGet(c *fiber.Ctx) error { diff --git a/internal/web/setup-multifactor.go b/internal/web/setup-multifactor.go index 9b3067e..082ad23 100644 --- a/internal/web/setup-multifactor.go +++ b/internal/web/setup-multifactor.go @@ -5,6 +5,7 @@ import ( "context" "encoding/base64" "encoding/json" + "fmt" "image/png" "log" "strconv" @@ -18,12 +19,6 @@ import ( "github.com/pquerna/otp/totp" ) -type mfaSetupResponse struct { - Error bool `json:"error"` - Message string `json:"message,omitempty"` - RecoveryTokens []string `json:"recoverytokens,omitempty"` -} - func HandleAdminAccountMFASetupGet(c *fiber.Ctx) error { if !db.IsCookieValid(c.Cookies(constants.LoginCookieName, "")) { c.Location("/admin/") @@ -36,15 +31,33 @@ func HandleAdminAccountMFASetupGet(c *fiber.Ctx) error { mfaconfig.Active = false mfaconfig.ExpiresAt = time.Now().Add(15 * time.Minute) - for i := 0; i < 6; i++ { - mfaconfig.RecoveryCodes = append(mfaconfig.RecoveryCodes, misc.RandomString(8)) - } - user, err := db.GetUserFromCookie(c.Cookies(constants.LoginCookieName)) if err != nil { log.Println(err) fiber.NewError(fiber.StatusInternalServerError, "500 Internal Server Error") } + + scratchcodes := []models.MFAScratchCode{} + scratchcodeFailed := 0 + + // generate four unique(!) scratch codes for the user + for len(scratchcodes) != 4 { + if scratchcodeFailed > 15 { + //TODO: structurized error logging + fmt.Println("[HandleAdminAccountMFASetupPost] Failed to generate unique scratch code 15 times! Aborting") + return misc.New500Error() + } + + code := misc.RandomString(8) + + if db.ScratchCodeIsUnique(code) { + scratchcodes = append(scratchcodes, models.MFAScratchCode{ + IsUsed: false, Code: code, UserName: user.UserName}) + } else { + scratchcodeFailed++ + } + } + mfaconfig.UserName = user.UserName key, err := totp.Generate(totp.GenerateOpts{ @@ -71,6 +84,12 @@ func HandleAdminAccountMFASetupGet(c *fiber.Ctx) error { mfaobject.Image = base64img + _, err = models.DB.NewInsert().Model(&scratchcodes).Exec(context.Background()) + if err != nil { + log.Printf("[HandleAdminAccountMFASetupGet] Error inserting scratch codes to DB: %q\n", err) + fiber.NewError(fiber.StatusInternalServerError, "500 Internal Server Error") + } + _, err = models.DB.NewInsert().Model(&mfaconfig).Exec(context.Background()) if err != nil { log.Printf("[HandleAdminAccountMFASetupGet] Error inserting mfaconfig to DB: %q\n", err) @@ -88,11 +107,12 @@ func HandleAdminAccountMFASetupPost(c *fiber.Ctx) error { return nil } - var response mfaSetupResponse + var response models.MFASetupResponse response.Error = true var token models.TokenRequest var config models.MFAConfig + var scratchcodes []models.MFAScratchCode //var user models.User err := json.Unmarshal(c.Body(), &token) @@ -115,7 +135,7 @@ func HandleAdminAccountMFASetupPost(c *fiber.Ctx) error { err = models.DB.NewSelect().Model(&config).Where("id = ?", setupcookie).Scan(context.Background()) if err != nil { - log.Printf("[HandleAdminAccountMFASetupGet] Error getting MFAConfig from DB: %q\n", err) + log.Printf("[HandleAdminAccountMFASetupPost] Error getting MFAConfig from DB: %q\n", err) return fiber.NewError(fiber.StatusInternalServerError, "500 Internal Server Error") } @@ -123,7 +143,19 @@ func HandleAdminAccountMFASetupPost(c *fiber.Ctx) error { if totpvalid { response.Error = false response.Message = "Multifactor authentication was successfully set up!" - response.RecoveryTokens = config.RecoveryCodes + + err = models.DB.NewSelect().Model(&scratchcodes).Where("username = ?", config.UserName).Scan(context.Background()) + if err != nil { + log.Printf("[HandleAdminAccountMFASetupPost] Error getting MFA scratch codes from DB: %q\n", err) + return fiber.NewError(fiber.StatusInternalServerError, "500 Internal Server Error") + } + + var scratchcodeSlice []string + for _, code := range scratchcodes { + scratchcodeSlice = append(scratchcodeSlice, code.Code) + } + + response.RecoveryTokens = scratchcodeSlice config.Active = true diff --git a/models/multifactor.go b/models/multifactor.go index 2612302..19e2983 100644 --- a/models/multifactor.go +++ b/models/multifactor.go @@ -26,12 +26,24 @@ type MFAConfig struct { ID int64 `bun:"id,pk,autoincrement"` UserName string `bun:"username,notnull"` TOTPSecret string `bun:"totpurl,notnull"` - RecoveryCodes []string `bun:"recoverycodes,notnull,array"` ExpiresAt time.Time `bun:"expiresat,notnull"` Active bool `bun:"active,notnull"` } +type MFAScratchCode struct { + bun.BaseModel `bun:"table:multifactor_scratchcodes"` + Code string `bun:"code,pk"` + UserName string `bun:"username,notnull"` + IsUsed bool `bun:"is_used,notnull"` +} + type MFATemplateObject struct { Key string Image string } + +type MFASetupResponse struct { + Error bool `json:"error"` + Message string `json:"message,omitempty"` + RecoveryTokens []string `json:"recovery_tokens,omitempty"` +} diff --git a/views/account.tmpl b/views/account.tmpl index df1c5f8..a36284e 100644 --- a/views/account.tmpl +++ b/views/account.tmpl @@ -50,11 +50,11 @@
Created at: {{ .Created }}