gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

marshaller.go (13800B)


      1 package mp4
      2 
      3 import (
      4 	"bytes"
      5 	"errors"
      6 	"fmt"
      7 	"io"
      8 	"math"
      9 	"reflect"
     10 
     11 	"github.com/abema/go-mp4/bitio"
     12 )
     13 
     14 const (
     15 	anyVersion = math.MaxUint8
     16 )
     17 
     18 var ErrUnsupportedBoxVersion = errors.New("unsupported box version")
     19 
     20 type marshaller struct {
     21 	writer bitio.Writer
     22 	wbits  uint64
     23 	src    IImmutableBox
     24 	ctx    Context
     25 }
     26 
     27 func Marshal(w io.Writer, src IImmutableBox, ctx Context) (n uint64, err error) {
     28 	boxDef := src.GetType().getBoxDef(ctx)
     29 	if boxDef == nil {
     30 		return 0, ErrBoxInfoNotFound
     31 	}
     32 
     33 	v := reflect.ValueOf(src).Elem()
     34 
     35 	m := &marshaller{
     36 		writer: bitio.NewWriter(w),
     37 		src:    src,
     38 		ctx:    ctx,
     39 	}
     40 
     41 	if err := m.marshalStruct(v, boxDef.fields); err != nil {
     42 		return 0, err
     43 	}
     44 
     45 	if m.wbits%8 != 0 {
     46 		return 0, fmt.Errorf("box size is not multiple of 8 bits: type=%s, bits=%d", src.GetType().String(), m.wbits)
     47 	}
     48 
     49 	return m.wbits / 8, nil
     50 }
     51 
     52 func (m *marshaller) marshal(v reflect.Value, fi *fieldInstance) error {
     53 	switch v.Type().Kind() {
     54 	case reflect.Ptr:
     55 		return m.marshalPtr(v, fi)
     56 	case reflect.Struct:
     57 		return m.marshalStruct(v, fi.children)
     58 	case reflect.Array:
     59 		return m.marshalArray(v, fi)
     60 	case reflect.Slice:
     61 		return m.marshalSlice(v, fi)
     62 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
     63 		return m.marshalInt(v, fi)
     64 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
     65 		return m.marshalUint(v, fi)
     66 	case reflect.Bool:
     67 		return m.marshalBool(v, fi)
     68 	case reflect.String:
     69 		return m.marshalString(v)
     70 	default:
     71 		return fmt.Errorf("unsupported type: %s", v.Type().Kind())
     72 	}
     73 }
     74 
     75 func (m *marshaller) marshalPtr(v reflect.Value, fi *fieldInstance) error {
     76 	return m.marshal(v.Elem(), fi)
     77 }
     78 
     79 func (m *marshaller) marshalStruct(v reflect.Value, fs []*field) error {
     80 	for _, f := range fs {
     81 		fi := resolveFieldInstance(f, m.src, v, m.ctx)
     82 
     83 		if !isTargetField(m.src, fi, m.ctx) {
     84 			continue
     85 		}
     86 
     87 		wbits, override, err := fi.cfo.OnWriteField(f.name, m.writer, m.ctx)
     88 		if err != nil {
     89 			return err
     90 		}
     91 		m.wbits += wbits
     92 		if override {
     93 			continue
     94 		}
     95 
     96 		err = m.marshal(v.FieldByName(f.name), fi)
     97 		if err != nil {
     98 			return err
     99 		}
    100 	}
    101 
    102 	return nil
    103 }
    104 
    105 func (m *marshaller) marshalArray(v reflect.Value, fi *fieldInstance) error {
    106 	size := v.Type().Size()
    107 	for i := 0; i < int(size)/int(v.Type().Elem().Size()); i++ {
    108 		var err error
    109 		err = m.marshal(v.Index(i), fi)
    110 		if err != nil {
    111 			return err
    112 		}
    113 	}
    114 	return nil
    115 }
    116 
    117 func (m *marshaller) marshalSlice(v reflect.Value, fi *fieldInstance) error {
    118 	length := uint64(v.Len())
    119 	if fi.length != LengthUnlimited {
    120 		if length < uint64(fi.length) {
    121 			return fmt.Errorf("the slice has too few elements: required=%d actual=%d", fi.length, length)
    122 		}
    123 		length = uint64(fi.length)
    124 	}
    125 
    126 	elemType := v.Type().Elem()
    127 	if elemType.Kind() == reflect.Uint8 && fi.size == 8 && m.wbits%8 == 0 {
    128 		if _, err := io.CopyN(m.writer, bytes.NewBuffer(v.Bytes()), int64(length)); err != nil {
    129 			return err
    130 		}
    131 		m.wbits += length * 8
    132 		return nil
    133 	}
    134 
    135 	for i := 0; i < int(length); i++ {
    136 		m.marshal(v.Index(i), fi)
    137 	}
    138 	return nil
    139 }
    140 
    141 func (m *marshaller) marshalInt(v reflect.Value, fi *fieldInstance) error {
    142 	signed := v.Int()
    143 
    144 	if fi.is(fieldVarint) {
    145 		return errors.New("signed varint is unsupported")
    146 	}
    147 
    148 	signBit := signed < 0
    149 	val := uint64(signed)
    150 	for i := uint(0); i < fi.size; i += 8 {
    151 		v := val
    152 		size := uint(8)
    153 		if fi.size > i+8 {
    154 			v = v >> (fi.size - (i + 8))
    155 		} else if fi.size < i+8 {
    156 			size = fi.size - i
    157 		}
    158 
    159 		// set sign bit
    160 		if i == 0 {
    161 			if signBit {
    162 				v |= 0x1 << (size - 1)
    163 			} else {
    164 				v &= 0x1<<(size-1) - 1
    165 			}
    166 		}
    167 
    168 		if err := m.writer.WriteBits([]byte{byte(v)}, size); err != nil {
    169 			return err
    170 		}
    171 		m.wbits += uint64(size)
    172 	}
    173 
    174 	return nil
    175 }
    176 
    177 func (m *marshaller) marshalUint(v reflect.Value, fi *fieldInstance) error {
    178 	val := v.Uint()
    179 
    180 	if fi.is(fieldVarint) {
    181 		m.writeUvarint(val)
    182 		return nil
    183 	}
    184 
    185 	for i := uint(0); i < fi.size; i += 8 {
    186 		v := val
    187 		size := uint(8)
    188 		if fi.size > i+8 {
    189 			v = v >> (fi.size - (i + 8))
    190 		} else if fi.size < i+8 {
    191 			size = fi.size - i
    192 		}
    193 		if err := m.writer.WriteBits([]byte{byte(v)}, size); err != nil {
    194 			return err
    195 		}
    196 		m.wbits += uint64(size)
    197 	}
    198 
    199 	return nil
    200 }
    201 
    202 func (m *marshaller) marshalBool(v reflect.Value, fi *fieldInstance) error {
    203 	var val byte
    204 	if v.Bool() {
    205 		val = 0xff
    206 	} else {
    207 		val = 0x00
    208 	}
    209 	if err := m.writer.WriteBits([]byte{val}, fi.size); err != nil {
    210 		return err
    211 	}
    212 	m.wbits += uint64(fi.size)
    213 	return nil
    214 }
    215 
    216 func (m *marshaller) marshalString(v reflect.Value) error {
    217 	data := []byte(v.String())
    218 	for _, b := range data {
    219 		if err := m.writer.WriteBits([]byte{b}, 8); err != nil {
    220 			return err
    221 		}
    222 		m.wbits += 8
    223 	}
    224 	// null character
    225 	if err := m.writer.WriteBits([]byte{0x00}, 8); err != nil {
    226 		return err
    227 	}
    228 	m.wbits += 8
    229 	return nil
    230 }
    231 
    232 func (m *marshaller) writeUvarint(u uint64) error {
    233 	for i := 21; i > 0; i -= 7 {
    234 		if err := m.writer.WriteBits([]byte{(byte(u >> uint(i))) | 0x80}, 8); err != nil {
    235 			return err
    236 		}
    237 		m.wbits += 8
    238 	}
    239 
    240 	if err := m.writer.WriteBits([]byte{byte(u) & 0x7f}, 8); err != nil {
    241 		return err
    242 	}
    243 	m.wbits += 8
    244 
    245 	return nil
    246 }
    247 
    248 type unmarshaller struct {
    249 	reader bitio.ReadSeeker
    250 	dst    IBox
    251 	size   uint64
    252 	rbits  uint64
    253 	ctx    Context
    254 }
    255 
    256 func UnmarshalAny(r io.ReadSeeker, boxType BoxType, payloadSize uint64, ctx Context) (box IBox, n uint64, err error) {
    257 	dst, err := boxType.New(ctx)
    258 	if err != nil {
    259 		return nil, 0, err
    260 	}
    261 	n, err = Unmarshal(r, payloadSize, dst, ctx)
    262 	return dst, n, err
    263 }
    264 
    265 func Unmarshal(r io.ReadSeeker, payloadSize uint64, dst IBox, ctx Context) (n uint64, err error) {
    266 	boxDef := dst.GetType().getBoxDef(ctx)
    267 	if boxDef == nil {
    268 		return 0, ErrBoxInfoNotFound
    269 	}
    270 
    271 	v := reflect.ValueOf(dst).Elem()
    272 
    273 	dst.SetVersion(anyVersion)
    274 
    275 	u := &unmarshaller{
    276 		reader: bitio.NewReadSeeker(r),
    277 		dst:    dst,
    278 		size:   payloadSize,
    279 		ctx:    ctx,
    280 	}
    281 
    282 	if n, override, err := dst.BeforeUnmarshal(r, payloadSize, u.ctx); err != nil {
    283 		return 0, err
    284 	} else if override {
    285 		return n, nil
    286 	} else {
    287 		u.rbits = n * 8
    288 	}
    289 
    290 	sn, err := r.Seek(0, io.SeekCurrent)
    291 	if err != nil {
    292 		return 0, err
    293 	}
    294 
    295 	if err := u.unmarshalStruct(v, boxDef.fields); err != nil {
    296 		if err == ErrUnsupportedBoxVersion {
    297 			r.Seek(sn, io.SeekStart)
    298 		}
    299 		return 0, err
    300 	}
    301 
    302 	if u.rbits%8 != 0 {
    303 		return 0, fmt.Errorf("box size is not multiple of 8 bits: type=%s, size=%d, bits=%d", dst.GetType().String(), u.size, u.rbits)
    304 	}
    305 
    306 	if u.rbits > u.size*8 {
    307 		return 0, fmt.Errorf("overrun error: type=%s, size=%d, bits=%d", dst.GetType().String(), u.size, u.rbits)
    308 	}
    309 
    310 	return u.rbits / 8, nil
    311 }
    312 
    313 func (u *unmarshaller) unmarshal(v reflect.Value, fi *fieldInstance) error {
    314 	var err error
    315 	switch v.Type().Kind() {
    316 	case reflect.Ptr:
    317 		err = u.unmarshalPtr(v, fi)
    318 	case reflect.Struct:
    319 		err = u.unmarshalStructInternal(v, fi)
    320 	case reflect.Array:
    321 		err = u.unmarshalArray(v, fi)
    322 	case reflect.Slice:
    323 		err = u.unmarshalSlice(v, fi)
    324 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    325 		err = u.unmarshalInt(v, fi)
    326 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
    327 		err = u.unmarshalUint(v, fi)
    328 	case reflect.Bool:
    329 		err = u.unmarshalBool(v, fi)
    330 	case reflect.String:
    331 		err = u.unmarshalString(v, fi)
    332 	default:
    333 		return fmt.Errorf("unsupported type: %s", v.Type().Kind())
    334 	}
    335 	return err
    336 }
    337 
    338 func (u *unmarshaller) unmarshalPtr(v reflect.Value, fi *fieldInstance) error {
    339 	v.Set(reflect.New(v.Type().Elem()))
    340 	return u.unmarshal(v.Elem(), fi)
    341 }
    342 
    343 func (u *unmarshaller) unmarshalStructInternal(v reflect.Value, fi *fieldInstance) error {
    344 	if fi.size != 0 && fi.size%8 == 0 {
    345 		u2 := *u
    346 		u2.size = uint64(fi.size / 8)
    347 		u2.rbits = 0
    348 		if err := u2.unmarshalStruct(v, fi.children); err != nil {
    349 			return err
    350 		}
    351 		u.rbits += u2.rbits
    352 		if u2.rbits != uint64(fi.size) {
    353 			return errors.New("invalid alignment")
    354 		}
    355 		return nil
    356 	}
    357 
    358 	return u.unmarshalStruct(v, fi.children)
    359 }
    360 
    361 func (u *unmarshaller) unmarshalStruct(v reflect.Value, fs []*field) error {
    362 	for _, f := range fs {
    363 		fi := resolveFieldInstance(f, u.dst, v, u.ctx)
    364 
    365 		if !isTargetField(u.dst, fi, u.ctx) {
    366 			continue
    367 		}
    368 
    369 		rbits, override, err := fi.cfo.OnReadField(f.name, u.reader, u.size*8-u.rbits, u.ctx)
    370 		if err != nil {
    371 			return err
    372 		}
    373 		u.rbits += rbits
    374 		if override {
    375 			continue
    376 		}
    377 
    378 		err = u.unmarshal(v.FieldByName(f.name), fi)
    379 		if err != nil {
    380 			return err
    381 		}
    382 
    383 		if v.FieldByName(f.name).Type() == reflect.TypeOf(FullBox{}) && !u.dst.GetType().IsSupportedVersion(u.dst.GetVersion(), u.ctx) {
    384 			return ErrUnsupportedBoxVersion
    385 		}
    386 	}
    387 
    388 	return nil
    389 }
    390 
    391 func (u *unmarshaller) unmarshalArray(v reflect.Value, fi *fieldInstance) error {
    392 	size := v.Type().Size()
    393 	for i := 0; i < int(size)/int(v.Type().Elem().Size()); i++ {
    394 		var err error
    395 		err = u.unmarshal(v.Index(i), fi)
    396 		if err != nil {
    397 			return err
    398 		}
    399 	}
    400 	return nil
    401 }
    402 
    403 func (u *unmarshaller) unmarshalSlice(v reflect.Value, fi *fieldInstance) error {
    404 	var slice reflect.Value
    405 	elemType := v.Type().Elem()
    406 
    407 	length := uint64(fi.length)
    408 	if fi.length == LengthUnlimited {
    409 		if fi.size != 0 {
    410 			left := (u.size)*8 - u.rbits
    411 			if left%uint64(fi.size) != 0 {
    412 				return errors.New("invalid alignment")
    413 			}
    414 			length = left / uint64(fi.size)
    415 		} else {
    416 			length = 0
    417 		}
    418 	}
    419 
    420 	if length > math.MaxInt32 {
    421 		return fmt.Errorf("out of memory: requestedSize=%d", length)
    422 	}
    423 
    424 	if fi.size != 0 && fi.size%8 == 0 && u.rbits%8 == 0 && elemType.Kind() == reflect.Uint8 && fi.size == 8 {
    425 		totalSize := length * uint64(fi.size) / 8
    426 		buf := bytes.NewBuffer(make([]byte, 0, totalSize))
    427 		if _, err := io.CopyN(buf, u.reader, int64(totalSize)); err != nil {
    428 			return err
    429 		}
    430 		slice = reflect.ValueOf(buf.Bytes())
    431 		u.rbits += uint64(totalSize) * 8
    432 
    433 	} else {
    434 		slice = reflect.MakeSlice(v.Type(), 0, int(length))
    435 		for i := 0; ; i++ {
    436 			if fi.length != LengthUnlimited && uint(i) >= fi.length {
    437 				break
    438 			}
    439 			if fi.length == LengthUnlimited && u.rbits >= u.size*8 {
    440 				break
    441 			}
    442 			slice = reflect.Append(slice, reflect.Zero(elemType))
    443 			if err := u.unmarshal(slice.Index(i), fi); err != nil {
    444 				return err
    445 			}
    446 			if u.rbits > u.size*8 {
    447 				return fmt.Errorf("failed to read array completely: fieldName=\"%s\"", fi.name)
    448 			}
    449 		}
    450 	}
    451 
    452 	v.Set(slice)
    453 	return nil
    454 }
    455 
    456 func (u *unmarshaller) unmarshalInt(v reflect.Value, fi *fieldInstance) error {
    457 	if fi.is(fieldVarint) {
    458 		return errors.New("signed varint is unsupported")
    459 	}
    460 
    461 	if fi.size == 0 {
    462 		return fmt.Errorf("size must not be zero: %s", fi.name)
    463 	}
    464 
    465 	data, err := u.reader.ReadBits(fi.size)
    466 	if err != nil {
    467 		return err
    468 	}
    469 	u.rbits += uint64(fi.size)
    470 
    471 	signBit := false
    472 	if len(data) > 0 {
    473 		signMask := byte(0x01) << ((fi.size - 1) % 8)
    474 		signBit = data[0]&signMask != 0
    475 		if signBit {
    476 			data[0] |= ^(signMask - 1)
    477 		}
    478 	}
    479 
    480 	var val uint64
    481 	if signBit {
    482 		val = ^uint64(0)
    483 	}
    484 	for i := range data {
    485 		val <<= 8
    486 		val |= uint64(data[i])
    487 	}
    488 	v.SetInt(int64(val))
    489 	return nil
    490 }
    491 
    492 func (u *unmarshaller) unmarshalUint(v reflect.Value, fi *fieldInstance) error {
    493 	if fi.is(fieldVarint) {
    494 		val, err := u.readUvarint()
    495 		if err != nil {
    496 			return err
    497 		}
    498 		v.SetUint(val)
    499 		return nil
    500 	}
    501 
    502 	if fi.size == 0 {
    503 		return fmt.Errorf("size must not be zero: %s", fi.name)
    504 	}
    505 
    506 	data, err := u.reader.ReadBits(fi.size)
    507 	if err != nil {
    508 		return err
    509 	}
    510 	u.rbits += uint64(fi.size)
    511 
    512 	val := uint64(0)
    513 	for i := range data {
    514 		val <<= 8
    515 		val |= uint64(data[i])
    516 	}
    517 	v.SetUint(val)
    518 
    519 	return nil
    520 }
    521 
    522 func (u *unmarshaller) unmarshalBool(v reflect.Value, fi *fieldInstance) error {
    523 	if fi.size == 0 {
    524 		return fmt.Errorf("size must not be zero: %s", fi.name)
    525 	}
    526 
    527 	data, err := u.reader.ReadBits(fi.size)
    528 	if err != nil {
    529 		return err
    530 	}
    531 	u.rbits += uint64(fi.size)
    532 
    533 	val := false
    534 	for _, b := range data {
    535 		val = val || (b != byte(0))
    536 	}
    537 	v.SetBool(val)
    538 
    539 	return nil
    540 }
    541 
    542 func (u *unmarshaller) unmarshalString(v reflect.Value, fi *fieldInstance) error {
    543 	switch fi.strType {
    544 	case stringType_C:
    545 		return u.unmarshalStringC(v)
    546 	case stringType_C_P:
    547 		return u.unmarshalStringCP(v, fi)
    548 	default:
    549 		return fmt.Errorf("unknown string type: %d", fi.strType)
    550 	}
    551 }
    552 
    553 func (u *unmarshaller) unmarshalStringC(v reflect.Value) error {
    554 	data := make([]byte, 0, 16)
    555 	for {
    556 		if u.rbits >= u.size*8 {
    557 			break
    558 		}
    559 
    560 		c, err := u.reader.ReadBits(8)
    561 		if err != nil {
    562 			return err
    563 		}
    564 		u.rbits += 8
    565 
    566 		if c[0] == 0 {
    567 			break // null character
    568 		}
    569 
    570 		data = append(data, c[0])
    571 	}
    572 	v.SetString(string(data))
    573 
    574 	return nil
    575 }
    576 
    577 func (u *unmarshaller) unmarshalStringCP(v reflect.Value, fi *fieldInstance) error {
    578 	if ok, err := u.tryReadPString(v, fi); err != nil {
    579 		return err
    580 	} else if ok {
    581 		return nil
    582 	}
    583 	return u.unmarshalStringC(v)
    584 }
    585 
    586 func (u *unmarshaller) tryReadPString(v reflect.Value, fi *fieldInstance) (ok bool, err error) {
    587 	remainingSize := (u.size*8 - u.rbits) / 8
    588 	if remainingSize < 2 {
    589 		return false, nil
    590 	}
    591 
    592 	offset, err := u.reader.Seek(0, io.SeekCurrent)
    593 	if err != nil {
    594 		return false, err
    595 	}
    596 	defer func() {
    597 		if err == nil && !ok {
    598 			_, err = u.reader.Seek(offset, io.SeekStart)
    599 		}
    600 	}()
    601 
    602 	buf0 := make([]byte, 1)
    603 	if _, err := io.ReadFull(u.reader, buf0); err != nil {
    604 		return false, err
    605 	}
    606 	remainingSize--
    607 	plen := buf0[0]
    608 	if uint64(plen) > remainingSize {
    609 		return false, nil
    610 	}
    611 	buf := make([]byte, int(plen))
    612 	if _, err := io.ReadFull(u.reader, buf); err != nil {
    613 		return false, err
    614 	}
    615 	remainingSize -= uint64(plen)
    616 	if fi.cfo.IsPString(fi.name, buf, remainingSize, u.ctx) {
    617 		u.rbits += uint64(len(buf)+1) * 8
    618 		v.SetString(string(buf))
    619 		return true, nil
    620 	}
    621 	return false, nil
    622 }
    623 
    624 func (u *unmarshaller) readUvarint() (uint64, error) {
    625 	var val uint64
    626 	for {
    627 		octet, err := u.reader.ReadBits(8)
    628 		if err != nil {
    629 			return 0, err
    630 		}
    631 		u.rbits += 8
    632 
    633 		val = (val << 7) + uint64(octet[0]&0x7f)
    634 
    635 		if octet[0]&0x80 == 0 {
    636 			return val, nil
    637 		}
    638 	}
    639 }