mirror of
https://github.com/astaxie/beego.git
synced 2024-11-13 07:20:54 +00:00
230 lines
6.6 KiB
Go
230 lines
6.6 KiB
Go
// Copyright 2014 beego Author. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Package cors provides handlers to enable CORS support.
|
|
// Usage
|
|
// import (
|
|
// "github.com/astaxie/beego"
|
|
// "github.com/astaxie/beego/plugins/cors"
|
|
// )
|
|
//
|
|
// func main() {
|
|
// // CORS for https://foo.* origins, allowing:
|
|
// // - PUT and PATCH methods
|
|
// // - Origin header
|
|
// // - Credentials share
|
|
// beego.InsertFilter("*", beego.BeforeRouter,cors.Allow(&cors.Options{
|
|
// AllowOrigins: []string{"https://*.foo.com"},
|
|
// AllowMethods: []string{"PUT", "PATCH"},
|
|
// AllowHeaders: []string{"Origin"},
|
|
// ExposeHeaders: []string{"Content-Length"},
|
|
// AllowCredentials: true,
|
|
// }))
|
|
// beego.Run()
|
|
// }
|
|
package cors
|
|
|
|
import (
|
|
"net/http"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/astaxie/beego"
|
|
"github.com/astaxie/beego/context"
|
|
)
|
|
|
|
const (
|
|
headerAllowOrigin = "Access-Control-Allow-Origin"
|
|
headerAllowCredentials = "Access-Control-Allow-Credentials"
|
|
headerAllowHeaders = "Access-Control-Allow-Headers"
|
|
headerAllowMethods = "Access-Control-Allow-Methods"
|
|
headerExposeHeaders = "Access-Control-Expose-Headers"
|
|
headerMaxAge = "Access-Control-Max-Age"
|
|
|
|
headerOrigin = "Origin"
|
|
headerRequestMethod = "Access-Control-Request-Method"
|
|
headerRequestHeaders = "Access-Control-Request-Headers"
|
|
)
|
|
|
|
var (
|
|
defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"}
|
|
// Regex patterns are generated from AllowOrigins. These are used and generated internally.
|
|
allowOriginPatterns = []string{}
|
|
)
|
|
|
|
// Options represents Access Control options.
|
|
type Options struct {
|
|
// If set, all origins are allowed.
|
|
AllowAllOrigins bool
|
|
// A list of allowed origins. Wild cards and FQDNs are supported.
|
|
AllowOrigins []string
|
|
// If set, allows to share auth credentials such as cookies.
|
|
AllowCredentials bool
|
|
// A list of allowed HTTP methods.
|
|
AllowMethods []string
|
|
// A list of allowed HTTP headers.
|
|
AllowHeaders []string
|
|
// A list of exposed HTTP headers.
|
|
ExposeHeaders []string
|
|
// Max age of the CORS headers.
|
|
MaxAge time.Duration
|
|
}
|
|
|
|
// Header converts options into CORS headers.
|
|
func (o *Options) Header(origin string) (headers map[string]string) {
|
|
headers = make(map[string]string)
|
|
// if origin is not allowed, don't extend the headers
|
|
// with CORS headers.
|
|
if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
|
|
return
|
|
}
|
|
|
|
// add allow origin
|
|
if o.AllowAllOrigins {
|
|
headers[headerAllowOrigin] = "*"
|
|
} else {
|
|
headers[headerAllowOrigin] = origin
|
|
}
|
|
|
|
// add allow credentials
|
|
headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
|
|
|
|
// add allow methods
|
|
if len(o.AllowMethods) > 0 {
|
|
headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
|
|
}
|
|
|
|
// add allow headers
|
|
if len(o.AllowHeaders) > 0 {
|
|
headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",")
|
|
}
|
|
|
|
// add exposed header
|
|
if len(o.ExposeHeaders) > 0 {
|
|
headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
|
|
}
|
|
// add a max age header
|
|
if o.MaxAge > time.Duration(0) {
|
|
headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
|
|
}
|
|
return
|
|
}
|
|
|
|
// PreflightHeader converts options into CORS headers for a preflight response.
|
|
func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) {
|
|
headers = make(map[string]string)
|
|
if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
|
|
return
|
|
}
|
|
// verify if requested method is allowed
|
|
for _, method := range o.AllowMethods {
|
|
if method == rMethod {
|
|
headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
|
|
break
|
|
}
|
|
}
|
|
|
|
// verify if requested headers are allowed
|
|
var allowed []string
|
|
for _, rHeader := range strings.Split(rHeaders, ",") {
|
|
rHeader = strings.TrimSpace(rHeader)
|
|
lookupLoop:
|
|
for _, allowedHeader := range o.AllowHeaders {
|
|
if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) {
|
|
allowed = append(allowed, rHeader)
|
|
break lookupLoop
|
|
}
|
|
}
|
|
}
|
|
|
|
headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
|
|
// add allow origin
|
|
if o.AllowAllOrigins {
|
|
headers[headerAllowOrigin] = "*"
|
|
} else {
|
|
headers[headerAllowOrigin] = origin
|
|
}
|
|
|
|
// add allowed headers
|
|
if len(allowed) > 0 {
|
|
headers[headerAllowHeaders] = strings.Join(allowed, ",")
|
|
}
|
|
|
|
// add exposed headers
|
|
if len(o.ExposeHeaders) > 0 {
|
|
headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
|
|
}
|
|
// add a max age header
|
|
if o.MaxAge > time.Duration(0) {
|
|
headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
|
|
}
|
|
return
|
|
}
|
|
|
|
// IsOriginAllowed looks up if the origin matches one of the patterns
|
|
// generated from Options.AllowOrigins patterns.
|
|
func (o *Options) IsOriginAllowed(origin string) (allowed bool) {
|
|
for _, pattern := range allowOriginPatterns {
|
|
allowed, _ = regexp.MatchString(pattern, origin)
|
|
if allowed {
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// Allow enables CORS for requests those match the provided options.
|
|
func Allow(opts *Options) beego.FilterFunc {
|
|
// Allow default headers if nothing is specified.
|
|
if len(opts.AllowHeaders) == 0 {
|
|
opts.AllowHeaders = defaultAllowHeaders
|
|
}
|
|
|
|
for _, origin := range opts.AllowOrigins {
|
|
pattern := regexp.QuoteMeta(origin)
|
|
pattern = strings.Replace(pattern, "\\*", ".*", -1)
|
|
pattern = strings.Replace(pattern, "\\?", ".", -1)
|
|
allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$")
|
|
}
|
|
|
|
return func(ctx *context.Context) {
|
|
var (
|
|
origin = ctx.Input.Header(headerOrigin)
|
|
requestedMethod = ctx.Input.Header(headerRequestMethod)
|
|
requestedHeaders = ctx.Input.Header(headerRequestHeaders)
|
|
// additional headers to be added
|
|
// to the response.
|
|
headers map[string]string
|
|
)
|
|
|
|
if ctx.Input.Method() == "OPTIONS" &&
|
|
(requestedMethod != "" || requestedHeaders != "") {
|
|
headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders)
|
|
for key, value := range headers {
|
|
ctx.Output.Header(key, value)
|
|
}
|
|
ctx.Output.SetStatus(http.StatusOK)
|
|
ctx.WriteString("")
|
|
return
|
|
}
|
|
headers = opts.Header(origin)
|
|
|
|
for key, value := range headers {
|
|
ctx.Output.Header(key, value)
|
|
}
|
|
}
|
|
}
|