utils.go (2308B)
1 package cors 2 3 import ( 4 "net/http" 5 "strconv" 6 "strings" 7 "time" 8 ) 9 10 type converter func(string) string 11 12 func generateNormalHeaders(c Config) http.Header { 13 headers := make(http.Header) 14 if c.AllowCredentials { 15 headers.Set("Access-Control-Allow-Credentials", "true") 16 } 17 if len(c.ExposeHeaders) > 0 { 18 exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey) 19 headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ",")) 20 } 21 if c.AllowAllOrigins { 22 headers.Set("Access-Control-Allow-Origin", "*") 23 } else { 24 headers.Set("Vary", "Origin") 25 } 26 return headers 27 } 28 29 func generatePreflightHeaders(c Config) http.Header { 30 headers := make(http.Header) 31 if c.AllowCredentials { 32 headers.Set("Access-Control-Allow-Credentials", "true") 33 } 34 if len(c.AllowMethods) > 0 { 35 allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper) 36 value := strings.Join(allowMethods, ",") 37 headers.Set("Access-Control-Allow-Methods", value) 38 } 39 if len(c.AllowHeaders) > 0 { 40 allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey) 41 value := strings.Join(allowHeaders, ",") 42 headers.Set("Access-Control-Allow-Headers", value) 43 } 44 if c.MaxAge > time.Duration(0) { 45 value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10) 46 headers.Set("Access-Control-Max-Age", value) 47 } 48 if c.AllowAllOrigins { 49 headers.Set("Access-Control-Allow-Origin", "*") 50 } else { 51 // Always set Vary headers 52 // see https://github.com/rs/cors/issues/10, 53 // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001 54 55 headers.Add("Vary", "Origin") 56 headers.Add("Vary", "Access-Control-Request-Method") 57 headers.Add("Vary", "Access-Control-Request-Headers") 58 } 59 return headers 60 } 61 62 func normalize(values []string) []string { 63 if values == nil { 64 return nil 65 } 66 distinctMap := make(map[string]bool, len(values)) 67 normalized := make([]string, 0, len(values)) 68 for _, value := range values { 69 value = strings.TrimSpace(value) 70 value = strings.ToLower(value) 71 if _, seen := distinctMap[value]; !seen { 72 normalized = append(normalized, value) 73 distinctMap[value] = true 74 } 75 } 76 return normalized 77 } 78 79 func convert(s []string, c converter) []string { 80 var out []string 81 for _, i := range s { 82 out = append(out, c(i)) 83 } 84 return out 85 }