normalized builtin middlewares to return hook.Handler
This commit is contained in:
		
							parent
							
								
									47d5ea3ce2
								
							
						
					
					
						commit
						f9ee710cdd
					
				|  | @ -19,6 +19,7 @@ import ( | |||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/pocketbase/pocketbase/core" | ||||
| 	"github.com/pocketbase/pocketbase/tools/hook" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
|  | @ -124,7 +125,7 @@ var DefaultCORSConfig = CORSConfig{ | |||
| } | ||||
| 
 | ||||
| // CORSWithConfig returns a CORS middleware with config.
 | ||||
| func CORSWithConfig(config CORSConfig) func(e *core.RequestEvent) error { | ||||
| func CORSWithConfig(config CORSConfig) *hook.Handler[*core.RequestEvent] { | ||||
| 	// Defaults
 | ||||
| 	if len(config.AllowOrigins) == 0 { | ||||
| 		config.AllowOrigins = DefaultCORSConfig.AllowOrigins | ||||
|  | @ -151,108 +152,112 @@ func CORSWithConfig(config CORSConfig) func(e *core.RequestEvent) error { | |||
| 		maxAge = strconv.Itoa(config.MaxAge) | ||||
| 	} | ||||
| 
 | ||||
| 	return func(e *core.RequestEvent) error { | ||||
| 		req := e.Request | ||||
| 		res := e.Response | ||||
| 		origin := req.Header.Get("Origin") | ||||
| 		allowOrigin := "" | ||||
| 	return &hook.Handler[*core.RequestEvent]{ | ||||
| 		Id:       DefaultCorsMiddlewareId, | ||||
| 		Priority: DefaultCorsMiddlewarePriority, | ||||
| 		Func: func(e *core.RequestEvent) error { | ||||
| 			req := e.Request | ||||
| 			res := e.Response | ||||
| 			origin := req.Header.Get("Origin") | ||||
| 			allowOrigin := "" | ||||
| 
 | ||||
| 		res.Header().Add("Vary", "Origin") | ||||
| 			res.Header().Add("Vary", "Origin") | ||||
| 
 | ||||
| 		// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
 | ||||
| 		// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
 | ||||
| 		// For simplicity we just consider method type and later `Origin` header.
 | ||||
| 		preflight := req.Method == http.MethodOptions | ||||
| 			// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
 | ||||
| 			// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
 | ||||
| 			// For simplicity we just consider method type and later `Origin` header.
 | ||||
| 			preflight := req.Method == http.MethodOptions | ||||
| 
 | ||||
| 		// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
 | ||||
| 		if origin == "" { | ||||
| 			if !preflight { | ||||
| 				return e.Next() | ||||
| 			// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
 | ||||
| 			if origin == "" { | ||||
| 				if !preflight { | ||||
| 					return e.Next() | ||||
| 				} | ||||
| 				return e.NoContent(http.StatusNoContent) | ||||
| 			} | ||||
| 			return e.NoContent(http.StatusNoContent) | ||||
| 		} | ||||
| 
 | ||||
| 		if config.AllowOriginFunc != nil { | ||||
| 			allowed, err := config.AllowOriginFunc(origin) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if allowed { | ||||
| 				allowOrigin = origin | ||||
| 			} | ||||
| 		} else { | ||||
| 			// Check allowed origins
 | ||||
| 			for _, o := range config.AllowOrigins { | ||||
| 				if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials { | ||||
| 			if config.AllowOriginFunc != nil { | ||||
| 				allowed, err := config.AllowOriginFunc(origin) | ||||
| 				if err != nil { | ||||
| 					return err | ||||
| 				} | ||||
| 				if allowed { | ||||
| 					allowOrigin = origin | ||||
| 					break | ||||
| 				} | ||||
| 				if o == "*" || o == origin { | ||||
| 					allowOrigin = o | ||||
| 					break | ||||
| 				} | ||||
| 				if matchSubdomain(origin, o) { | ||||
| 					allowOrigin = origin | ||||
| 					break | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			checkPatterns := false | ||||
| 			if allowOrigin == "" { | ||||
| 				// to avoid regex cost by invalid (long) domains (253 is domain name max limit)
 | ||||
| 				if len(origin) <= (253+3+5) && strings.Contains(origin, "://") { | ||||
| 					checkPatterns = true | ||||
| 				} | ||||
| 			} | ||||
| 			if checkPatterns { | ||||
| 				for _, re := range allowOriginPatterns { | ||||
| 					if match, _ := regexp.MatchString(re, origin); match { | ||||
| 			} else { | ||||
| 				// Check allowed origins
 | ||||
| 				for _, o := range config.AllowOrigins { | ||||
| 					if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials { | ||||
| 						allowOrigin = origin | ||||
| 						break | ||||
| 					} | ||||
| 					if o == "*" || o == origin { | ||||
| 						allowOrigin = o | ||||
| 						break | ||||
| 					} | ||||
| 					if matchSubdomain(origin, o) { | ||||
| 						allowOrigin = origin | ||||
| 						break | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		// Origin not allowed
 | ||||
| 		if allowOrigin == "" { | ||||
| 				checkPatterns := false | ||||
| 				if allowOrigin == "" { | ||||
| 					// to avoid regex cost by invalid (long) domains (253 is domain name max limit)
 | ||||
| 					if len(origin) <= (253+3+5) && strings.Contains(origin, "://") { | ||||
| 						checkPatterns = true | ||||
| 					} | ||||
| 				} | ||||
| 				if checkPatterns { | ||||
| 					for _, re := range allowOriginPatterns { | ||||
| 						if match, _ := regexp.MatchString(re, origin); match { | ||||
| 							allowOrigin = origin | ||||
| 							break | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			// Origin not allowed
 | ||||
| 			if allowOrigin == "" { | ||||
| 				if !preflight { | ||||
| 					return e.Next() | ||||
| 				} | ||||
| 				return e.NoContent(http.StatusNoContent) | ||||
| 			} | ||||
| 
 | ||||
| 			res.Header().Set("Access-Control-Allow-Origin", allowOrigin) | ||||
| 			if config.AllowCredentials { | ||||
| 				res.Header().Set("Access-Control-Allow-Credentials", "true") | ||||
| 			} | ||||
| 
 | ||||
| 			// Simple request
 | ||||
| 			if !preflight { | ||||
| 				if exposeHeaders != "" { | ||||
| 					res.Header().Set("Access-Control-Expose-Headers", exposeHeaders) | ||||
| 				} | ||||
| 				return e.Next() | ||||
| 			} | ||||
| 
 | ||||
| 			// Preflight request
 | ||||
| 			res.Header().Add("Vary", "Access-Control-Request-Method") | ||||
| 			res.Header().Add("Vary", "Access-Control-Request-Headers") | ||||
| 			res.Header().Set("Access-Control-Allow-Methods", allowMethods) | ||||
| 
 | ||||
| 			if allowHeaders != "" { | ||||
| 				res.Header().Set("Access-Control-Allow-Headers", allowHeaders) | ||||
| 			} else { | ||||
| 				h := req.Header.Get("Access-Control-Request-Headers") | ||||
| 				if h != "" { | ||||
| 					res.Header().Set("Access-Control-Allow-Headers", h) | ||||
| 				} | ||||
| 			} | ||||
| 			if config.MaxAge != 0 { | ||||
| 				res.Header().Set("Access-Control-Max-Age", maxAge) | ||||
| 			} | ||||
| 
 | ||||
| 			return e.NoContent(http.StatusNoContent) | ||||
| 		} | ||||
| 
 | ||||
| 		res.Header().Set("Access-Control-Allow-Origin", allowOrigin) | ||||
| 		if config.AllowCredentials { | ||||
| 			res.Header().Set("Access-Control-Allow-Credentials", "true") | ||||
| 		} | ||||
| 
 | ||||
| 		// Simple request
 | ||||
| 		if !preflight { | ||||
| 			if exposeHeaders != "" { | ||||
| 				res.Header().Set("Access-Control-Expose-Headers", exposeHeaders) | ||||
| 			} | ||||
| 			return e.Next() | ||||
| 		} | ||||
| 
 | ||||
| 		// Preflight request
 | ||||
| 		res.Header().Add("Vary", "Access-Control-Request-Method") | ||||
| 		res.Header().Add("Vary", "Access-Control-Request-Headers") | ||||
| 		res.Header().Set("Access-Control-Allow-Methods", allowMethods) | ||||
| 
 | ||||
| 		if allowHeaders != "" { | ||||
| 			res.Header().Set("Access-Control-Allow-Headers", allowHeaders) | ||||
| 		} else { | ||||
| 			h := req.Header.Get("Access-Control-Request-Headers") | ||||
| 			if h != "" { | ||||
| 				res.Header().Set("Access-Control-Allow-Headers", h) | ||||
| 			} | ||||
| 		} | ||||
| 		if config.MaxAge != 0 { | ||||
| 			res.Header().Set("Access-Control-Max-Age", maxAge) | ||||
| 		} | ||||
| 
 | ||||
| 		return e.NoContent(http.StatusNoContent) | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -18,6 +18,7 @@ import ( | |||
| 	"sync" | ||||
| 
 | ||||
| 	"github.com/pocketbase/pocketbase/core" | ||||
| 	"github.com/pocketbase/pocketbase/tools/hook" | ||||
| 	"github.com/pocketbase/pocketbase/tools/router" | ||||
| ) | ||||
| 
 | ||||
|  | @ -25,6 +26,10 @@ const ( | |||
| 	gzipScheme = "gzip" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	DefaultGzipMiddlewareId = "pbGzip" | ||||
| ) | ||||
| 
 | ||||
| // GzipConfig defines the config for Gzip middleware.
 | ||||
| type GzipConfig struct { | ||||
| 	// Gzip compression level.
 | ||||
|  | @ -46,12 +51,12 @@ type GzipConfig struct { | |||
| } | ||||
| 
 | ||||
| // Gzip returns a middleware which compresses HTTP response using Gzip compression scheme.
 | ||||
| func Gzip() func(*core.RequestEvent) error { | ||||
| func Gzip() *hook.Handler[*core.RequestEvent] { | ||||
| 	return GzipWithConfig(GzipConfig{}) | ||||
| } | ||||
| 
 | ||||
| // GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
 | ||||
| func GzipWithConfig(config GzipConfig) func(*core.RequestEvent) error { | ||||
| func GzipWithConfig(config GzipConfig) *hook.Handler[*core.RequestEvent] { | ||||
| 	if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
 | ||||
| 		panic(errors.New("invalid gzip level")) | ||||
| 	} | ||||
|  | @ -79,54 +84,57 @@ func GzipWithConfig(config GzipConfig) func(*core.RequestEvent) error { | |||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	return func(e *core.RequestEvent) error { | ||||
| 		e.Response.Header().Add("Vary", "Accept-Encoding") | ||||
| 		if strings.Contains(e.Request.Header.Get("Accept-Encoding"), gzipScheme) { | ||||
| 			w, ok := pool.Get().(*gzip.Writer) | ||||
| 			if !ok { | ||||
| 				return e.InternalServerError("", errors.New("failed to get gzip.Writer")) | ||||
| 	return &hook.Handler[*core.RequestEvent]{ | ||||
| 		Id: DefaultGzipMiddlewareId, | ||||
| 		Func: func(e *core.RequestEvent) error { | ||||
| 			e.Response.Header().Add("Vary", "Accept-Encoding") | ||||
| 			if strings.Contains(e.Request.Header.Get("Accept-Encoding"), gzipScheme) { | ||||
| 				w, ok := pool.Get().(*gzip.Writer) | ||||
| 				if !ok { | ||||
| 					return e.InternalServerError("", errors.New("failed to get gzip.Writer")) | ||||
| 				} | ||||
| 
 | ||||
| 				rw := e.Response | ||||
| 				w.Reset(rw) | ||||
| 
 | ||||
| 				buf := bpool.Get().(*bytes.Buffer) | ||||
| 				buf.Reset() | ||||
| 
 | ||||
| 				grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf} | ||||
| 				defer func() { | ||||
| 					// There are different reasons for cases when we have not yet written response to the client and now need to do so.
 | ||||
| 					// a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
 | ||||
| 					// b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
 | ||||
| 					if !grw.wroteBody { | ||||
| 						if rw.Header().Get("Content-Encoding") == gzipScheme { | ||||
| 							rw.Header().Del("Content-Encoding") | ||||
| 						} | ||||
| 						if grw.wroteHeader { | ||||
| 							rw.WriteHeader(grw.code) | ||||
| 						} | ||||
| 						// We have to reset response to it's pristine state when
 | ||||
| 						// nothing is written to body or error is returned.
 | ||||
| 						// See issue echo#424, echo#407.
 | ||||
| 						e.Response = rw | ||||
| 						w.Reset(io.Discard) | ||||
| 					} else if !grw.minLengthExceeded { | ||||
| 						// Write uncompressed response
 | ||||
| 						e.Response = rw | ||||
| 						if grw.wroteHeader { | ||||
| 							rw.WriteHeader(grw.code) | ||||
| 						} | ||||
| 						grw.buffer.WriteTo(rw) | ||||
| 						w.Reset(io.Discard) | ||||
| 					} | ||||
| 					w.Close() | ||||
| 					bpool.Put(buf) | ||||
| 					pool.Put(w) | ||||
| 				}() | ||||
| 				e.Response = grw | ||||
| 			} | ||||
| 
 | ||||
| 			rw := e.Response | ||||
| 			w.Reset(rw) | ||||
| 
 | ||||
| 			buf := bpool.Get().(*bytes.Buffer) | ||||
| 			buf.Reset() | ||||
| 
 | ||||
| 			grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf} | ||||
| 			defer func() { | ||||
| 				// There are different reasons for cases when we have not yet written response to the client and now need to do so.
 | ||||
| 				// a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
 | ||||
| 				// b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
 | ||||
| 				if !grw.wroteBody { | ||||
| 					if rw.Header().Get("Content-Encoding") == gzipScheme { | ||||
| 						rw.Header().Del("Content-Encoding") | ||||
| 					} | ||||
| 					if grw.wroteHeader { | ||||
| 						rw.WriteHeader(grw.code) | ||||
| 					} | ||||
| 					// We have to reset response to it's pristine state when
 | ||||
| 					// nothing is written to body or error is returned.
 | ||||
| 					// See issue echo#424, echo#407.
 | ||||
| 					e.Response = rw | ||||
| 					w.Reset(io.Discard) | ||||
| 				} else if !grw.minLengthExceeded { | ||||
| 					// Write uncompressed response
 | ||||
| 					e.Response = rw | ||||
| 					if grw.wroteHeader { | ||||
| 						rw.WriteHeader(grw.code) | ||||
| 					} | ||||
| 					grw.buffer.WriteTo(rw) | ||||
| 					w.Reset(io.Discard) | ||||
| 				} | ||||
| 				w.Close() | ||||
| 				bpool.Put(buf) | ||||
| 				pool.Put(w) | ||||
| 			}() | ||||
| 			e.Response = grw | ||||
| 		} | ||||
| 
 | ||||
| 		return e.Next() | ||||
| 			return e.Next() | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -83,21 +83,17 @@ func Serve(app core.App, config ServeConfig) error { | |||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	pbRouter.Bind(&hook.Handler[*core.RequestEvent]{ | ||||
| 		Id: DefaultCorsMiddlewareId, | ||||
| 		Func: CORSWithConfig(CORSConfig{ | ||||
| 			AllowOrigins: config.AllowedOrigins, | ||||
| 			AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, | ||||
| 		}), | ||||
| 		Priority: DefaultCorsMiddlewarePriority, | ||||
| 	}) | ||||
| 	pbRouter.Bind(CORSWithConfig(CORSConfig{ | ||||
| 		AllowOrigins: config.AllowedOrigins, | ||||
| 		AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, | ||||
| 	})) | ||||
| 
 | ||||
| 	pbRouter.BindFunc(installerRedirect(app, config.DashboardPath)) | ||||
| 
 | ||||
| 	pbRouter.GET(config.DashboardPath, Static(ui.DistDirFS, false)). | ||||
| 		BindFunc(dashboardRemoveInstallerParam()). | ||||
| 		BindFunc(dashboardCacheControl()). | ||||
| 		BindFunc(Gzip()) | ||||
| 		Bind(Gzip()) | ||||
| 
 | ||||
| 	// start http server
 | ||||
| 	// ---
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue