decode_query.go (2686B)
1 package msgpack 2 3 import ( 4 "fmt" 5 "strconv" 6 "strings" 7 8 "github.com/vmihailenco/msgpack/v5/msgpcode" 9 ) 10 11 type queryResult struct { 12 query string 13 key string 14 hasAsterisk bool 15 16 values []interface{} 17 } 18 19 func (q *queryResult) nextKey() { 20 ind := strings.IndexByte(q.query, '.') 21 if ind == -1 { 22 q.key = q.query 23 q.query = "" 24 return 25 } 26 q.key = q.query[:ind] 27 q.query = q.query[ind+1:] 28 } 29 30 // Query extracts data specified by the query from the msgpack stream skipping 31 // any other data. Query consists of map keys and array indexes separated with dot, 32 // e.g. key1.0.key2. 33 func (d *Decoder) Query(query string) ([]interface{}, error) { 34 res := queryResult{ 35 query: query, 36 } 37 if err := d.query(&res); err != nil { 38 return nil, err 39 } 40 return res.values, nil 41 } 42 43 func (d *Decoder) query(q *queryResult) error { 44 q.nextKey() 45 if q.key == "" { 46 v, err := d.decodeInterfaceCond() 47 if err != nil { 48 return err 49 } 50 q.values = append(q.values, v) 51 return nil 52 } 53 54 code, err := d.PeekCode() 55 if err != nil { 56 return err 57 } 58 59 switch { 60 case code == msgpcode.Map16 || code == msgpcode.Map32 || msgpcode.IsFixedMap(code): 61 err = d.queryMapKey(q) 62 case code == msgpcode.Array16 || code == msgpcode.Array32 || msgpcode.IsFixedArray(code): 63 err = d.queryArrayIndex(q) 64 default: 65 err = fmt.Errorf("msgpack: unsupported code=%x decoding key=%q", code, q.key) 66 } 67 return err 68 } 69 70 func (d *Decoder) queryMapKey(q *queryResult) error { 71 n, err := d.DecodeMapLen() 72 if err != nil { 73 return err 74 } 75 if n == -1 { 76 return nil 77 } 78 79 for i := 0; i < n; i++ { 80 key, err := d.decodeStringTemp() 81 if err != nil { 82 return err 83 } 84 85 if key == q.key { 86 if err := d.query(q); err != nil { 87 return err 88 } 89 if q.hasAsterisk { 90 return d.skipNext((n - i - 1) * 2) 91 } 92 return nil 93 } 94 95 if err := d.Skip(); err != nil { 96 return err 97 } 98 } 99 100 return nil 101 } 102 103 func (d *Decoder) queryArrayIndex(q *queryResult) error { 104 n, err := d.DecodeArrayLen() 105 if err != nil { 106 return err 107 } 108 if n == -1 { 109 return nil 110 } 111 112 if q.key == "*" { 113 q.hasAsterisk = true 114 115 query := q.query 116 for i := 0; i < n; i++ { 117 q.query = query 118 if err := d.query(q); err != nil { 119 return err 120 } 121 } 122 123 q.hasAsterisk = false 124 return nil 125 } 126 127 ind, err := strconv.Atoi(q.key) 128 if err != nil { 129 return err 130 } 131 132 for i := 0; i < n; i++ { 133 if i == ind { 134 if err := d.query(q); err != nil { 135 return err 136 } 137 if q.hasAsterisk { 138 return d.skipNext(n - i - 1) 139 } 140 return nil 141 } 142 143 if err := d.Skip(); err != nil { 144 return err 145 } 146 } 147 148 return nil 149 } 150 151 func (d *Decoder) skipNext(n int) error { 152 for i := 0; i < n; i++ { 153 if err := d.Skip(); err != nil { 154 return err 155 } 156 } 157 return nil 158 }