transport_unix.go (5390B)
1 //+build !windows,!solaris 2 3 package dbus 4 5 import ( 6 "bytes" 7 "encoding/binary" 8 "errors" 9 "io" 10 "net" 11 "syscall" 12 ) 13 14 type oobReader struct { 15 conn *net.UnixConn 16 oob []byte 17 buf [4096]byte 18 } 19 20 func (o *oobReader) Read(b []byte) (n int, err error) { 21 n, oobn, flags, _, err := o.conn.ReadMsgUnix(b, o.buf[:]) 22 if err != nil { 23 return n, err 24 } 25 if flags&syscall.MSG_CTRUNC != 0 { 26 return n, errors.New("dbus: control data truncated (too many fds received)") 27 } 28 o.oob = append(o.oob, o.buf[:oobn]...) 29 return n, nil 30 } 31 32 type unixTransport struct { 33 *net.UnixConn 34 rdr *oobReader 35 hasUnixFDs bool 36 } 37 38 func newUnixTransport(keys string) (transport, error) { 39 var err error 40 41 t := new(unixTransport) 42 abstract := getKey(keys, "abstract") 43 path := getKey(keys, "path") 44 switch { 45 case abstract == "" && path == "": 46 return nil, errors.New("dbus: invalid address (neither path nor abstract set)") 47 case abstract != "" && path == "": 48 t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: "@" + abstract, Net: "unix"}) 49 if err != nil { 50 return nil, err 51 } 52 return t, nil 53 case abstract == "" && path != "": 54 t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: path, Net: "unix"}) 55 if err != nil { 56 return nil, err 57 } 58 return t, nil 59 default: 60 return nil, errors.New("dbus: invalid address (both path and abstract set)") 61 } 62 } 63 64 func init() { 65 transports["unix"] = newUnixTransport 66 } 67 68 func (t *unixTransport) EnableUnixFDs() { 69 t.hasUnixFDs = true 70 } 71 72 func (t *unixTransport) ReadMessage() (*Message, error) { 73 var ( 74 blen, hlen uint32 75 csheader [16]byte 76 headers []header 77 order binary.ByteOrder 78 unixfds uint32 79 ) 80 // To be sure that all bytes of out-of-band data are read, we use a special 81 // reader that uses ReadUnix on the underlying connection instead of Read 82 // and gathers the out-of-band data in a buffer. 83 if t.rdr == nil { 84 t.rdr = &oobReader{conn: t.UnixConn} 85 } else { 86 t.rdr.oob = nil 87 } 88 89 // read the first 16 bytes (the part of the header that has a constant size), 90 // from which we can figure out the length of the rest of the message 91 if _, err := io.ReadFull(t.rdr, csheader[:]); err != nil { 92 return nil, err 93 } 94 switch csheader[0] { 95 case 'l': 96 order = binary.LittleEndian 97 case 'B': 98 order = binary.BigEndian 99 default: 100 return nil, InvalidMessageError("invalid byte order") 101 } 102 // csheader[4:8] -> length of message body, csheader[12:16] -> length of 103 // header fields (without alignment) 104 binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen) 105 binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen) 106 if hlen%8 != 0 { 107 hlen += 8 - (hlen % 8) 108 } 109 110 // decode headers and look for unix fds 111 headerdata := make([]byte, hlen+4) 112 copy(headerdata, csheader[12:]) 113 if _, err := io.ReadFull(t.rdr, headerdata[4:]); err != nil { 114 return nil, err 115 } 116 dec := newDecoder(bytes.NewBuffer(headerdata), order) 117 dec.pos = 12 118 vs, err := dec.Decode(Signature{"a(yv)"}) 119 if err != nil { 120 return nil, err 121 } 122 Store(vs, &headers) 123 for _, v := range headers { 124 if v.Field == byte(FieldUnixFDs) { 125 unixfds, _ = v.Variant.value.(uint32) 126 } 127 } 128 all := make([]byte, 16+hlen+blen) 129 copy(all, csheader[:]) 130 copy(all[16:], headerdata[4:]) 131 if _, err := io.ReadFull(t.rdr, all[16+hlen:]); err != nil { 132 return nil, err 133 } 134 if unixfds != 0 { 135 if !t.hasUnixFDs { 136 return nil, errors.New("dbus: got unix fds on unsupported transport") 137 } 138 // read the fds from the OOB data 139 scms, err := syscall.ParseSocketControlMessage(t.rdr.oob) 140 if err != nil { 141 return nil, err 142 } 143 if len(scms) != 1 { 144 return nil, errors.New("dbus: received more than one socket control message") 145 } 146 fds, err := syscall.ParseUnixRights(&scms[0]) 147 if err != nil { 148 return nil, err 149 } 150 msg, err := DecodeMessage(bytes.NewBuffer(all)) 151 if err != nil { 152 return nil, err 153 } 154 // substitute the values in the message body (which are indices for the 155 // array receiver via OOB) with the actual values 156 for i, v := range msg.Body { 157 switch v.(type) { 158 case UnixFDIndex: 159 j := v.(UnixFDIndex) 160 if uint32(j) >= unixfds { 161 return nil, InvalidMessageError("invalid index for unix fd") 162 } 163 msg.Body[i] = UnixFD(fds[j]) 164 case []UnixFDIndex: 165 idxArray := v.([]UnixFDIndex) 166 fdArray := make([]UnixFD, len(idxArray)) 167 for k, j := range idxArray { 168 if uint32(j) >= unixfds { 169 return nil, InvalidMessageError("invalid index for unix fd") 170 } 171 fdArray[k] = UnixFD(fds[j]) 172 } 173 msg.Body[i] = fdArray 174 } 175 } 176 return msg, nil 177 } 178 return DecodeMessage(bytes.NewBuffer(all)) 179 } 180 181 func (t *unixTransport) SendMessage(msg *Message) error { 182 fds := make([]int, 0) 183 for i, v := range msg.Body { 184 if fd, ok := v.(UnixFD); ok { 185 msg.Body[i] = UnixFDIndex(len(fds)) 186 fds = append(fds, int(fd)) 187 } 188 } 189 if len(fds) != 0 { 190 if !t.hasUnixFDs { 191 return errors.New("dbus: unix fd passing not enabled") 192 } 193 msg.Headers[FieldUnixFDs] = MakeVariant(uint32(len(fds))) 194 oob := syscall.UnixRights(fds...) 195 buf := new(bytes.Buffer) 196 msg.EncodeTo(buf, nativeEndian) 197 n, oobn, err := t.UnixConn.WriteMsgUnix(buf.Bytes(), oob, nil) 198 if err != nil { 199 return err 200 } 201 if n != buf.Len() || oobn != len(oob) { 202 return io.ErrShortWrite 203 } 204 } else { 205 if err := msg.EncodeTo(t, nativeEndian); err != nil { 206 return err 207 } 208 } 209 return nil 210 } 211 212 func (t *unixTransport) SupportsUnixFDs() bool { 213 return true 214 }