You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ims/app/http/middleware/cors.go

240 lines
6.3 KiB

2 months ago
package middleware
import (
"net/http"
"regexp"
"strconv"
"strings"
"zestack.dev/slim"
)
// CORSConfig defines the config for CORS middleware.
type CORSConfig struct {
// AllowOrigin defines a list of origins that may access the resource.
// Optional. Default value []string{"*"}.
AllowOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
// Optional.
AllowOriginFunc func(origin string) (bool, error)
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods.
AllowMethods []string
// AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request.
// Optional. Default value []string{}.
AllowHeaders []string
// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credential flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials.
// Optional. Default value is false.
AllowCredentials bool
// ExposeHeaders defines the whitelist headers that clients are allowed to
// access.
// Optional. Default value []string{}.
ExposeHeaders []string
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached.
// Optional. Default value 0.
MaxAge int
}
func CORS() slim.MiddlewareFunc {
return CORSWithConfig(CORSConfig{})
}
func CORSWithConfig(config CORSConfig) slim.MiddlewareFunc {
if len(config.AllowOrigins) == 0 {
config.AllowOrigins = []string{"*"}
}
if len(config.AllowMethods) == 0 {
config.AllowMethods = []string{
http.MethodGet,
http.MethodHead,
http.MethodPut,
http.MethodPatch,
http.MethodPost,
http.MethodDelete,
}
}
var allowOriginPatterns []string
for _, origin := range config.AllowOrigins {
pattern := regexp.QuoteMeta(origin)
pattern = strings.Replace(pattern, "\\*", ".*", -1)
pattern = strings.Replace(pattern, "\\?", ".", -1)
pattern = "^" + pattern + "$"
allowOriginPatterns = append(allowOriginPatterns, pattern)
}
allowMethods := strings.Join(config.AllowMethods, ",")
allowHeaders := strings.Join(config.AllowHeaders, ",")
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
maxAge := strconv.Itoa(config.MaxAge)
return func(c slim.Context, next slim.HandlerFunc) error {
req := c.Request()
res := c.Response()
origin := req.Header.Get(slim.HeaderOrigin)
allowOrigin := ""
preflight := req.Method == http.MethodOptions
res.Header().Add(slim.HeaderVary, slim.HeaderOrigin)
// No Origin provided
if origin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}
if config.AllowOriginFunc != nil {
allowed, err := config.AllowOriginFunc(origin)
if err != nil {
return err
}
if allowed {
allowOrigin = origin
}
} else {
// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials {
allowOrigin = origin
break
}
if o == "*" || o == origin {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
didx := strings.Index(origin, "://")
if didx == -1 {
continue
}
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
break
}
if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin
break
}
}
}
}
// Origin isn't allowed
if allowOrigin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}
// Simple request
if !preflight {
res.Header().Set(slim.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(slim.HeaderAccessControlAllowCredentials, "true")
}
if exposeHeaders != "" {
res.Header().Set(slim.HeaderAccessControlExposeHeaders, exposeHeaders)
}
return next(c)
}
// Preflight request
res.Header().Add(slim.HeaderVary, slim.HeaderAccessControlRequestMethod)
res.Header().Add(slim.HeaderVary, slim.HeaderAccessControlRequestHeaders)
res.Header().Set(slim.HeaderAccessControlAllowOrigin, allowOrigin)
res.Header().Set(slim.HeaderAccessControlAllowMethods, allowMethods)
if config.AllowCredentials {
res.Header().Set(slim.HeaderAccessControlAllowCredentials, "true")
}
if allowHeaders != "" {
res.Header().Set(slim.HeaderAccessControlAllowHeaders, allowHeaders)
} else {
h := req.Header.Get(slim.HeaderAccessControlRequestHeaders)
if h != "" {
res.Header().Set(slim.HeaderAccessControlAllowHeaders, h)
}
}
if config.MaxAge > 0 {
res.Header().Set(slim.HeaderAccessControlMaxAge, maxAge)
}
return c.NoContent(http.StatusNoContent)
}
}
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}
// matchSubdomain compares authority with wildcard
func matchSubdomain(domain, pattern string) bool {
if !matchScheme(domain, pattern) {
return false
}
didx := strings.Index(domain, "://")
pidx := strings.Index(pattern, "://")
if didx == -1 || pidx == -1 {
return false
}
domAuth := domain[didx+3:]
// to avoid long loop by invalid long domain
if len(domAuth) > 253 {
return false
}
patAuth := pattern[pidx+3:]
domComp := strings.Split(domAuth, ".")
patComp := strings.Split(patAuth, ".")
for i := len(domComp)/2 - 1; i >= 0; i-- {
opp := len(domComp) - 1 - i
domComp[i], domComp[opp] = domComp[opp], domComp[i]
}
for i := len(patComp)/2 - 1; i >= 0; i-- {
opp := len(patComp) - 1 - i
patComp[i], patComp[opp] = patComp[opp], patComp[i]
}
for i, v := range domComp {
if len(patComp) <= i {
return false
}
p := patComp[i]
if p == "*" {
return true
}
if p != v {
return false
}
}
return false
}