gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

cache.go (7692B)


      1 package validator
      2 
      3 import (
      4 	"fmt"
      5 	"reflect"
      6 	"strings"
      7 	"sync"
      8 	"sync/atomic"
      9 )
     10 
     11 type tagType uint8
     12 
     13 const (
     14 	typeDefault tagType = iota
     15 	typeOmitEmpty
     16 	typeIsDefault
     17 	typeNoStructLevel
     18 	typeStructOnly
     19 	typeDive
     20 	typeOr
     21 	typeKeys
     22 	typeEndKeys
     23 )
     24 
     25 const (
     26 	invalidValidation   = "Invalid validation tag on field '%s'"
     27 	undefinedValidation = "Undefined validation function '%s' on field '%s'"
     28 	keysTagNotDefined   = "'" + endKeysTag + "' tag encountered without a corresponding '" + keysTag + "' tag"
     29 )
     30 
     31 type structCache struct {
     32 	lock sync.Mutex
     33 	m    atomic.Value // map[reflect.Type]*cStruct
     34 }
     35 
     36 func (sc *structCache) Get(key reflect.Type) (c *cStruct, found bool) {
     37 	c, found = sc.m.Load().(map[reflect.Type]*cStruct)[key]
     38 	return
     39 }
     40 
     41 func (sc *structCache) Set(key reflect.Type, value *cStruct) {
     42 	m := sc.m.Load().(map[reflect.Type]*cStruct)
     43 	nm := make(map[reflect.Type]*cStruct, len(m)+1)
     44 	for k, v := range m {
     45 		nm[k] = v
     46 	}
     47 	nm[key] = value
     48 	sc.m.Store(nm)
     49 }
     50 
     51 type tagCache struct {
     52 	lock sync.Mutex
     53 	m    atomic.Value // map[string]*cTag
     54 }
     55 
     56 func (tc *tagCache) Get(key string) (c *cTag, found bool) {
     57 	c, found = tc.m.Load().(map[string]*cTag)[key]
     58 	return
     59 }
     60 
     61 func (tc *tagCache) Set(key string, value *cTag) {
     62 	m := tc.m.Load().(map[string]*cTag)
     63 	nm := make(map[string]*cTag, len(m)+1)
     64 	for k, v := range m {
     65 		nm[k] = v
     66 	}
     67 	nm[key] = value
     68 	tc.m.Store(nm)
     69 }
     70 
     71 type cStruct struct {
     72 	name   string
     73 	fields []*cField
     74 	fn     StructLevelFuncCtx
     75 }
     76 
     77 type cField struct {
     78 	idx        int
     79 	name       string
     80 	altName    string
     81 	namesEqual bool
     82 	cTags      *cTag
     83 }
     84 
     85 type cTag struct {
     86 	tag                  string
     87 	aliasTag             string
     88 	actualAliasTag       string
     89 	param                string
     90 	keys                 *cTag // only populated when using tag's 'keys' and 'endkeys' for map key validation
     91 	next                 *cTag
     92 	fn                   FuncCtx
     93 	typeof               tagType
     94 	hasTag               bool
     95 	hasAlias             bool
     96 	hasParam             bool // true if parameter used eg. eq= where the equal sign has been set
     97 	isBlockEnd           bool // indicates the current tag represents the last validation in the block
     98 	runValidationWhenNil bool
     99 }
    100 
    101 func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStruct {
    102 	v.structCache.lock.Lock()
    103 	defer v.structCache.lock.Unlock() // leave as defer! because if inner panics, it will never get unlocked otherwise!
    104 
    105 	typ := current.Type()
    106 
    107 	// could have been multiple trying to access, but once first is done this ensures struct
    108 	// isn't parsed again.
    109 	cs, ok := v.structCache.Get(typ)
    110 	if ok {
    111 		return cs
    112 	}
    113 
    114 	cs = &cStruct{name: sName, fields: make([]*cField, 0), fn: v.structLevelFuncs[typ]}
    115 
    116 	numFields := current.NumField()
    117 	rules := v.rules[typ]
    118 
    119 	var ctag *cTag
    120 	var fld reflect.StructField
    121 	var tag string
    122 	var customName string
    123 
    124 	for i := 0; i < numFields; i++ {
    125 
    126 		fld = typ.Field(i)
    127 
    128 		if !fld.Anonymous && len(fld.PkgPath) > 0 {
    129 			continue
    130 		}
    131 
    132 		if rtag, ok := rules[fld.Name]; ok {
    133 			tag = rtag
    134 		} else {
    135 			tag = fld.Tag.Get(v.tagName)
    136 		}
    137 
    138 		if tag == skipValidationTag {
    139 			continue
    140 		}
    141 
    142 		customName = fld.Name
    143 
    144 		if v.hasTagNameFunc {
    145 			name := v.tagNameFunc(fld)
    146 			if len(name) > 0 {
    147 				customName = name
    148 			}
    149 		}
    150 
    151 		// NOTE: cannot use shared tag cache, because tags may be equal, but things like alias may be different
    152 		// and so only struct level caching can be used instead of combined with Field tag caching
    153 
    154 		if len(tag) > 0 {
    155 			ctag, _ = v.parseFieldTagsRecursive(tag, fld.Name, "", false)
    156 		} else {
    157 			// even if field doesn't have validations need cTag for traversing to potential inner/nested
    158 			// elements of the field.
    159 			ctag = new(cTag)
    160 		}
    161 
    162 		cs.fields = append(cs.fields, &cField{
    163 			idx:        i,
    164 			name:       fld.Name,
    165 			altName:    customName,
    166 			cTags:      ctag,
    167 			namesEqual: fld.Name == customName,
    168 		})
    169 	}
    170 	v.structCache.Set(typ, cs)
    171 	return cs
    172 }
    173 
    174 func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias string, hasAlias bool) (firstCtag *cTag, current *cTag) {
    175 	var t string
    176 	noAlias := len(alias) == 0
    177 	tags := strings.Split(tag, tagSeparator)
    178 
    179 	for i := 0; i < len(tags); i++ {
    180 		t = tags[i]
    181 		if noAlias {
    182 			alias = t
    183 		}
    184 
    185 		// check map for alias and process new tags, otherwise process as usual
    186 		if tagsVal, found := v.aliases[t]; found {
    187 			if i == 0 {
    188 				firstCtag, current = v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
    189 			} else {
    190 				next, curr := v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
    191 				current.next, current = next, curr
    192 
    193 			}
    194 			continue
    195 		}
    196 
    197 		var prevTag tagType
    198 
    199 		if i == 0 {
    200 			current = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true, typeof: typeDefault}
    201 			firstCtag = current
    202 		} else {
    203 			prevTag = current.typeof
    204 			current.next = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true}
    205 			current = current.next
    206 		}
    207 
    208 		switch t {
    209 		case diveTag:
    210 			current.typeof = typeDive
    211 			continue
    212 
    213 		case keysTag:
    214 			current.typeof = typeKeys
    215 
    216 			if i == 0 || prevTag != typeDive {
    217 				panic(fmt.Sprintf("'%s' tag must be immediately preceded by the '%s' tag", keysTag, diveTag))
    218 			}
    219 
    220 			current.typeof = typeKeys
    221 
    222 			// need to pass along only keys tag
    223 			// need to increment i to skip over the keys tags
    224 			b := make([]byte, 0, 64)
    225 
    226 			i++
    227 
    228 			for ; i < len(tags); i++ {
    229 
    230 				b = append(b, tags[i]...)
    231 				b = append(b, ',')
    232 
    233 				if tags[i] == endKeysTag {
    234 					break
    235 				}
    236 			}
    237 
    238 			current.keys, _ = v.parseFieldTagsRecursive(string(b[:len(b)-1]), fieldName, "", false)
    239 			continue
    240 
    241 		case endKeysTag:
    242 			current.typeof = typeEndKeys
    243 
    244 			// if there are more in tags then there was no keysTag defined
    245 			// and an error should be thrown
    246 			if i != len(tags)-1 {
    247 				panic(keysTagNotDefined)
    248 			}
    249 			return
    250 
    251 		case omitempty:
    252 			current.typeof = typeOmitEmpty
    253 			continue
    254 
    255 		case structOnlyTag:
    256 			current.typeof = typeStructOnly
    257 			continue
    258 
    259 		case noStructLevelTag:
    260 			current.typeof = typeNoStructLevel
    261 			continue
    262 
    263 		default:
    264 			if t == isdefault {
    265 				current.typeof = typeIsDefault
    266 			}
    267 			// if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C"
    268 			orVals := strings.Split(t, orSeparator)
    269 
    270 			for j := 0; j < len(orVals); j++ {
    271 				vals := strings.SplitN(orVals[j], tagKeySeparator, 2)
    272 				if noAlias {
    273 					alias = vals[0]
    274 					current.aliasTag = alias
    275 				} else {
    276 					current.actualAliasTag = t
    277 				}
    278 
    279 				if j > 0 {
    280 					current.next = &cTag{aliasTag: alias, actualAliasTag: current.actualAliasTag, hasAlias: hasAlias, hasTag: true}
    281 					current = current.next
    282 				}
    283 				current.hasParam = len(vals) > 1
    284 
    285 				current.tag = vals[0]
    286 				if len(current.tag) == 0 {
    287 					panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, fieldName)))
    288 				}
    289 
    290 				if wrapper, ok := v.validations[current.tag]; ok {
    291 					current.fn = wrapper.fn
    292 					current.runValidationWhenNil = wrapper.runValidatinOnNil
    293 				} else {
    294 					panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, current.tag, fieldName)))
    295 				}
    296 
    297 				if len(orVals) > 1 {
    298 					current.typeof = typeOr
    299 				}
    300 
    301 				if len(vals) > 1 {
    302 					current.param = strings.Replace(strings.Replace(vals[1], utf8HexComma, ",", -1), utf8Pipe, "|", -1)
    303 				}
    304 			}
    305 			current.isBlockEnd = true
    306 		}
    307 	}
    308 	return
    309 }
    310 
    311 func (v *Validate) fetchCacheTag(tag string) *cTag {
    312 	// find cached tag
    313 	ctag, found := v.tagCache.Get(tag)
    314 	if !found {
    315 		v.tagCache.lock.Lock()
    316 		defer v.tagCache.lock.Unlock()
    317 
    318 		// could have been multiple trying to access, but once first is done this ensures tag
    319 		// isn't parsed again.
    320 		ctag, found = v.tagCache.Get(tag)
    321 		if !found {
    322 			ctag, _ = v.parseFieldTagsRecursive(tag, "", "", false)
    323 			v.tagCache.Set(tag, ctag)
    324 		}
    325 	}
    326 	return ctag
    327 }