gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

ext.go (6276B)


      1 package msgpack
      2 
      3 import (
      4 	"fmt"
      5 	"math"
      6 	"reflect"
      7 
      8 	"github.com/vmihailenco/msgpack/v5/msgpcode"
      9 )
     10 
     11 type extInfo struct {
     12 	Type    reflect.Type
     13 	Decoder func(d *Decoder, v reflect.Value, extLen int) error
     14 }
     15 
     16 var extTypes = make(map[int8]*extInfo)
     17 
     18 type MarshalerUnmarshaler interface {
     19 	Marshaler
     20 	Unmarshaler
     21 }
     22 
     23 func RegisterExt(extID int8, value MarshalerUnmarshaler) {
     24 	RegisterExtEncoder(extID, value, func(e *Encoder, v reflect.Value) ([]byte, error) {
     25 		marshaler := v.Interface().(Marshaler)
     26 		return marshaler.MarshalMsgpack()
     27 	})
     28 	RegisterExtDecoder(extID, value, func(d *Decoder, v reflect.Value, extLen int) error {
     29 		b, err := d.readN(extLen)
     30 		if err != nil {
     31 			return err
     32 		}
     33 		return v.Interface().(Unmarshaler).UnmarshalMsgpack(b)
     34 	})
     35 }
     36 
     37 func UnregisterExt(extID int8) {
     38 	unregisterExtEncoder(extID)
     39 	unregisterExtDecoder(extID)
     40 }
     41 
     42 func RegisterExtEncoder(
     43 	extID int8,
     44 	value interface{},
     45 	encoder func(enc *Encoder, v reflect.Value) ([]byte, error),
     46 ) {
     47 	unregisterExtEncoder(extID)
     48 
     49 	typ := reflect.TypeOf(value)
     50 	extEncoder := makeExtEncoder(extID, typ, encoder)
     51 	typeEncMap.Store(extID, typ)
     52 	typeEncMap.Store(typ, extEncoder)
     53 	if typ.Kind() == reflect.Ptr {
     54 		typeEncMap.Store(typ.Elem(), makeExtEncoderAddr(extEncoder))
     55 	}
     56 }
     57 
     58 func unregisterExtEncoder(extID int8) {
     59 	t, ok := typeEncMap.Load(extID)
     60 	if !ok {
     61 		return
     62 	}
     63 	typeEncMap.Delete(extID)
     64 	typ := t.(reflect.Type)
     65 	typeEncMap.Delete(typ)
     66 	if typ.Kind() == reflect.Ptr {
     67 		typeEncMap.Delete(typ.Elem())
     68 	}
     69 }
     70 
     71 func makeExtEncoder(
     72 	extID int8,
     73 	typ reflect.Type,
     74 	encoder func(enc *Encoder, v reflect.Value) ([]byte, error),
     75 ) encoderFunc {
     76 	nilable := typ.Kind() == reflect.Ptr
     77 
     78 	return func(e *Encoder, v reflect.Value) error {
     79 		if nilable && v.IsNil() {
     80 			return e.EncodeNil()
     81 		}
     82 
     83 		b, err := encoder(e, v)
     84 		if err != nil {
     85 			return err
     86 		}
     87 
     88 		if err := e.EncodeExtHeader(extID, len(b)); err != nil {
     89 			return err
     90 		}
     91 
     92 		return e.write(b)
     93 	}
     94 }
     95 
     96 func makeExtEncoderAddr(extEncoder encoderFunc) encoderFunc {
     97 	return func(e *Encoder, v reflect.Value) error {
     98 		if !v.CanAddr() {
     99 			return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
    100 		}
    101 		return extEncoder(e, v.Addr())
    102 	}
    103 }
    104 
    105 func RegisterExtDecoder(
    106 	extID int8,
    107 	value interface{},
    108 	decoder func(dec *Decoder, v reflect.Value, extLen int) error,
    109 ) {
    110 	unregisterExtDecoder(extID)
    111 
    112 	typ := reflect.TypeOf(value)
    113 	extDecoder := makeExtDecoder(extID, typ, decoder)
    114 	extTypes[extID] = &extInfo{
    115 		Type:    typ,
    116 		Decoder: decoder,
    117 	}
    118 
    119 	typeDecMap.Store(extID, typ)
    120 	typeDecMap.Store(typ, extDecoder)
    121 	if typ.Kind() == reflect.Ptr {
    122 		typeDecMap.Store(typ.Elem(), makeExtDecoderAddr(extDecoder))
    123 	}
    124 }
    125 
    126 func unregisterExtDecoder(extID int8) {
    127 	t, ok := typeDecMap.Load(extID)
    128 	if !ok {
    129 		return
    130 	}
    131 	typeDecMap.Delete(extID)
    132 	delete(extTypes, extID)
    133 	typ := t.(reflect.Type)
    134 	typeDecMap.Delete(typ)
    135 	if typ.Kind() == reflect.Ptr {
    136 		typeDecMap.Delete(typ.Elem())
    137 	}
    138 }
    139 
    140 func makeExtDecoder(
    141 	wantedExtID int8,
    142 	typ reflect.Type,
    143 	decoder func(d *Decoder, v reflect.Value, extLen int) error,
    144 ) decoderFunc {
    145 	return nilAwareDecoder(typ, func(d *Decoder, v reflect.Value) error {
    146 		extID, extLen, err := d.DecodeExtHeader()
    147 		if err != nil {
    148 			return err
    149 		}
    150 		if extID != wantedExtID {
    151 			return fmt.Errorf("msgpack: got ext type=%d, wanted %d", extID, wantedExtID)
    152 		}
    153 		return decoder(d, v, extLen)
    154 	})
    155 }
    156 
    157 func makeExtDecoderAddr(extDecoder decoderFunc) decoderFunc {
    158 	return func(d *Decoder, v reflect.Value) error {
    159 		if !v.CanAddr() {
    160 			return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
    161 		}
    162 		return extDecoder(d, v.Addr())
    163 	}
    164 }
    165 
    166 func (e *Encoder) EncodeExtHeader(extID int8, extLen int) error {
    167 	if err := e.encodeExtLen(extLen); err != nil {
    168 		return err
    169 	}
    170 	if err := e.w.WriteByte(byte(extID)); err != nil {
    171 		return err
    172 	}
    173 	return nil
    174 }
    175 
    176 func (e *Encoder) encodeExtLen(l int) error {
    177 	switch l {
    178 	case 1:
    179 		return e.writeCode(msgpcode.FixExt1)
    180 	case 2:
    181 		return e.writeCode(msgpcode.FixExt2)
    182 	case 4:
    183 		return e.writeCode(msgpcode.FixExt4)
    184 	case 8:
    185 		return e.writeCode(msgpcode.FixExt8)
    186 	case 16:
    187 		return e.writeCode(msgpcode.FixExt16)
    188 	}
    189 	if l <= math.MaxUint8 {
    190 		return e.write1(msgpcode.Ext8, uint8(l))
    191 	}
    192 	if l <= math.MaxUint16 {
    193 		return e.write2(msgpcode.Ext16, uint16(l))
    194 	}
    195 	return e.write4(msgpcode.Ext32, uint32(l))
    196 }
    197 
    198 func (d *Decoder) DecodeExtHeader() (extID int8, extLen int, err error) {
    199 	c, err := d.readCode()
    200 	if err != nil {
    201 		return
    202 	}
    203 	return d.extHeader(c)
    204 }
    205 
    206 func (d *Decoder) extHeader(c byte) (int8, int, error) {
    207 	extLen, err := d.parseExtLen(c)
    208 	if err != nil {
    209 		return 0, 0, err
    210 	}
    211 
    212 	extID, err := d.readCode()
    213 	if err != nil {
    214 		return 0, 0, err
    215 	}
    216 
    217 	return int8(extID), extLen, nil
    218 }
    219 
    220 func (d *Decoder) parseExtLen(c byte) (int, error) {
    221 	switch c {
    222 	case msgpcode.FixExt1:
    223 		return 1, nil
    224 	case msgpcode.FixExt2:
    225 		return 2, nil
    226 	case msgpcode.FixExt4:
    227 		return 4, nil
    228 	case msgpcode.FixExt8:
    229 		return 8, nil
    230 	case msgpcode.FixExt16:
    231 		return 16, nil
    232 	case msgpcode.Ext8:
    233 		n, err := d.uint8()
    234 		return int(n), err
    235 	case msgpcode.Ext16:
    236 		n, err := d.uint16()
    237 		return int(n), err
    238 	case msgpcode.Ext32:
    239 		n, err := d.uint32()
    240 		return int(n), err
    241 	default:
    242 		return 0, fmt.Errorf("msgpack: invalid code=%x decoding ext len", c)
    243 	}
    244 }
    245 
    246 func (d *Decoder) decodeInterfaceExt(c byte) (interface{}, error) {
    247 	extID, extLen, err := d.extHeader(c)
    248 	if err != nil {
    249 		return nil, err
    250 	}
    251 
    252 	info, ok := extTypes[extID]
    253 	if !ok {
    254 		return nil, fmt.Errorf("msgpack: unknown ext id=%d", extID)
    255 	}
    256 
    257 	v := reflect.New(info.Type).Elem()
    258 	if nilable(v.Kind()) && v.IsNil() {
    259 		v.Set(reflect.New(info.Type.Elem()))
    260 	}
    261 
    262 	if err := info.Decoder(d, v, extLen); err != nil {
    263 		return nil, err
    264 	}
    265 
    266 	return v.Interface(), nil
    267 }
    268 
    269 func (d *Decoder) skipExt(c byte) error {
    270 	n, err := d.parseExtLen(c)
    271 	if err != nil {
    272 		return err
    273 	}
    274 	return d.skipN(n + 1)
    275 }
    276 
    277 func (d *Decoder) skipExtHeader(c byte) error {
    278 	// Read ext type.
    279 	_, err := d.readCode()
    280 	if err != nil {
    281 		return err
    282 	}
    283 	// Read ext body len.
    284 	for i := 0; i < extHeaderLen(c); i++ {
    285 		_, err := d.readCode()
    286 		if err != nil {
    287 			return err
    288 		}
    289 	}
    290 	return nil
    291 }
    292 
    293 func extHeaderLen(c byte) int {
    294 	switch c {
    295 	case msgpcode.Ext8:
    296 		return 1
    297 	case msgpcode.Ext16:
    298 		return 2
    299 	case msgpcode.Ext32:
    300 		return 4
    301 	}
    302 	return 0
    303 }