config.go (2948B)
1 package cors 2 3 import ( 4 "net/http" 5 "strings" 6 7 "github.com/gin-gonic/gin" 8 ) 9 10 type cors struct { 11 allowAllOrigins bool 12 allowCredentials bool 13 allowOriginFunc func(string) bool 14 allowOrigins []string 15 normalHeaders http.Header 16 preflightHeaders http.Header 17 wildcardOrigins [][]string 18 } 19 20 var ( 21 DefaultSchemas = []string{ 22 "http://", 23 "https://", 24 } 25 ExtensionSchemas = []string{ 26 "chrome-extension://", 27 "safari-extension://", 28 "moz-extension://", 29 "ms-browser-extension://", 30 } 31 FileSchemas = []string{ 32 "file://", 33 } 34 WebSocketSchemas = []string{ 35 "ws://", 36 "wss://", 37 } 38 ) 39 40 func newCors(config Config) *cors { 41 if err := config.Validate(); err != nil { 42 panic(err.Error()) 43 } 44 45 for _, origin := range config.AllowOrigins { 46 if origin == "*" { 47 config.AllowAllOrigins = true 48 } 49 } 50 51 return &cors{ 52 allowOriginFunc: config.AllowOriginFunc, 53 allowAllOrigins: config.AllowAllOrigins, 54 allowCredentials: config.AllowCredentials, 55 allowOrigins: normalize(config.AllowOrigins), 56 normalHeaders: generateNormalHeaders(config), 57 preflightHeaders: generatePreflightHeaders(config), 58 wildcardOrigins: config.parseWildcardRules(), 59 } 60 } 61 62 func (cors *cors) applyCors(c *gin.Context) { 63 origin := c.Request.Header.Get("Origin") 64 if len(origin) == 0 { 65 // request is not a CORS request 66 return 67 } 68 host := c.Request.Host 69 70 if origin == "http://"+host || origin == "https://"+host { 71 // request is not a CORS request but have origin header. 72 // for example, use fetch api 73 return 74 } 75 76 if !cors.validateOrigin(origin) { 77 c.AbortWithStatus(http.StatusForbidden) 78 return 79 } 80 81 if c.Request.Method == "OPTIONS" { 82 cors.handlePreflight(c) 83 defer c.AbortWithStatus(http.StatusNoContent) // Using 204 is better than 200 when the request status is OPTIONS 84 } else { 85 cors.handleNormal(c) 86 } 87 88 if !cors.allowAllOrigins { 89 c.Header("Access-Control-Allow-Origin", origin) 90 } 91 } 92 93 func (cors *cors) validateWildcardOrigin(origin string) bool { 94 for _, w := range cors.wildcardOrigins { 95 if w[0] == "*" && strings.HasSuffix(origin, w[1]) { 96 return true 97 } 98 if w[1] == "*" && strings.HasPrefix(origin, w[0]) { 99 return true 100 } 101 if strings.HasPrefix(origin, w[0]) && strings.HasSuffix(origin, w[1]) { 102 return true 103 } 104 } 105 106 return false 107 } 108 109 func (cors *cors) validateOrigin(origin string) bool { 110 if cors.allowAllOrigins { 111 return true 112 } 113 for _, value := range cors.allowOrigins { 114 if value == origin { 115 return true 116 } 117 } 118 if len(cors.wildcardOrigins) > 0 && cors.validateWildcardOrigin(origin) { 119 return true 120 } 121 if cors.allowOriginFunc != nil { 122 return cors.allowOriginFunc(origin) 123 } 124 return false 125 } 126 127 func (cors *cors) handlePreflight(c *gin.Context) { 128 header := c.Writer.Header() 129 for key, value := range cors.preflightHeaders { 130 header[key] = value 131 } 132 } 133 134 func (cors *cors) handleNormal(c *gin.Context) { 135 header := c.Writer.Header() 136 for key, value := range cors.normalHeaders { 137 header[key] = value 138 } 139 }