From f9ee710cddcc524398fd6d84da9a7675d10b8e2c Mon Sep 17 00:00:00 2001 From: Gani Georgiev Date: Mon, 14 Oct 2024 18:17:31 +0300 Subject: [PATCH] normalized builtin middlewares to return hook.Handler --- apis/middlewares_cors.go | 179 ++++++++++++++++++++------------------- apis/middlewares_gzip.go | 104 ++++++++++++----------- apis/serve.go | 14 ++- 3 files changed, 153 insertions(+), 144 deletions(-) diff --git a/apis/middlewares_cors.go b/apis/middlewares_cors.go index f8711b2b..1c4e2c6d 100644 --- a/apis/middlewares_cors.go +++ b/apis/middlewares_cors.go @@ -19,6 +19,7 @@ import ( "strings" "github.com/pocketbase/pocketbase/core" + "github.com/pocketbase/pocketbase/tools/hook" ) const ( @@ -124,7 +125,7 @@ var DefaultCORSConfig = CORSConfig{ } // CORSWithConfig returns a CORS middleware with config. -func CORSWithConfig(config CORSConfig) func(e *core.RequestEvent) error { +func CORSWithConfig(config CORSConfig) *hook.Handler[*core.RequestEvent] { // Defaults if len(config.AllowOrigins) == 0 { config.AllowOrigins = DefaultCORSConfig.AllowOrigins @@ -151,108 +152,112 @@ func CORSWithConfig(config CORSConfig) func(e *core.RequestEvent) error { maxAge = strconv.Itoa(config.MaxAge) } - return func(e *core.RequestEvent) error { - req := e.Request - res := e.Response - origin := req.Header.Get("Origin") - allowOrigin := "" + return &hook.Handler[*core.RequestEvent]{ + Id: DefaultCorsMiddlewareId, + Priority: DefaultCorsMiddlewarePriority, + Func: func(e *core.RequestEvent) error { + req := e.Request + res := e.Response + origin := req.Header.Get("Origin") + allowOrigin := "" - res.Header().Add("Vary", "Origin") + res.Header().Add("Vary", "Origin") - // Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method, - // Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request - // For simplicity we just consider method type and later `Origin` header. - preflight := req.Method == http.MethodOptions + // Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method, + // Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request + // For simplicity we just consider method type and later `Origin` header. + preflight := req.Method == http.MethodOptions - // No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain - if origin == "" { - if !preflight { - return e.Next() + // No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain + if origin == "" { + if !preflight { + return e.Next() + } + return e.NoContent(http.StatusNoContent) } - return e.NoContent(http.StatusNoContent) - } - if config.AllowOriginFunc != nil { - allowed, err := config.AllowOriginFunc(origin) - if err != nil { - return err - } - if allowed { - allowOrigin = origin - } - } else { - // Check allowed origins - for _, o := range config.AllowOrigins { - if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials { + if config.AllowOriginFunc != nil { + allowed, err := config.AllowOriginFunc(origin) + if err != nil { + return err + } + if allowed { allowOrigin = origin - break } - if o == "*" || o == origin { - allowOrigin = o - break - } - if matchSubdomain(origin, o) { - allowOrigin = origin - break - } - } - - checkPatterns := false - if allowOrigin == "" { - // to avoid regex cost by invalid (long) domains (253 is domain name max limit) - if len(origin) <= (253+3+5) && strings.Contains(origin, "://") { - checkPatterns = true - } - } - if checkPatterns { - for _, re := range allowOriginPatterns { - if match, _ := regexp.MatchString(re, origin); match { + } else { + // Check allowed origins + for _, o := range config.AllowOrigins { + if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials { + allowOrigin = origin + break + } + if o == "*" || o == origin { + allowOrigin = o + break + } + if matchSubdomain(origin, o) { allowOrigin = origin break } } - } - } - // Origin not allowed - if allowOrigin == "" { + checkPatterns := false + if allowOrigin == "" { + // to avoid regex cost by invalid (long) domains (253 is domain name max limit) + if len(origin) <= (253+3+5) && strings.Contains(origin, "://") { + checkPatterns = true + } + } + if checkPatterns { + for _, re := range allowOriginPatterns { + if match, _ := regexp.MatchString(re, origin); match { + allowOrigin = origin + break + } + } + } + } + + // Origin not allowed + if allowOrigin == "" { + if !preflight { + return e.Next() + } + return e.NoContent(http.StatusNoContent) + } + + res.Header().Set("Access-Control-Allow-Origin", allowOrigin) + if config.AllowCredentials { + res.Header().Set("Access-Control-Allow-Credentials", "true") + } + + // Simple request if !preflight { + if exposeHeaders != "" { + res.Header().Set("Access-Control-Expose-Headers", exposeHeaders) + } return e.Next() } + + // Preflight request + res.Header().Add("Vary", "Access-Control-Request-Method") + res.Header().Add("Vary", "Access-Control-Request-Headers") + res.Header().Set("Access-Control-Allow-Methods", allowMethods) + + if allowHeaders != "" { + res.Header().Set("Access-Control-Allow-Headers", allowHeaders) + } else { + h := req.Header.Get("Access-Control-Request-Headers") + if h != "" { + res.Header().Set("Access-Control-Allow-Headers", h) + } + } + if config.MaxAge != 0 { + res.Header().Set("Access-Control-Max-Age", maxAge) + } + return e.NoContent(http.StatusNoContent) - } - - res.Header().Set("Access-Control-Allow-Origin", allowOrigin) - if config.AllowCredentials { - res.Header().Set("Access-Control-Allow-Credentials", "true") - } - - // Simple request - if !preflight { - if exposeHeaders != "" { - res.Header().Set("Access-Control-Expose-Headers", exposeHeaders) - } - return e.Next() - } - - // Preflight request - res.Header().Add("Vary", "Access-Control-Request-Method") - res.Header().Add("Vary", "Access-Control-Request-Headers") - res.Header().Set("Access-Control-Allow-Methods", allowMethods) - - if allowHeaders != "" { - res.Header().Set("Access-Control-Allow-Headers", allowHeaders) - } else { - h := req.Header.Get("Access-Control-Request-Headers") - if h != "" { - res.Header().Set("Access-Control-Allow-Headers", h) - } - } - if config.MaxAge != 0 { - res.Header().Set("Access-Control-Max-Age", maxAge) - } - - return e.NoContent(http.StatusNoContent) + }, } } diff --git a/apis/middlewares_gzip.go b/apis/middlewares_gzip.go index 68923e6e..23934d4e 100644 --- a/apis/middlewares_gzip.go +++ b/apis/middlewares_gzip.go @@ -18,6 +18,7 @@ import ( "sync" "github.com/pocketbase/pocketbase/core" + "github.com/pocketbase/pocketbase/tools/hook" "github.com/pocketbase/pocketbase/tools/router" ) @@ -25,6 +26,10 @@ const ( gzipScheme = "gzip" ) +const ( + DefaultGzipMiddlewareId = "pbGzip" +) + // GzipConfig defines the config for Gzip middleware. type GzipConfig struct { // Gzip compression level. @@ -46,12 +51,12 @@ type GzipConfig struct { } // Gzip returns a middleware which compresses HTTP response using Gzip compression scheme. -func Gzip() func(*core.RequestEvent) error { +func Gzip() *hook.Handler[*core.RequestEvent] { return GzipWithConfig(GzipConfig{}) } // GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme. -func GzipWithConfig(config GzipConfig) func(*core.RequestEvent) error { +func GzipWithConfig(config GzipConfig) *hook.Handler[*core.RequestEvent] { if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression panic(errors.New("invalid gzip level")) } @@ -79,54 +84,57 @@ func GzipWithConfig(config GzipConfig) func(*core.RequestEvent) error { }, } - return func(e *core.RequestEvent) error { - e.Response.Header().Add("Vary", "Accept-Encoding") - if strings.Contains(e.Request.Header.Get("Accept-Encoding"), gzipScheme) { - w, ok := pool.Get().(*gzip.Writer) - if !ok { - return e.InternalServerError("", errors.New("failed to get gzip.Writer")) + return &hook.Handler[*core.RequestEvent]{ + Id: DefaultGzipMiddlewareId, + Func: func(e *core.RequestEvent) error { + e.Response.Header().Add("Vary", "Accept-Encoding") + if strings.Contains(e.Request.Header.Get("Accept-Encoding"), gzipScheme) { + w, ok := pool.Get().(*gzip.Writer) + if !ok { + return e.InternalServerError("", errors.New("failed to get gzip.Writer")) + } + + rw := e.Response + w.Reset(rw) + + buf := bpool.Get().(*bytes.Buffer) + buf.Reset() + + grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf} + defer func() { + // There are different reasons for cases when we have not yet written response to the client and now need to do so. + // a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now. + // b) body is shorter than our minimum length threshold and being buffered currently and needs to be written + if !grw.wroteBody { + if rw.Header().Get("Content-Encoding") == gzipScheme { + rw.Header().Del("Content-Encoding") + } + if grw.wroteHeader { + rw.WriteHeader(grw.code) + } + // We have to reset response to it's pristine state when + // nothing is written to body or error is returned. + // See issue echo#424, echo#407. + e.Response = rw + w.Reset(io.Discard) + } else if !grw.minLengthExceeded { + // Write uncompressed response + e.Response = rw + if grw.wroteHeader { + rw.WriteHeader(grw.code) + } + grw.buffer.WriteTo(rw) + w.Reset(io.Discard) + } + w.Close() + bpool.Put(buf) + pool.Put(w) + }() + e.Response = grw } - rw := e.Response - w.Reset(rw) - - buf := bpool.Get().(*bytes.Buffer) - buf.Reset() - - grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf} - defer func() { - // There are different reasons for cases when we have not yet written response to the client and now need to do so. - // a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now. - // b) body is shorter than our minimum length threshold and being buffered currently and needs to be written - if !grw.wroteBody { - if rw.Header().Get("Content-Encoding") == gzipScheme { - rw.Header().Del("Content-Encoding") - } - if grw.wroteHeader { - rw.WriteHeader(grw.code) - } - // We have to reset response to it's pristine state when - // nothing is written to body or error is returned. - // See issue echo#424, echo#407. - e.Response = rw - w.Reset(io.Discard) - } else if !grw.minLengthExceeded { - // Write uncompressed response - e.Response = rw - if grw.wroteHeader { - rw.WriteHeader(grw.code) - } - grw.buffer.WriteTo(rw) - w.Reset(io.Discard) - } - w.Close() - bpool.Put(buf) - pool.Put(w) - }() - e.Response = grw - } - - return e.Next() + return e.Next() + }, } } diff --git a/apis/serve.go b/apis/serve.go index 5706d619..3ddc8762 100644 --- a/apis/serve.go +++ b/apis/serve.go @@ -83,21 +83,17 @@ func Serve(app core.App, config ServeConfig) error { return err } - pbRouter.Bind(&hook.Handler[*core.RequestEvent]{ - Id: DefaultCorsMiddlewareId, - Func: CORSWithConfig(CORSConfig{ - AllowOrigins: config.AllowedOrigins, - AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, - }), - Priority: DefaultCorsMiddlewarePriority, - }) + pbRouter.Bind(CORSWithConfig(CORSConfig{ + AllowOrigins: config.AllowedOrigins, + AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, + })) pbRouter.BindFunc(installerRedirect(app, config.DashboardPath)) pbRouter.GET(config.DashboardPath, Static(ui.DistDirFS, false)). BindFunc(dashboardRemoveInstallerParam()). BindFunc(dashboardCacheControl()). - BindFunc(Gzip()) + Bind(Gzip()) // start http server // ---