startup_message.go (2333B)
1 package pgproto3 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "encoding/json" 7 "errors" 8 "fmt" 9 10 "github.com/jackc/pgio" 11 ) 12 13 const ProtocolVersionNumber = 196608 // 3.0 14 15 type StartupMessage struct { 16 ProtocolVersion uint32 17 Parameters map[string]string 18 } 19 20 // Frontend identifies this message as sendable by a PostgreSQL frontend. 21 func (*StartupMessage) Frontend() {} 22 23 // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 24 // type identifier and 4 byte message length. 25 func (dst *StartupMessage) Decode(src []byte) error { 26 if len(src) < 4 { 27 return errors.New("startup message too short") 28 } 29 30 dst.ProtocolVersion = binary.BigEndian.Uint32(src) 31 rp := 4 32 33 if dst.ProtocolVersion != ProtocolVersionNumber { 34 return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) 35 } 36 37 dst.Parameters = make(map[string]string) 38 for { 39 idx := bytes.IndexByte(src[rp:], 0) 40 if idx < 0 { 41 return &invalidMessageFormatErr{messageType: "StartupMesage"} 42 } 43 key := string(src[rp : rp+idx]) 44 rp += idx + 1 45 46 idx = bytes.IndexByte(src[rp:], 0) 47 if idx < 0 { 48 return &invalidMessageFormatErr{messageType: "StartupMesage"} 49 } 50 value := string(src[rp : rp+idx]) 51 rp += idx + 1 52 53 dst.Parameters[key] = value 54 55 if len(src[rp:]) == 1 { 56 if src[rp] != 0 { 57 return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) 58 } 59 break 60 } 61 } 62 63 return nil 64 } 65 66 // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 67 func (src *StartupMessage) Encode(dst []byte) []byte { 68 sp := len(dst) 69 dst = pgio.AppendInt32(dst, -1) 70 71 dst = pgio.AppendUint32(dst, src.ProtocolVersion) 72 for k, v := range src.Parameters { 73 dst = append(dst, k...) 74 dst = append(dst, 0) 75 dst = append(dst, v...) 76 dst = append(dst, 0) 77 } 78 dst = append(dst, 0) 79 80 pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 81 82 return dst 83 } 84 85 // MarshalJSON implements encoding/json.Marshaler. 86 func (src StartupMessage) MarshalJSON() ([]byte, error) { 87 return json.Marshal(struct { 88 Type string 89 ProtocolVersion uint32 90 Parameters map[string]string 91 }{ 92 Type: "StartupMessage", 93 ProtocolVersion: src.ProtocolVersion, 94 Parameters: src.Parameters, 95 }) 96 }