From cc833ad643ac061d1af4c1c491bd291e114355a7 Mon Sep 17 00:00:00 2001 From: Gani Georgiev Date: Wed, 13 Nov 2024 20:14:27 +0200 Subject: [PATCH] updated mfa defaults and errors check --- apis/record_helpers.go | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/apis/record_helpers.go b/apis/record_helpers.go index f980da94..5e2ebce0 100644 --- a/apis/record_helpers.go +++ b/apis/record_helpers.go @@ -1,7 +1,6 @@ package apis import ( - "database/sql" "errors" "fmt" "net/http" @@ -122,7 +121,8 @@ func recordAuthResponse(e *core.RequestEvent, authRecord *core.Record, token str }) } -// wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule. +// wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule +// (note: returns true even in case of an error as a safer default). func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) { rule := record.Collection().MFA.Rule if rule == "" { @@ -131,7 +131,7 @@ func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) { requestInfo, err := e.RequestInfo() if err != nil { - return false, err + return true, err } var exists bool @@ -144,13 +144,13 @@ func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) { resolver := core.NewRecordFieldResolver(e.App, record.Collection(), requestInfo, true) expr, err := search.FilterData(rule).BuildExpr(resolver) if err != nil { - return false, err + return true, err } resolver.UpdateQuery(query) err = query.AndWhere(expr).Limit(1).Row(&exists) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return false, err + if err != nil { + return true, err } return exists, nil @@ -166,11 +166,10 @@ func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod s } ok, err := wantsMFA(e, authRecord) + if err != nil { + return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err)) + } if !ok { - if err != nil { - return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err)) - } - return "", nil // no mfa needed for this auth record } @@ -214,7 +213,7 @@ func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod s } if err != nil || mfa.HasExpired(authRecord.Collection().MFA.DurationTime()) { deleteMFA() - return "", firstApiError(err, e.BadRequestError("Invalid or expired MFA session.", err)) + return "", e.BadRequestError("Invalid or expired MFA session.", err) } if mfa.RecordRef() != authRecord.Id || mfa.CollectionRef() != authRecord.Collection().Id {