marshaller.go (13800B)
1 package mp4 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io" 8 "math" 9 "reflect" 10 11 "github.com/abema/go-mp4/bitio" 12 ) 13 14 const ( 15 anyVersion = math.MaxUint8 16 ) 17 18 var ErrUnsupportedBoxVersion = errors.New("unsupported box version") 19 20 type marshaller struct { 21 writer bitio.Writer 22 wbits uint64 23 src IImmutableBox 24 ctx Context 25 } 26 27 func Marshal(w io.Writer, src IImmutableBox, ctx Context) (n uint64, err error) { 28 boxDef := src.GetType().getBoxDef(ctx) 29 if boxDef == nil { 30 return 0, ErrBoxInfoNotFound 31 } 32 33 v := reflect.ValueOf(src).Elem() 34 35 m := &marshaller{ 36 writer: bitio.NewWriter(w), 37 src: src, 38 ctx: ctx, 39 } 40 41 if err := m.marshalStruct(v, boxDef.fields); err != nil { 42 return 0, err 43 } 44 45 if m.wbits%8 != 0 { 46 return 0, fmt.Errorf("box size is not multiple of 8 bits: type=%s, bits=%d", src.GetType().String(), m.wbits) 47 } 48 49 return m.wbits / 8, nil 50 } 51 52 func (m *marshaller) marshal(v reflect.Value, fi *fieldInstance) error { 53 switch v.Type().Kind() { 54 case reflect.Ptr: 55 return m.marshalPtr(v, fi) 56 case reflect.Struct: 57 return m.marshalStruct(v, fi.children) 58 case reflect.Array: 59 return m.marshalArray(v, fi) 60 case reflect.Slice: 61 return m.marshalSlice(v, fi) 62 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 63 return m.marshalInt(v, fi) 64 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 65 return m.marshalUint(v, fi) 66 case reflect.Bool: 67 return m.marshalBool(v, fi) 68 case reflect.String: 69 return m.marshalString(v) 70 default: 71 return fmt.Errorf("unsupported type: %s", v.Type().Kind()) 72 } 73 } 74 75 func (m *marshaller) marshalPtr(v reflect.Value, fi *fieldInstance) error { 76 return m.marshal(v.Elem(), fi) 77 } 78 79 func (m *marshaller) marshalStruct(v reflect.Value, fs []*field) error { 80 for _, f := range fs { 81 fi := resolveFieldInstance(f, m.src, v, m.ctx) 82 83 if !isTargetField(m.src, fi, m.ctx) { 84 continue 85 } 86 87 wbits, override, err := fi.cfo.OnWriteField(f.name, m.writer, m.ctx) 88 if err != nil { 89 return err 90 } 91 m.wbits += wbits 92 if override { 93 continue 94 } 95 96 err = m.marshal(v.FieldByName(f.name), fi) 97 if err != nil { 98 return err 99 } 100 } 101 102 return nil 103 } 104 105 func (m *marshaller) marshalArray(v reflect.Value, fi *fieldInstance) error { 106 size := v.Type().Size() 107 for i := 0; i < int(size)/int(v.Type().Elem().Size()); i++ { 108 var err error 109 err = m.marshal(v.Index(i), fi) 110 if err != nil { 111 return err 112 } 113 } 114 return nil 115 } 116 117 func (m *marshaller) marshalSlice(v reflect.Value, fi *fieldInstance) error { 118 length := uint64(v.Len()) 119 if fi.length != LengthUnlimited { 120 if length < uint64(fi.length) { 121 return fmt.Errorf("the slice has too few elements: required=%d actual=%d", fi.length, length) 122 } 123 length = uint64(fi.length) 124 } 125 126 elemType := v.Type().Elem() 127 if elemType.Kind() == reflect.Uint8 && fi.size == 8 && m.wbits%8 == 0 { 128 if _, err := io.CopyN(m.writer, bytes.NewBuffer(v.Bytes()), int64(length)); err != nil { 129 return err 130 } 131 m.wbits += length * 8 132 return nil 133 } 134 135 for i := 0; i < int(length); i++ { 136 m.marshal(v.Index(i), fi) 137 } 138 return nil 139 } 140 141 func (m *marshaller) marshalInt(v reflect.Value, fi *fieldInstance) error { 142 signed := v.Int() 143 144 if fi.is(fieldVarint) { 145 return errors.New("signed varint is unsupported") 146 } 147 148 signBit := signed < 0 149 val := uint64(signed) 150 for i := uint(0); i < fi.size; i += 8 { 151 v := val 152 size := uint(8) 153 if fi.size > i+8 { 154 v = v >> (fi.size - (i + 8)) 155 } else if fi.size < i+8 { 156 size = fi.size - i 157 } 158 159 // set sign bit 160 if i == 0 { 161 if signBit { 162 v |= 0x1 << (size - 1) 163 } else { 164 v &= 0x1<<(size-1) - 1 165 } 166 } 167 168 if err := m.writer.WriteBits([]byte{byte(v)}, size); err != nil { 169 return err 170 } 171 m.wbits += uint64(size) 172 } 173 174 return nil 175 } 176 177 func (m *marshaller) marshalUint(v reflect.Value, fi *fieldInstance) error { 178 val := v.Uint() 179 180 if fi.is(fieldVarint) { 181 m.writeUvarint(val) 182 return nil 183 } 184 185 for i := uint(0); i < fi.size; i += 8 { 186 v := val 187 size := uint(8) 188 if fi.size > i+8 { 189 v = v >> (fi.size - (i + 8)) 190 } else if fi.size < i+8 { 191 size = fi.size - i 192 } 193 if err := m.writer.WriteBits([]byte{byte(v)}, size); err != nil { 194 return err 195 } 196 m.wbits += uint64(size) 197 } 198 199 return nil 200 } 201 202 func (m *marshaller) marshalBool(v reflect.Value, fi *fieldInstance) error { 203 var val byte 204 if v.Bool() { 205 val = 0xff 206 } else { 207 val = 0x00 208 } 209 if err := m.writer.WriteBits([]byte{val}, fi.size); err != nil { 210 return err 211 } 212 m.wbits += uint64(fi.size) 213 return nil 214 } 215 216 func (m *marshaller) marshalString(v reflect.Value) error { 217 data := []byte(v.String()) 218 for _, b := range data { 219 if err := m.writer.WriteBits([]byte{b}, 8); err != nil { 220 return err 221 } 222 m.wbits += 8 223 } 224 // null character 225 if err := m.writer.WriteBits([]byte{0x00}, 8); err != nil { 226 return err 227 } 228 m.wbits += 8 229 return nil 230 } 231 232 func (m *marshaller) writeUvarint(u uint64) error { 233 for i := 21; i > 0; i -= 7 { 234 if err := m.writer.WriteBits([]byte{(byte(u >> uint(i))) | 0x80}, 8); err != nil { 235 return err 236 } 237 m.wbits += 8 238 } 239 240 if err := m.writer.WriteBits([]byte{byte(u) & 0x7f}, 8); err != nil { 241 return err 242 } 243 m.wbits += 8 244 245 return nil 246 } 247 248 type unmarshaller struct { 249 reader bitio.ReadSeeker 250 dst IBox 251 size uint64 252 rbits uint64 253 ctx Context 254 } 255 256 func UnmarshalAny(r io.ReadSeeker, boxType BoxType, payloadSize uint64, ctx Context) (box IBox, n uint64, err error) { 257 dst, err := boxType.New(ctx) 258 if err != nil { 259 return nil, 0, err 260 } 261 n, err = Unmarshal(r, payloadSize, dst, ctx) 262 return dst, n, err 263 } 264 265 func Unmarshal(r io.ReadSeeker, payloadSize uint64, dst IBox, ctx Context) (n uint64, err error) { 266 boxDef := dst.GetType().getBoxDef(ctx) 267 if boxDef == nil { 268 return 0, ErrBoxInfoNotFound 269 } 270 271 v := reflect.ValueOf(dst).Elem() 272 273 dst.SetVersion(anyVersion) 274 275 u := &unmarshaller{ 276 reader: bitio.NewReadSeeker(r), 277 dst: dst, 278 size: payloadSize, 279 ctx: ctx, 280 } 281 282 if n, override, err := dst.BeforeUnmarshal(r, payloadSize, u.ctx); err != nil { 283 return 0, err 284 } else if override { 285 return n, nil 286 } else { 287 u.rbits = n * 8 288 } 289 290 sn, err := r.Seek(0, io.SeekCurrent) 291 if err != nil { 292 return 0, err 293 } 294 295 if err := u.unmarshalStruct(v, boxDef.fields); err != nil { 296 if err == ErrUnsupportedBoxVersion { 297 r.Seek(sn, io.SeekStart) 298 } 299 return 0, err 300 } 301 302 if u.rbits%8 != 0 { 303 return 0, fmt.Errorf("box size is not multiple of 8 bits: type=%s, size=%d, bits=%d", dst.GetType().String(), u.size, u.rbits) 304 } 305 306 if u.rbits > u.size*8 { 307 return 0, fmt.Errorf("overrun error: type=%s, size=%d, bits=%d", dst.GetType().String(), u.size, u.rbits) 308 } 309 310 return u.rbits / 8, nil 311 } 312 313 func (u *unmarshaller) unmarshal(v reflect.Value, fi *fieldInstance) error { 314 var err error 315 switch v.Type().Kind() { 316 case reflect.Ptr: 317 err = u.unmarshalPtr(v, fi) 318 case reflect.Struct: 319 err = u.unmarshalStructInternal(v, fi) 320 case reflect.Array: 321 err = u.unmarshalArray(v, fi) 322 case reflect.Slice: 323 err = u.unmarshalSlice(v, fi) 324 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 325 err = u.unmarshalInt(v, fi) 326 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 327 err = u.unmarshalUint(v, fi) 328 case reflect.Bool: 329 err = u.unmarshalBool(v, fi) 330 case reflect.String: 331 err = u.unmarshalString(v, fi) 332 default: 333 return fmt.Errorf("unsupported type: %s", v.Type().Kind()) 334 } 335 return err 336 } 337 338 func (u *unmarshaller) unmarshalPtr(v reflect.Value, fi *fieldInstance) error { 339 v.Set(reflect.New(v.Type().Elem())) 340 return u.unmarshal(v.Elem(), fi) 341 } 342 343 func (u *unmarshaller) unmarshalStructInternal(v reflect.Value, fi *fieldInstance) error { 344 if fi.size != 0 && fi.size%8 == 0 { 345 u2 := *u 346 u2.size = uint64(fi.size / 8) 347 u2.rbits = 0 348 if err := u2.unmarshalStruct(v, fi.children); err != nil { 349 return err 350 } 351 u.rbits += u2.rbits 352 if u2.rbits != uint64(fi.size) { 353 return errors.New("invalid alignment") 354 } 355 return nil 356 } 357 358 return u.unmarshalStruct(v, fi.children) 359 } 360 361 func (u *unmarshaller) unmarshalStruct(v reflect.Value, fs []*field) error { 362 for _, f := range fs { 363 fi := resolveFieldInstance(f, u.dst, v, u.ctx) 364 365 if !isTargetField(u.dst, fi, u.ctx) { 366 continue 367 } 368 369 rbits, override, err := fi.cfo.OnReadField(f.name, u.reader, u.size*8-u.rbits, u.ctx) 370 if err != nil { 371 return err 372 } 373 u.rbits += rbits 374 if override { 375 continue 376 } 377 378 err = u.unmarshal(v.FieldByName(f.name), fi) 379 if err != nil { 380 return err 381 } 382 383 if v.FieldByName(f.name).Type() == reflect.TypeOf(FullBox{}) && !u.dst.GetType().IsSupportedVersion(u.dst.GetVersion(), u.ctx) { 384 return ErrUnsupportedBoxVersion 385 } 386 } 387 388 return nil 389 } 390 391 func (u *unmarshaller) unmarshalArray(v reflect.Value, fi *fieldInstance) error { 392 size := v.Type().Size() 393 for i := 0; i < int(size)/int(v.Type().Elem().Size()); i++ { 394 var err error 395 err = u.unmarshal(v.Index(i), fi) 396 if err != nil { 397 return err 398 } 399 } 400 return nil 401 } 402 403 func (u *unmarshaller) unmarshalSlice(v reflect.Value, fi *fieldInstance) error { 404 var slice reflect.Value 405 elemType := v.Type().Elem() 406 407 length := uint64(fi.length) 408 if fi.length == LengthUnlimited { 409 if fi.size != 0 { 410 left := (u.size)*8 - u.rbits 411 if left%uint64(fi.size) != 0 { 412 return errors.New("invalid alignment") 413 } 414 length = left / uint64(fi.size) 415 } else { 416 length = 0 417 } 418 } 419 420 if length > math.MaxInt32 { 421 return fmt.Errorf("out of memory: requestedSize=%d", length) 422 } 423 424 if fi.size != 0 && fi.size%8 == 0 && u.rbits%8 == 0 && elemType.Kind() == reflect.Uint8 && fi.size == 8 { 425 totalSize := length * uint64(fi.size) / 8 426 buf := bytes.NewBuffer(make([]byte, 0, totalSize)) 427 if _, err := io.CopyN(buf, u.reader, int64(totalSize)); err != nil { 428 return err 429 } 430 slice = reflect.ValueOf(buf.Bytes()) 431 u.rbits += uint64(totalSize) * 8 432 433 } else { 434 slice = reflect.MakeSlice(v.Type(), 0, int(length)) 435 for i := 0; ; i++ { 436 if fi.length != LengthUnlimited && uint(i) >= fi.length { 437 break 438 } 439 if fi.length == LengthUnlimited && u.rbits >= u.size*8 { 440 break 441 } 442 slice = reflect.Append(slice, reflect.Zero(elemType)) 443 if err := u.unmarshal(slice.Index(i), fi); err != nil { 444 return err 445 } 446 if u.rbits > u.size*8 { 447 return fmt.Errorf("failed to read array completely: fieldName=\"%s\"", fi.name) 448 } 449 } 450 } 451 452 v.Set(slice) 453 return nil 454 } 455 456 func (u *unmarshaller) unmarshalInt(v reflect.Value, fi *fieldInstance) error { 457 if fi.is(fieldVarint) { 458 return errors.New("signed varint is unsupported") 459 } 460 461 if fi.size == 0 { 462 return fmt.Errorf("size must not be zero: %s", fi.name) 463 } 464 465 data, err := u.reader.ReadBits(fi.size) 466 if err != nil { 467 return err 468 } 469 u.rbits += uint64(fi.size) 470 471 signBit := false 472 if len(data) > 0 { 473 signMask := byte(0x01) << ((fi.size - 1) % 8) 474 signBit = data[0]&signMask != 0 475 if signBit { 476 data[0] |= ^(signMask - 1) 477 } 478 } 479 480 var val uint64 481 if signBit { 482 val = ^uint64(0) 483 } 484 for i := range data { 485 val <<= 8 486 val |= uint64(data[i]) 487 } 488 v.SetInt(int64(val)) 489 return nil 490 } 491 492 func (u *unmarshaller) unmarshalUint(v reflect.Value, fi *fieldInstance) error { 493 if fi.is(fieldVarint) { 494 val, err := u.readUvarint() 495 if err != nil { 496 return err 497 } 498 v.SetUint(val) 499 return nil 500 } 501 502 if fi.size == 0 { 503 return fmt.Errorf("size must not be zero: %s", fi.name) 504 } 505 506 data, err := u.reader.ReadBits(fi.size) 507 if err != nil { 508 return err 509 } 510 u.rbits += uint64(fi.size) 511 512 val := uint64(0) 513 for i := range data { 514 val <<= 8 515 val |= uint64(data[i]) 516 } 517 v.SetUint(val) 518 519 return nil 520 } 521 522 func (u *unmarshaller) unmarshalBool(v reflect.Value, fi *fieldInstance) error { 523 if fi.size == 0 { 524 return fmt.Errorf("size must not be zero: %s", fi.name) 525 } 526 527 data, err := u.reader.ReadBits(fi.size) 528 if err != nil { 529 return err 530 } 531 u.rbits += uint64(fi.size) 532 533 val := false 534 for _, b := range data { 535 val = val || (b != byte(0)) 536 } 537 v.SetBool(val) 538 539 return nil 540 } 541 542 func (u *unmarshaller) unmarshalString(v reflect.Value, fi *fieldInstance) error { 543 switch fi.strType { 544 case stringType_C: 545 return u.unmarshalStringC(v) 546 case stringType_C_P: 547 return u.unmarshalStringCP(v, fi) 548 default: 549 return fmt.Errorf("unknown string type: %d", fi.strType) 550 } 551 } 552 553 func (u *unmarshaller) unmarshalStringC(v reflect.Value) error { 554 data := make([]byte, 0, 16) 555 for { 556 if u.rbits >= u.size*8 { 557 break 558 } 559 560 c, err := u.reader.ReadBits(8) 561 if err != nil { 562 return err 563 } 564 u.rbits += 8 565 566 if c[0] == 0 { 567 break // null character 568 } 569 570 data = append(data, c[0]) 571 } 572 v.SetString(string(data)) 573 574 return nil 575 } 576 577 func (u *unmarshaller) unmarshalStringCP(v reflect.Value, fi *fieldInstance) error { 578 if ok, err := u.tryReadPString(v, fi); err != nil { 579 return err 580 } else if ok { 581 return nil 582 } 583 return u.unmarshalStringC(v) 584 } 585 586 func (u *unmarshaller) tryReadPString(v reflect.Value, fi *fieldInstance) (ok bool, err error) { 587 remainingSize := (u.size*8 - u.rbits) / 8 588 if remainingSize < 2 { 589 return false, nil 590 } 591 592 offset, err := u.reader.Seek(0, io.SeekCurrent) 593 if err != nil { 594 return false, err 595 } 596 defer func() { 597 if err == nil && !ok { 598 _, err = u.reader.Seek(offset, io.SeekStart) 599 } 600 }() 601 602 buf0 := make([]byte, 1) 603 if _, err := io.ReadFull(u.reader, buf0); err != nil { 604 return false, err 605 } 606 remainingSize-- 607 plen := buf0[0] 608 if uint64(plen) > remainingSize { 609 return false, nil 610 } 611 buf := make([]byte, int(plen)) 612 if _, err := io.ReadFull(u.reader, buf); err != nil { 613 return false, err 614 } 615 remainingSize -= uint64(plen) 616 if fi.cfo.IsPString(fi.name, buf, remainingSize, u.ctx) { 617 u.rbits += uint64(len(buf)+1) * 8 618 v.SetString(string(buf)) 619 return true, nil 620 } 621 return false, nil 622 } 623 624 func (u *unmarshaller) readUvarint() (uint64, error) { 625 var val uint64 626 for { 627 octet, err := u.reader.ReadBits(8) 628 if err != nil { 629 return 0, err 630 } 631 u.rbits += 8 632 633 val = (val << 7) + uint64(octet[0]&0x7f) 634 635 if octet[0]&0x80 == 0 { 636 return val, nil 637 } 638 } 639 }