ext.go (6276B)
1 package msgpack 2 3 import ( 4 "fmt" 5 "math" 6 "reflect" 7 8 "github.com/vmihailenco/msgpack/v5/msgpcode" 9 ) 10 11 type extInfo struct { 12 Type reflect.Type 13 Decoder func(d *Decoder, v reflect.Value, extLen int) error 14 } 15 16 var extTypes = make(map[int8]*extInfo) 17 18 type MarshalerUnmarshaler interface { 19 Marshaler 20 Unmarshaler 21 } 22 23 func RegisterExt(extID int8, value MarshalerUnmarshaler) { 24 RegisterExtEncoder(extID, value, func(e *Encoder, v reflect.Value) ([]byte, error) { 25 marshaler := v.Interface().(Marshaler) 26 return marshaler.MarshalMsgpack() 27 }) 28 RegisterExtDecoder(extID, value, func(d *Decoder, v reflect.Value, extLen int) error { 29 b, err := d.readN(extLen) 30 if err != nil { 31 return err 32 } 33 return v.Interface().(Unmarshaler).UnmarshalMsgpack(b) 34 }) 35 } 36 37 func UnregisterExt(extID int8) { 38 unregisterExtEncoder(extID) 39 unregisterExtDecoder(extID) 40 } 41 42 func RegisterExtEncoder( 43 extID int8, 44 value interface{}, 45 encoder func(enc *Encoder, v reflect.Value) ([]byte, error), 46 ) { 47 unregisterExtEncoder(extID) 48 49 typ := reflect.TypeOf(value) 50 extEncoder := makeExtEncoder(extID, typ, encoder) 51 typeEncMap.Store(extID, typ) 52 typeEncMap.Store(typ, extEncoder) 53 if typ.Kind() == reflect.Ptr { 54 typeEncMap.Store(typ.Elem(), makeExtEncoderAddr(extEncoder)) 55 } 56 } 57 58 func unregisterExtEncoder(extID int8) { 59 t, ok := typeEncMap.Load(extID) 60 if !ok { 61 return 62 } 63 typeEncMap.Delete(extID) 64 typ := t.(reflect.Type) 65 typeEncMap.Delete(typ) 66 if typ.Kind() == reflect.Ptr { 67 typeEncMap.Delete(typ.Elem()) 68 } 69 } 70 71 func makeExtEncoder( 72 extID int8, 73 typ reflect.Type, 74 encoder func(enc *Encoder, v reflect.Value) ([]byte, error), 75 ) encoderFunc { 76 nilable := typ.Kind() == reflect.Ptr 77 78 return func(e *Encoder, v reflect.Value) error { 79 if nilable && v.IsNil() { 80 return e.EncodeNil() 81 } 82 83 b, err := encoder(e, v) 84 if err != nil { 85 return err 86 } 87 88 if err := e.EncodeExtHeader(extID, len(b)); err != nil { 89 return err 90 } 91 92 return e.write(b) 93 } 94 } 95 96 func makeExtEncoderAddr(extEncoder encoderFunc) encoderFunc { 97 return func(e *Encoder, v reflect.Value) error { 98 if !v.CanAddr() { 99 return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface()) 100 } 101 return extEncoder(e, v.Addr()) 102 } 103 } 104 105 func RegisterExtDecoder( 106 extID int8, 107 value interface{}, 108 decoder func(dec *Decoder, v reflect.Value, extLen int) error, 109 ) { 110 unregisterExtDecoder(extID) 111 112 typ := reflect.TypeOf(value) 113 extDecoder := makeExtDecoder(extID, typ, decoder) 114 extTypes[extID] = &extInfo{ 115 Type: typ, 116 Decoder: decoder, 117 } 118 119 typeDecMap.Store(extID, typ) 120 typeDecMap.Store(typ, extDecoder) 121 if typ.Kind() == reflect.Ptr { 122 typeDecMap.Store(typ.Elem(), makeExtDecoderAddr(extDecoder)) 123 } 124 } 125 126 func unregisterExtDecoder(extID int8) { 127 t, ok := typeDecMap.Load(extID) 128 if !ok { 129 return 130 } 131 typeDecMap.Delete(extID) 132 delete(extTypes, extID) 133 typ := t.(reflect.Type) 134 typeDecMap.Delete(typ) 135 if typ.Kind() == reflect.Ptr { 136 typeDecMap.Delete(typ.Elem()) 137 } 138 } 139 140 func makeExtDecoder( 141 wantedExtID int8, 142 typ reflect.Type, 143 decoder func(d *Decoder, v reflect.Value, extLen int) error, 144 ) decoderFunc { 145 return nilAwareDecoder(typ, func(d *Decoder, v reflect.Value) error { 146 extID, extLen, err := d.DecodeExtHeader() 147 if err != nil { 148 return err 149 } 150 if extID != wantedExtID { 151 return fmt.Errorf("msgpack: got ext type=%d, wanted %d", extID, wantedExtID) 152 } 153 return decoder(d, v, extLen) 154 }) 155 } 156 157 func makeExtDecoderAddr(extDecoder decoderFunc) decoderFunc { 158 return func(d *Decoder, v reflect.Value) error { 159 if !v.CanAddr() { 160 return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface()) 161 } 162 return extDecoder(d, v.Addr()) 163 } 164 } 165 166 func (e *Encoder) EncodeExtHeader(extID int8, extLen int) error { 167 if err := e.encodeExtLen(extLen); err != nil { 168 return err 169 } 170 if err := e.w.WriteByte(byte(extID)); err != nil { 171 return err 172 } 173 return nil 174 } 175 176 func (e *Encoder) encodeExtLen(l int) error { 177 switch l { 178 case 1: 179 return e.writeCode(msgpcode.FixExt1) 180 case 2: 181 return e.writeCode(msgpcode.FixExt2) 182 case 4: 183 return e.writeCode(msgpcode.FixExt4) 184 case 8: 185 return e.writeCode(msgpcode.FixExt8) 186 case 16: 187 return e.writeCode(msgpcode.FixExt16) 188 } 189 if l <= math.MaxUint8 { 190 return e.write1(msgpcode.Ext8, uint8(l)) 191 } 192 if l <= math.MaxUint16 { 193 return e.write2(msgpcode.Ext16, uint16(l)) 194 } 195 return e.write4(msgpcode.Ext32, uint32(l)) 196 } 197 198 func (d *Decoder) DecodeExtHeader() (extID int8, extLen int, err error) { 199 c, err := d.readCode() 200 if err != nil { 201 return 202 } 203 return d.extHeader(c) 204 } 205 206 func (d *Decoder) extHeader(c byte) (int8, int, error) { 207 extLen, err := d.parseExtLen(c) 208 if err != nil { 209 return 0, 0, err 210 } 211 212 extID, err := d.readCode() 213 if err != nil { 214 return 0, 0, err 215 } 216 217 return int8(extID), extLen, nil 218 } 219 220 func (d *Decoder) parseExtLen(c byte) (int, error) { 221 switch c { 222 case msgpcode.FixExt1: 223 return 1, nil 224 case msgpcode.FixExt2: 225 return 2, nil 226 case msgpcode.FixExt4: 227 return 4, nil 228 case msgpcode.FixExt8: 229 return 8, nil 230 case msgpcode.FixExt16: 231 return 16, nil 232 case msgpcode.Ext8: 233 n, err := d.uint8() 234 return int(n), err 235 case msgpcode.Ext16: 236 n, err := d.uint16() 237 return int(n), err 238 case msgpcode.Ext32: 239 n, err := d.uint32() 240 return int(n), err 241 default: 242 return 0, fmt.Errorf("msgpack: invalid code=%x decoding ext len", c) 243 } 244 } 245 246 func (d *Decoder) decodeInterfaceExt(c byte) (interface{}, error) { 247 extID, extLen, err := d.extHeader(c) 248 if err != nil { 249 return nil, err 250 } 251 252 info, ok := extTypes[extID] 253 if !ok { 254 return nil, fmt.Errorf("msgpack: unknown ext id=%d", extID) 255 } 256 257 v := reflect.New(info.Type).Elem() 258 if nilable(v.Kind()) && v.IsNil() { 259 v.Set(reflect.New(info.Type.Elem())) 260 } 261 262 if err := info.Decoder(d, v, extLen); err != nil { 263 return nil, err 264 } 265 266 return v.Interface(), nil 267 } 268 269 func (d *Decoder) skipExt(c byte) error { 270 n, err := d.parseExtLen(c) 271 if err != nil { 272 return err 273 } 274 return d.skipN(n + 1) 275 } 276 277 func (d *Decoder) skipExtHeader(c byte) error { 278 // Read ext type. 279 _, err := d.readCode() 280 if err != nil { 281 return err 282 } 283 // Read ext body len. 284 for i := 0; i < extHeaderLen(c); i++ { 285 _, err := d.readCode() 286 if err != nil { 287 return err 288 } 289 } 290 return nil 291 } 292 293 func extHeaderLen(c byte) int { 294 switch c { 295 case msgpcode.Ext8: 296 return 1 297 case msgpcode.Ext16: 298 return 2 299 case msgpcode.Ext32: 300 return 4 301 } 302 return 0 303 }