gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

config.go (2948B)


      1 package cors
      2 
      3 import (
      4 	"net/http"
      5 	"strings"
      6 
      7 	"github.com/gin-gonic/gin"
      8 )
      9 
     10 type cors struct {
     11 	allowAllOrigins  bool
     12 	allowCredentials bool
     13 	allowOriginFunc  func(string) bool
     14 	allowOrigins     []string
     15 	normalHeaders    http.Header
     16 	preflightHeaders http.Header
     17 	wildcardOrigins  [][]string
     18 }
     19 
     20 var (
     21 	DefaultSchemas = []string{
     22 		"http://",
     23 		"https://",
     24 	}
     25 	ExtensionSchemas = []string{
     26 		"chrome-extension://",
     27 		"safari-extension://",
     28 		"moz-extension://",
     29 		"ms-browser-extension://",
     30 	}
     31 	FileSchemas = []string{
     32 		"file://",
     33 	}
     34 	WebSocketSchemas = []string{
     35 		"ws://",
     36 		"wss://",
     37 	}
     38 )
     39 
     40 func newCors(config Config) *cors {
     41 	if err := config.Validate(); err != nil {
     42 		panic(err.Error())
     43 	}
     44 
     45 	for _, origin := range config.AllowOrigins {
     46 		if origin == "*" {
     47 			config.AllowAllOrigins = true
     48 		}
     49 	}
     50 
     51 	return &cors{
     52 		allowOriginFunc:  config.AllowOriginFunc,
     53 		allowAllOrigins:  config.AllowAllOrigins,
     54 		allowCredentials: config.AllowCredentials,
     55 		allowOrigins:     normalize(config.AllowOrigins),
     56 		normalHeaders:    generateNormalHeaders(config),
     57 		preflightHeaders: generatePreflightHeaders(config),
     58 		wildcardOrigins:  config.parseWildcardRules(),
     59 	}
     60 }
     61 
     62 func (cors *cors) applyCors(c *gin.Context) {
     63 	origin := c.Request.Header.Get("Origin")
     64 	if len(origin) == 0 {
     65 		// request is not a CORS request
     66 		return
     67 	}
     68 	host := c.Request.Host
     69 
     70 	if origin == "http://"+host || origin == "https://"+host {
     71 		// request is not a CORS request but have origin header.
     72 		// for example, use fetch api
     73 		return
     74 	}
     75 
     76 	if !cors.validateOrigin(origin) {
     77 		c.AbortWithStatus(http.StatusForbidden)
     78 		return
     79 	}
     80 
     81 	if c.Request.Method == "OPTIONS" {
     82 		cors.handlePreflight(c)
     83 		defer c.AbortWithStatus(http.StatusNoContent) // Using 204 is better than 200 when the request status is OPTIONS
     84 	} else {
     85 		cors.handleNormal(c)
     86 	}
     87 
     88 	if !cors.allowAllOrigins {
     89 		c.Header("Access-Control-Allow-Origin", origin)
     90 	}
     91 }
     92 
     93 func (cors *cors) validateWildcardOrigin(origin string) bool {
     94 	for _, w := range cors.wildcardOrigins {
     95 		if w[0] == "*" && strings.HasSuffix(origin, w[1]) {
     96 			return true
     97 		}
     98 		if w[1] == "*" && strings.HasPrefix(origin, w[0]) {
     99 			return true
    100 		}
    101 		if strings.HasPrefix(origin, w[0]) && strings.HasSuffix(origin, w[1]) {
    102 			return true
    103 		}
    104 	}
    105 
    106 	return false
    107 }
    108 
    109 func (cors *cors) validateOrigin(origin string) bool {
    110 	if cors.allowAllOrigins {
    111 		return true
    112 	}
    113 	for _, value := range cors.allowOrigins {
    114 		if value == origin {
    115 			return true
    116 		}
    117 	}
    118 	if len(cors.wildcardOrigins) > 0 && cors.validateWildcardOrigin(origin) {
    119 		return true
    120 	}
    121 	if cors.allowOriginFunc != nil {
    122 		return cors.allowOriginFunc(origin)
    123 	}
    124 	return false
    125 }
    126 
    127 func (cors *cors) handlePreflight(c *gin.Context) {
    128 	header := c.Writer.Header()
    129 	for key, value := range cors.preflightHeaders {
    130 		header[key] = value
    131 	}
    132 }
    133 
    134 func (cors *cors) handleNormal(c *gin.Context) {
    135 	header := c.Writer.Header()
    136 	for key, value := range cors.normalHeaders {
    137 		header[key] = value
    138 	}
    139 }