From 7ee6b11e9de313dd830d4f10381407536469c1a4 Mon Sep 17 00:00:00 2001 From: Gani Georgiev Date: Thu, 21 Nov 2024 11:12:25 +0200 Subject: [PATCH] return an error in case of required MFA so that external handlers can react if necessary --- apis/record_auth_with_otp.go | 10 +++++----- apis/record_helpers.go | 7 ++++++- apis/record_helpers_test.go | 9 +++++---- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/apis/record_auth_with_otp.go b/apis/record_auth_with_otp.go index ff19085f..a4021db1 100644 --- a/apis/record_auth_with_otp.go +++ b/apis/record_auth_with_otp.go @@ -79,17 +79,17 @@ func recordAuthWithOTP(e *core.RequestEvent) error { } } - err = RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOTP, nil) - if err != nil { - return err - } - // try to delete the used otp err = e.App.Delete(e.OTP) if err != nil { e.App.Logger().Error("Failed to delete used OTP", "error", err, "otpId", e.OTP.Id) } + err = RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOTP, nil) + if err != nil { + return err + } + return nil }) } diff --git a/apis/record_helpers.go b/apis/record_helpers.go index 2e66a608..9ed686c4 100644 --- a/apis/record_helpers.go +++ b/apis/record_helpers.go @@ -20,6 +20,8 @@ const ( fieldsQueryParam = "fields" ) +var ErrMFA = errors.New("mfa required") + // RecordAuthResponse writes standardized json record auth response // into the specified request context. // @@ -70,9 +72,12 @@ func recordAuthResponse(e *core.RequestEvent, authRecord *core.Record, token str // require additional authentication if mfaId != "" { - return e.JSON(http.StatusUnauthorized, map[string]string{ + // eagerly write the mfa response and return an err so that + // external middlewars are aware that the auth response requires an extra step + e.JSON(http.StatusUnauthorized, map[string]string{ "mfaId": mfaId, }) + return ErrMFA } // --- diff --git a/apis/record_helpers_test.go b/apis/record_helpers_test.go index b5a98631..d0458d3f 100644 --- a/apis/record_helpers_test.go +++ b/apis/record_helpers_test.go @@ -2,6 +2,7 @@ package apis_test import ( "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" @@ -584,8 +585,8 @@ func TestRecordAuthResponseMFACheck(t *testing.T) { user.Collection().MFA.Rule = "1=1" err = apis.RecordAuthResponse(event, user, "example", nil) - if err != nil { - t.Fatalf("Expected nil, got error: %v", err) + if !errors.Is(err, apis.ErrMFA) { + t.Fatalf("Expected ErrMFA, got: %v", err) } body := rec.Body.String() @@ -602,8 +603,8 @@ func TestRecordAuthResponseMFACheck(t *testing.T) { resetMFAs(user) err = apis.RecordAuthResponse(event, user, "example", nil) - if err != nil { - t.Fatalf("Expected nil, got error: %v", err) + if !errors.Is(err, apis.ErrMFA) { + t.Fatalf("Expected ErrMFA, got: %v", err) } body := rec.Body.String()