gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

transport_unix.go (5390B)


      1 //+build !windows,!solaris
      2 
      3 package dbus
      4 
      5 import (
      6 	"bytes"
      7 	"encoding/binary"
      8 	"errors"
      9 	"io"
     10 	"net"
     11 	"syscall"
     12 )
     13 
     14 type oobReader struct {
     15 	conn *net.UnixConn
     16 	oob  []byte
     17 	buf  [4096]byte
     18 }
     19 
     20 func (o *oobReader) Read(b []byte) (n int, err error) {
     21 	n, oobn, flags, _, err := o.conn.ReadMsgUnix(b, o.buf[:])
     22 	if err != nil {
     23 		return n, err
     24 	}
     25 	if flags&syscall.MSG_CTRUNC != 0 {
     26 		return n, errors.New("dbus: control data truncated (too many fds received)")
     27 	}
     28 	o.oob = append(o.oob, o.buf[:oobn]...)
     29 	return n, nil
     30 }
     31 
     32 type unixTransport struct {
     33 	*net.UnixConn
     34 	rdr        *oobReader
     35 	hasUnixFDs bool
     36 }
     37 
     38 func newUnixTransport(keys string) (transport, error) {
     39 	var err error
     40 
     41 	t := new(unixTransport)
     42 	abstract := getKey(keys, "abstract")
     43 	path := getKey(keys, "path")
     44 	switch {
     45 	case abstract == "" && path == "":
     46 		return nil, errors.New("dbus: invalid address (neither path nor abstract set)")
     47 	case abstract != "" && path == "":
     48 		t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: "@" + abstract, Net: "unix"})
     49 		if err != nil {
     50 			return nil, err
     51 		}
     52 		return t, nil
     53 	case abstract == "" && path != "":
     54 		t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: path, Net: "unix"})
     55 		if err != nil {
     56 			return nil, err
     57 		}
     58 		return t, nil
     59 	default:
     60 		return nil, errors.New("dbus: invalid address (both path and abstract set)")
     61 	}
     62 }
     63 
     64 func init() {
     65 	transports["unix"] = newUnixTransport
     66 }
     67 
     68 func (t *unixTransport) EnableUnixFDs() {
     69 	t.hasUnixFDs = true
     70 }
     71 
     72 func (t *unixTransport) ReadMessage() (*Message, error) {
     73 	var (
     74 		blen, hlen uint32
     75 		csheader   [16]byte
     76 		headers    []header
     77 		order      binary.ByteOrder
     78 		unixfds    uint32
     79 	)
     80 	// To be sure that all bytes of out-of-band data are read, we use a special
     81 	// reader that uses ReadUnix on the underlying connection instead of Read
     82 	// and gathers the out-of-band data in a buffer.
     83 	if t.rdr == nil {
     84 		t.rdr = &oobReader{conn: t.UnixConn}
     85 	} else {
     86 		t.rdr.oob = nil
     87 	}
     88 
     89 	// read the first 16 bytes (the part of the header that has a constant size),
     90 	// from which we can figure out the length of the rest of the message
     91 	if _, err := io.ReadFull(t.rdr, csheader[:]); err != nil {
     92 		return nil, err
     93 	}
     94 	switch csheader[0] {
     95 	case 'l':
     96 		order = binary.LittleEndian
     97 	case 'B':
     98 		order = binary.BigEndian
     99 	default:
    100 		return nil, InvalidMessageError("invalid byte order")
    101 	}
    102 	// csheader[4:8] -> length of message body, csheader[12:16] -> length of
    103 	// header fields (without alignment)
    104 	binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen)
    105 	binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen)
    106 	if hlen%8 != 0 {
    107 		hlen += 8 - (hlen % 8)
    108 	}
    109 
    110 	// decode headers and look for unix fds
    111 	headerdata := make([]byte, hlen+4)
    112 	copy(headerdata, csheader[12:])
    113 	if _, err := io.ReadFull(t.rdr, headerdata[4:]); err != nil {
    114 		return nil, err
    115 	}
    116 	dec := newDecoder(bytes.NewBuffer(headerdata), order)
    117 	dec.pos = 12
    118 	vs, err := dec.Decode(Signature{"a(yv)"})
    119 	if err != nil {
    120 		return nil, err
    121 	}
    122 	Store(vs, &headers)
    123 	for _, v := range headers {
    124 		if v.Field == byte(FieldUnixFDs) {
    125 			unixfds, _ = v.Variant.value.(uint32)
    126 		}
    127 	}
    128 	all := make([]byte, 16+hlen+blen)
    129 	copy(all, csheader[:])
    130 	copy(all[16:], headerdata[4:])
    131 	if _, err := io.ReadFull(t.rdr, all[16+hlen:]); err != nil {
    132 		return nil, err
    133 	}
    134 	if unixfds != 0 {
    135 		if !t.hasUnixFDs {
    136 			return nil, errors.New("dbus: got unix fds on unsupported transport")
    137 		}
    138 		// read the fds from the OOB data
    139 		scms, err := syscall.ParseSocketControlMessage(t.rdr.oob)
    140 		if err != nil {
    141 			return nil, err
    142 		}
    143 		if len(scms) != 1 {
    144 			return nil, errors.New("dbus: received more than one socket control message")
    145 		}
    146 		fds, err := syscall.ParseUnixRights(&scms[0])
    147 		if err != nil {
    148 			return nil, err
    149 		}
    150 		msg, err := DecodeMessage(bytes.NewBuffer(all))
    151 		if err != nil {
    152 			return nil, err
    153 		}
    154 		// substitute the values in the message body (which are indices for the
    155 		// array receiver via OOB) with the actual values
    156 		for i, v := range msg.Body {
    157 			switch v.(type) {
    158 			case UnixFDIndex:
    159 				j := v.(UnixFDIndex)
    160 				if uint32(j) >= unixfds {
    161 					return nil, InvalidMessageError("invalid index for unix fd")
    162 				}
    163 				msg.Body[i] = UnixFD(fds[j])
    164 			case []UnixFDIndex:
    165 				idxArray := v.([]UnixFDIndex)
    166 				fdArray := make([]UnixFD, len(idxArray))
    167 				for k, j := range idxArray {
    168 					if uint32(j) >= unixfds {
    169 						return nil, InvalidMessageError("invalid index for unix fd")
    170 					}
    171 					fdArray[k] = UnixFD(fds[j])
    172 				}
    173 				msg.Body[i] = fdArray
    174 			}
    175 		}
    176 		return msg, nil
    177 	}
    178 	return DecodeMessage(bytes.NewBuffer(all))
    179 }
    180 
    181 func (t *unixTransport) SendMessage(msg *Message) error {
    182 	fds := make([]int, 0)
    183 	for i, v := range msg.Body {
    184 		if fd, ok := v.(UnixFD); ok {
    185 			msg.Body[i] = UnixFDIndex(len(fds))
    186 			fds = append(fds, int(fd))
    187 		}
    188 	}
    189 	if len(fds) != 0 {
    190 		if !t.hasUnixFDs {
    191 			return errors.New("dbus: unix fd passing not enabled")
    192 		}
    193 		msg.Headers[FieldUnixFDs] = MakeVariant(uint32(len(fds)))
    194 		oob := syscall.UnixRights(fds...)
    195 		buf := new(bytes.Buffer)
    196 		msg.EncodeTo(buf, nativeEndian)
    197 		n, oobn, err := t.UnixConn.WriteMsgUnix(buf.Bytes(), oob, nil)
    198 		if err != nil {
    199 			return err
    200 		}
    201 		if n != buf.Len() || oobn != len(oob) {
    202 			return io.ErrShortWrite
    203 		}
    204 	} else {
    205 		if err := msg.EncodeTo(t, nativeEndian); err != nil {
    206 			return err
    207 		}
    208 	}
    209 	return nil
    210 }
    211 
    212 func (t *unixTransport) SupportsUnixFDs() bool {
    213 	return true
    214 }