diff --git a/apis/record_helpers.go b/apis/record_helpers.go index f980da94..5e2ebce0 100644 --- a/apis/record_helpers.go +++ b/apis/record_helpers.go @@ -1,7 +1,6 @@ package apis import ( - "database/sql" "errors" "fmt" "net/http" @@ -122,7 +121,8 @@ func recordAuthResponse(e *core.RequestEvent, authRecord *core.Record, token str }) } -// wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule. +// wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule +// (note: returns true even in case of an error as a safer default). func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) { rule := record.Collection().MFA.Rule if rule == "" { @@ -131,7 +131,7 @@ func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) { requestInfo, err := e.RequestInfo() if err != nil { - return false, err + return true, err } var exists bool @@ -144,13 +144,13 @@ func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) { resolver := core.NewRecordFieldResolver(e.App, record.Collection(), requestInfo, true) expr, err := search.FilterData(rule).BuildExpr(resolver) if err != nil { - return false, err + return true, err } resolver.UpdateQuery(query) err = query.AndWhere(expr).Limit(1).Row(&exists) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return false, err + if err != nil { + return true, err } return exists, nil @@ -166,11 +166,10 @@ func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod s } ok, err := wantsMFA(e, authRecord) + if err != nil { + return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err)) + } if !ok { - if err != nil { - return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err)) - } - return "", nil // no mfa needed for this auth record } @@ -214,7 +213,7 @@ func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod s } if err != nil || mfa.HasExpired(authRecord.Collection().MFA.DurationTime()) { deleteMFA() - return "", firstApiError(err, e.BadRequestError("Invalid or expired MFA session.", err)) + return "", e.BadRequestError("Invalid or expired MFA session.", err) } if mfa.RecordRef() != authRecord.Id || mfa.CollectionRef() != authRecord.Collection().Id {