decode.go (17716B)
1 // Copyright 2019 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package protojson 6 7 import ( 8 "encoding/base64" 9 "fmt" 10 "math" 11 "strconv" 12 "strings" 13 14 "google.golang.org/protobuf/internal/encoding/json" 15 "google.golang.org/protobuf/internal/encoding/messageset" 16 "google.golang.org/protobuf/internal/errors" 17 "google.golang.org/protobuf/internal/flags" 18 "google.golang.org/protobuf/internal/genid" 19 "google.golang.org/protobuf/internal/pragma" 20 "google.golang.org/protobuf/internal/set" 21 "google.golang.org/protobuf/proto" 22 "google.golang.org/protobuf/reflect/protoreflect" 23 "google.golang.org/protobuf/reflect/protoregistry" 24 ) 25 26 // Unmarshal reads the given []byte into the given proto.Message. 27 // The provided message must be mutable (e.g., a non-nil pointer to a message). 28 func Unmarshal(b []byte, m proto.Message) error { 29 return UnmarshalOptions{}.Unmarshal(b, m) 30 } 31 32 // UnmarshalOptions is a configurable JSON format parser. 33 type UnmarshalOptions struct { 34 pragma.NoUnkeyedLiterals 35 36 // If AllowPartial is set, input for messages that will result in missing 37 // required fields will not return an error. 38 AllowPartial bool 39 40 // If DiscardUnknown is set, unknown fields are ignored. 41 DiscardUnknown bool 42 43 // Resolver is used for looking up types when unmarshaling 44 // google.protobuf.Any messages or extension fields. 45 // If nil, this defaults to using protoregistry.GlobalTypes. 46 Resolver interface { 47 protoregistry.MessageTypeResolver 48 protoregistry.ExtensionTypeResolver 49 } 50 } 51 52 // Unmarshal reads the given []byte and populates the given proto.Message 53 // using options in the UnmarshalOptions object. 54 // It will clear the message first before setting the fields. 55 // If it returns an error, the given message may be partially set. 56 // The provided message must be mutable (e.g., a non-nil pointer to a message). 57 func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error { 58 return o.unmarshal(b, m) 59 } 60 61 // unmarshal is a centralized function that all unmarshal operations go through. 62 // For profiling purposes, avoid changing the name of this function or 63 // introducing other code paths for unmarshal that do not go through this. 64 func (o UnmarshalOptions) unmarshal(b []byte, m proto.Message) error { 65 proto.Reset(m) 66 67 if o.Resolver == nil { 68 o.Resolver = protoregistry.GlobalTypes 69 } 70 71 dec := decoder{json.NewDecoder(b), o} 72 if err := dec.unmarshalMessage(m.ProtoReflect(), false); err != nil { 73 return err 74 } 75 76 // Check for EOF. 77 tok, err := dec.Read() 78 if err != nil { 79 return err 80 } 81 if tok.Kind() != json.EOF { 82 return dec.unexpectedTokenError(tok) 83 } 84 85 if o.AllowPartial { 86 return nil 87 } 88 return proto.CheckInitialized(m) 89 } 90 91 type decoder struct { 92 *json.Decoder 93 opts UnmarshalOptions 94 } 95 96 // newError returns an error object with position info. 97 func (d decoder) newError(pos int, f string, x ...interface{}) error { 98 line, column := d.Position(pos) 99 head := fmt.Sprintf("(line %d:%d): ", line, column) 100 return errors.New(head+f, x...) 101 } 102 103 // unexpectedTokenError returns a syntax error for the given unexpected token. 104 func (d decoder) unexpectedTokenError(tok json.Token) error { 105 return d.syntaxError(tok.Pos(), "unexpected token %s", tok.RawString()) 106 } 107 108 // syntaxError returns a syntax error for given position. 109 func (d decoder) syntaxError(pos int, f string, x ...interface{}) error { 110 line, column := d.Position(pos) 111 head := fmt.Sprintf("syntax error (line %d:%d): ", line, column) 112 return errors.New(head+f, x...) 113 } 114 115 // unmarshalMessage unmarshals a message into the given protoreflect.Message. 116 func (d decoder) unmarshalMessage(m protoreflect.Message, skipTypeURL bool) error { 117 if unmarshal := wellKnownTypeUnmarshaler(m.Descriptor().FullName()); unmarshal != nil { 118 return unmarshal(d, m) 119 } 120 121 tok, err := d.Read() 122 if err != nil { 123 return err 124 } 125 if tok.Kind() != json.ObjectOpen { 126 return d.unexpectedTokenError(tok) 127 } 128 129 messageDesc := m.Descriptor() 130 if !flags.ProtoLegacy && messageset.IsMessageSet(messageDesc) { 131 return errors.New("no support for proto1 MessageSets") 132 } 133 134 var seenNums set.Ints 135 var seenOneofs set.Ints 136 fieldDescs := messageDesc.Fields() 137 for { 138 // Read field name. 139 tok, err := d.Read() 140 if err != nil { 141 return err 142 } 143 switch tok.Kind() { 144 default: 145 return d.unexpectedTokenError(tok) 146 case json.ObjectClose: 147 return nil 148 case json.Name: 149 // Continue below. 150 } 151 152 name := tok.Name() 153 // Unmarshaling a non-custom embedded message in Any will contain the 154 // JSON field "@type" which should be skipped because it is not a field 155 // of the embedded message, but simply an artifact of the Any format. 156 if skipTypeURL && name == "@type" { 157 d.Read() 158 continue 159 } 160 161 // Get the FieldDescriptor. 162 var fd protoreflect.FieldDescriptor 163 if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") { 164 // Only extension names are in [name] format. 165 extName := protoreflect.FullName(name[1 : len(name)-1]) 166 extType, err := d.opts.Resolver.FindExtensionByName(extName) 167 if err != nil && err != protoregistry.NotFound { 168 return d.newError(tok.Pos(), "unable to resolve %s: %v", tok.RawString(), err) 169 } 170 if extType != nil { 171 fd = extType.TypeDescriptor() 172 if !messageDesc.ExtensionRanges().Has(fd.Number()) || fd.ContainingMessage().FullName() != messageDesc.FullName() { 173 return d.newError(tok.Pos(), "message %v cannot be extended by %v", messageDesc.FullName(), fd.FullName()) 174 } 175 } 176 } else { 177 // The name can either be the JSON name or the proto field name. 178 fd = fieldDescs.ByJSONName(name) 179 if fd == nil { 180 fd = fieldDescs.ByTextName(name) 181 } 182 } 183 if flags.ProtoLegacy { 184 if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() { 185 fd = nil // reset since the weak reference is not linked in 186 } 187 } 188 189 if fd == nil { 190 // Field is unknown. 191 if d.opts.DiscardUnknown { 192 if err := d.skipJSONValue(); err != nil { 193 return err 194 } 195 continue 196 } 197 return d.newError(tok.Pos(), "unknown field %v", tok.RawString()) 198 } 199 200 // Do not allow duplicate fields. 201 num := uint64(fd.Number()) 202 if seenNums.Has(num) { 203 return d.newError(tok.Pos(), "duplicate field %v", tok.RawString()) 204 } 205 seenNums.Set(num) 206 207 // No need to set values for JSON null unless the field type is 208 // google.protobuf.Value or google.protobuf.NullValue. 209 if tok, _ := d.Peek(); tok.Kind() == json.Null && !isKnownValue(fd) && !isNullValue(fd) { 210 d.Read() 211 continue 212 } 213 214 switch { 215 case fd.IsList(): 216 list := m.Mutable(fd).List() 217 if err := d.unmarshalList(list, fd); err != nil { 218 return err 219 } 220 case fd.IsMap(): 221 mmap := m.Mutable(fd).Map() 222 if err := d.unmarshalMap(mmap, fd); err != nil { 223 return err 224 } 225 default: 226 // If field is a oneof, check if it has already been set. 227 if od := fd.ContainingOneof(); od != nil { 228 idx := uint64(od.Index()) 229 if seenOneofs.Has(idx) { 230 return d.newError(tok.Pos(), "error parsing %s, oneof %v is already set", tok.RawString(), od.FullName()) 231 } 232 seenOneofs.Set(idx) 233 } 234 235 // Required or optional fields. 236 if err := d.unmarshalSingular(m, fd); err != nil { 237 return err 238 } 239 } 240 } 241 } 242 243 func isKnownValue(fd protoreflect.FieldDescriptor) bool { 244 md := fd.Message() 245 return md != nil && md.FullName() == genid.Value_message_fullname 246 } 247 248 func isNullValue(fd protoreflect.FieldDescriptor) bool { 249 ed := fd.Enum() 250 return ed != nil && ed.FullName() == genid.NullValue_enum_fullname 251 } 252 253 // unmarshalSingular unmarshals to the non-repeated field specified 254 // by the given FieldDescriptor. 255 func (d decoder) unmarshalSingular(m protoreflect.Message, fd protoreflect.FieldDescriptor) error { 256 var val protoreflect.Value 257 var err error 258 switch fd.Kind() { 259 case protoreflect.MessageKind, protoreflect.GroupKind: 260 val = m.NewField(fd) 261 err = d.unmarshalMessage(val.Message(), false) 262 default: 263 val, err = d.unmarshalScalar(fd) 264 } 265 266 if err != nil { 267 return err 268 } 269 m.Set(fd, val) 270 return nil 271 } 272 273 // unmarshalScalar unmarshals to a scalar/enum protoreflect.Value specified by 274 // the given FieldDescriptor. 275 func (d decoder) unmarshalScalar(fd protoreflect.FieldDescriptor) (protoreflect.Value, error) { 276 const b32 int = 32 277 const b64 int = 64 278 279 tok, err := d.Read() 280 if err != nil { 281 return protoreflect.Value{}, err 282 } 283 284 kind := fd.Kind() 285 switch kind { 286 case protoreflect.BoolKind: 287 if tok.Kind() == json.Bool { 288 return protoreflect.ValueOfBool(tok.Bool()), nil 289 } 290 291 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 292 if v, ok := unmarshalInt(tok, b32); ok { 293 return v, nil 294 } 295 296 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 297 if v, ok := unmarshalInt(tok, b64); ok { 298 return v, nil 299 } 300 301 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 302 if v, ok := unmarshalUint(tok, b32); ok { 303 return v, nil 304 } 305 306 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 307 if v, ok := unmarshalUint(tok, b64); ok { 308 return v, nil 309 } 310 311 case protoreflect.FloatKind: 312 if v, ok := unmarshalFloat(tok, b32); ok { 313 return v, nil 314 } 315 316 case protoreflect.DoubleKind: 317 if v, ok := unmarshalFloat(tok, b64); ok { 318 return v, nil 319 } 320 321 case protoreflect.StringKind: 322 if tok.Kind() == json.String { 323 return protoreflect.ValueOfString(tok.ParsedString()), nil 324 } 325 326 case protoreflect.BytesKind: 327 if v, ok := unmarshalBytes(tok); ok { 328 return v, nil 329 } 330 331 case protoreflect.EnumKind: 332 if v, ok := unmarshalEnum(tok, fd); ok { 333 return v, nil 334 } 335 336 default: 337 panic(fmt.Sprintf("unmarshalScalar: invalid scalar kind %v", kind)) 338 } 339 340 return protoreflect.Value{}, d.newError(tok.Pos(), "invalid value for %v type: %v", kind, tok.RawString()) 341 } 342 343 func unmarshalInt(tok json.Token, bitSize int) (protoreflect.Value, bool) { 344 switch tok.Kind() { 345 case json.Number: 346 return getInt(tok, bitSize) 347 348 case json.String: 349 // Decode number from string. 350 s := strings.TrimSpace(tok.ParsedString()) 351 if len(s) != len(tok.ParsedString()) { 352 return protoreflect.Value{}, false 353 } 354 dec := json.NewDecoder([]byte(s)) 355 tok, err := dec.Read() 356 if err != nil { 357 return protoreflect.Value{}, false 358 } 359 return getInt(tok, bitSize) 360 } 361 return protoreflect.Value{}, false 362 } 363 364 func getInt(tok json.Token, bitSize int) (protoreflect.Value, bool) { 365 n, ok := tok.Int(bitSize) 366 if !ok { 367 return protoreflect.Value{}, false 368 } 369 if bitSize == 32 { 370 return protoreflect.ValueOfInt32(int32(n)), true 371 } 372 return protoreflect.ValueOfInt64(n), true 373 } 374 375 func unmarshalUint(tok json.Token, bitSize int) (protoreflect.Value, bool) { 376 switch tok.Kind() { 377 case json.Number: 378 return getUint(tok, bitSize) 379 380 case json.String: 381 // Decode number from string. 382 s := strings.TrimSpace(tok.ParsedString()) 383 if len(s) != len(tok.ParsedString()) { 384 return protoreflect.Value{}, false 385 } 386 dec := json.NewDecoder([]byte(s)) 387 tok, err := dec.Read() 388 if err != nil { 389 return protoreflect.Value{}, false 390 } 391 return getUint(tok, bitSize) 392 } 393 return protoreflect.Value{}, false 394 } 395 396 func getUint(tok json.Token, bitSize int) (protoreflect.Value, bool) { 397 n, ok := tok.Uint(bitSize) 398 if !ok { 399 return protoreflect.Value{}, false 400 } 401 if bitSize == 32 { 402 return protoreflect.ValueOfUint32(uint32(n)), true 403 } 404 return protoreflect.ValueOfUint64(n), true 405 } 406 407 func unmarshalFloat(tok json.Token, bitSize int) (protoreflect.Value, bool) { 408 switch tok.Kind() { 409 case json.Number: 410 return getFloat(tok, bitSize) 411 412 case json.String: 413 s := tok.ParsedString() 414 switch s { 415 case "NaN": 416 if bitSize == 32 { 417 return protoreflect.ValueOfFloat32(float32(math.NaN())), true 418 } 419 return protoreflect.ValueOfFloat64(math.NaN()), true 420 case "Infinity": 421 if bitSize == 32 { 422 return protoreflect.ValueOfFloat32(float32(math.Inf(+1))), true 423 } 424 return protoreflect.ValueOfFloat64(math.Inf(+1)), true 425 case "-Infinity": 426 if bitSize == 32 { 427 return protoreflect.ValueOfFloat32(float32(math.Inf(-1))), true 428 } 429 return protoreflect.ValueOfFloat64(math.Inf(-1)), true 430 } 431 432 // Decode number from string. 433 if len(s) != len(strings.TrimSpace(s)) { 434 return protoreflect.Value{}, false 435 } 436 dec := json.NewDecoder([]byte(s)) 437 tok, err := dec.Read() 438 if err != nil { 439 return protoreflect.Value{}, false 440 } 441 return getFloat(tok, bitSize) 442 } 443 return protoreflect.Value{}, false 444 } 445 446 func getFloat(tok json.Token, bitSize int) (protoreflect.Value, bool) { 447 n, ok := tok.Float(bitSize) 448 if !ok { 449 return protoreflect.Value{}, false 450 } 451 if bitSize == 32 { 452 return protoreflect.ValueOfFloat32(float32(n)), true 453 } 454 return protoreflect.ValueOfFloat64(n), true 455 } 456 457 func unmarshalBytes(tok json.Token) (protoreflect.Value, bool) { 458 if tok.Kind() != json.String { 459 return protoreflect.Value{}, false 460 } 461 462 s := tok.ParsedString() 463 enc := base64.StdEncoding 464 if strings.ContainsAny(s, "-_") { 465 enc = base64.URLEncoding 466 } 467 if len(s)%4 != 0 { 468 enc = enc.WithPadding(base64.NoPadding) 469 } 470 b, err := enc.DecodeString(s) 471 if err != nil { 472 return protoreflect.Value{}, false 473 } 474 return protoreflect.ValueOfBytes(b), true 475 } 476 477 func unmarshalEnum(tok json.Token, fd protoreflect.FieldDescriptor) (protoreflect.Value, bool) { 478 switch tok.Kind() { 479 case json.String: 480 // Lookup EnumNumber based on name. 481 s := tok.ParsedString() 482 if enumVal := fd.Enum().Values().ByName(protoreflect.Name(s)); enumVal != nil { 483 return protoreflect.ValueOfEnum(enumVal.Number()), true 484 } 485 486 case json.Number: 487 if n, ok := tok.Int(32); ok { 488 return protoreflect.ValueOfEnum(protoreflect.EnumNumber(n)), true 489 } 490 491 case json.Null: 492 // This is only valid for google.protobuf.NullValue. 493 if isNullValue(fd) { 494 return protoreflect.ValueOfEnum(0), true 495 } 496 } 497 498 return protoreflect.Value{}, false 499 } 500 501 func (d decoder) unmarshalList(list protoreflect.List, fd protoreflect.FieldDescriptor) error { 502 tok, err := d.Read() 503 if err != nil { 504 return err 505 } 506 if tok.Kind() != json.ArrayOpen { 507 return d.unexpectedTokenError(tok) 508 } 509 510 switch fd.Kind() { 511 case protoreflect.MessageKind, protoreflect.GroupKind: 512 for { 513 tok, err := d.Peek() 514 if err != nil { 515 return err 516 } 517 518 if tok.Kind() == json.ArrayClose { 519 d.Read() 520 return nil 521 } 522 523 val := list.NewElement() 524 if err := d.unmarshalMessage(val.Message(), false); err != nil { 525 return err 526 } 527 list.Append(val) 528 } 529 default: 530 for { 531 tok, err := d.Peek() 532 if err != nil { 533 return err 534 } 535 536 if tok.Kind() == json.ArrayClose { 537 d.Read() 538 return nil 539 } 540 541 val, err := d.unmarshalScalar(fd) 542 if err != nil { 543 return err 544 } 545 list.Append(val) 546 } 547 } 548 549 return nil 550 } 551 552 func (d decoder) unmarshalMap(mmap protoreflect.Map, fd protoreflect.FieldDescriptor) error { 553 tok, err := d.Read() 554 if err != nil { 555 return err 556 } 557 if tok.Kind() != json.ObjectOpen { 558 return d.unexpectedTokenError(tok) 559 } 560 561 // Determine ahead whether map entry is a scalar type or a message type in 562 // order to call the appropriate unmarshalMapValue func inside the for loop 563 // below. 564 var unmarshalMapValue func() (protoreflect.Value, error) 565 switch fd.MapValue().Kind() { 566 case protoreflect.MessageKind, protoreflect.GroupKind: 567 unmarshalMapValue = func() (protoreflect.Value, error) { 568 val := mmap.NewValue() 569 if err := d.unmarshalMessage(val.Message(), false); err != nil { 570 return protoreflect.Value{}, err 571 } 572 return val, nil 573 } 574 default: 575 unmarshalMapValue = func() (protoreflect.Value, error) { 576 return d.unmarshalScalar(fd.MapValue()) 577 } 578 } 579 580 Loop: 581 for { 582 // Read field name. 583 tok, err := d.Read() 584 if err != nil { 585 return err 586 } 587 switch tok.Kind() { 588 default: 589 return d.unexpectedTokenError(tok) 590 case json.ObjectClose: 591 break Loop 592 case json.Name: 593 // Continue. 594 } 595 596 // Unmarshal field name. 597 pkey, err := d.unmarshalMapKey(tok, fd.MapKey()) 598 if err != nil { 599 return err 600 } 601 602 // Check for duplicate field name. 603 if mmap.Has(pkey) { 604 return d.newError(tok.Pos(), "duplicate map key %v", tok.RawString()) 605 } 606 607 // Read and unmarshal field value. 608 pval, err := unmarshalMapValue() 609 if err != nil { 610 return err 611 } 612 613 mmap.Set(pkey, pval) 614 } 615 616 return nil 617 } 618 619 // unmarshalMapKey converts given token of Name kind into a protoreflect.MapKey. 620 // A map key type is any integral or string type. 621 func (d decoder) unmarshalMapKey(tok json.Token, fd protoreflect.FieldDescriptor) (protoreflect.MapKey, error) { 622 const b32 = 32 623 const b64 = 64 624 const base10 = 10 625 626 name := tok.Name() 627 kind := fd.Kind() 628 switch kind { 629 case protoreflect.StringKind: 630 return protoreflect.ValueOfString(name).MapKey(), nil 631 632 case protoreflect.BoolKind: 633 switch name { 634 case "true": 635 return protoreflect.ValueOfBool(true).MapKey(), nil 636 case "false": 637 return protoreflect.ValueOfBool(false).MapKey(), nil 638 } 639 640 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 641 if n, err := strconv.ParseInt(name, base10, b32); err == nil { 642 return protoreflect.ValueOfInt32(int32(n)).MapKey(), nil 643 } 644 645 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 646 if n, err := strconv.ParseInt(name, base10, b64); err == nil { 647 return protoreflect.ValueOfInt64(int64(n)).MapKey(), nil 648 } 649 650 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 651 if n, err := strconv.ParseUint(name, base10, b32); err == nil { 652 return protoreflect.ValueOfUint32(uint32(n)).MapKey(), nil 653 } 654 655 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 656 if n, err := strconv.ParseUint(name, base10, b64); err == nil { 657 return protoreflect.ValueOfUint64(uint64(n)).MapKey(), nil 658 } 659 660 default: 661 panic(fmt.Sprintf("invalid kind for map key: %v", kind)) 662 } 663 664 return protoreflect.MapKey{}, d.newError(tok.Pos(), "invalid value for %v key: %s", kind, tok.RawString()) 665 }