updated mfa defaults and errors check

This commit is contained in:
Gani Georgiev 2024-11-13 20:14:27 +02:00
parent 396aa0f97c
commit cc833ad643
1 changed files with 10 additions and 11 deletions

View File

@ -1,7 +1,6 @@
package apis package apis
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -122,7 +121,8 @@ func recordAuthResponse(e *core.RequestEvent, authRecord *core.Record, token str
}) })
} }
// wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule. // wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule
// (note: returns true even in case of an error as a safer default).
func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) { func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) {
rule := record.Collection().MFA.Rule rule := record.Collection().MFA.Rule
if rule == "" { if rule == "" {
@ -131,7 +131,7 @@ func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) {
requestInfo, err := e.RequestInfo() requestInfo, err := e.RequestInfo()
if err != nil { if err != nil {
return false, err return true, err
} }
var exists bool var exists bool
@ -144,13 +144,13 @@ func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) {
resolver := core.NewRecordFieldResolver(e.App, record.Collection(), requestInfo, true) resolver := core.NewRecordFieldResolver(e.App, record.Collection(), requestInfo, true)
expr, err := search.FilterData(rule).BuildExpr(resolver) expr, err := search.FilterData(rule).BuildExpr(resolver)
if err != nil { if err != nil {
return false, err return true, err
} }
resolver.UpdateQuery(query) resolver.UpdateQuery(query)
err = query.AndWhere(expr).Limit(1).Row(&exists) err = query.AndWhere(expr).Limit(1).Row(&exists)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil {
return false, err return true, err
} }
return exists, nil return exists, nil
@ -166,11 +166,10 @@ func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod s
} }
ok, err := wantsMFA(e, authRecord) ok, err := wantsMFA(e, authRecord)
if !ok {
if err != nil { if err != nil {
return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err)) return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err))
} }
if !ok {
return "", nil // no mfa needed for this auth record return "", nil // no mfa needed for this auth record
} }
@ -214,7 +213,7 @@ func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod s
} }
if err != nil || mfa.HasExpired(authRecord.Collection().MFA.DurationTime()) { if err != nil || mfa.HasExpired(authRecord.Collection().MFA.DurationTime()) {
deleteMFA() deleteMFA()
return "", firstApiError(err, e.BadRequestError("Invalid or expired MFA session.", err)) return "", e.BadRequestError("Invalid or expired MFA session.", err)
} }
if mfa.RecordRef() != authRecord.Id || mfa.CollectionRef() != authRecord.Collection().Id { if mfa.RecordRef() != authRecord.Id || mfa.CollectionRef() != authRecord.Collection().Id {