gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

mux.go (7870B)


      1 // Copyright 2013 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package ssh
      6 
      7 import (
      8 	"encoding/binary"
      9 	"fmt"
     10 	"io"
     11 	"log"
     12 	"sync"
     13 	"sync/atomic"
     14 )
     15 
     16 // debugMux, if set, causes messages in the connection protocol to be
     17 // logged.
     18 const debugMux = false
     19 
     20 // chanList is a thread safe channel list.
     21 type chanList struct {
     22 	// protects concurrent access to chans
     23 	sync.Mutex
     24 
     25 	// chans are indexed by the local id of the channel, which the
     26 	// other side should send in the PeersId field.
     27 	chans []*channel
     28 
     29 	// This is a debugging aid: it offsets all IDs by this
     30 	// amount. This helps distinguish otherwise identical
     31 	// server/client muxes
     32 	offset uint32
     33 }
     34 
     35 // Assigns a channel ID to the given channel.
     36 func (c *chanList) add(ch *channel) uint32 {
     37 	c.Lock()
     38 	defer c.Unlock()
     39 	for i := range c.chans {
     40 		if c.chans[i] == nil {
     41 			c.chans[i] = ch
     42 			return uint32(i) + c.offset
     43 		}
     44 	}
     45 	c.chans = append(c.chans, ch)
     46 	return uint32(len(c.chans)-1) + c.offset
     47 }
     48 
     49 // getChan returns the channel for the given ID.
     50 func (c *chanList) getChan(id uint32) *channel {
     51 	id -= c.offset
     52 
     53 	c.Lock()
     54 	defer c.Unlock()
     55 	if id < uint32(len(c.chans)) {
     56 		return c.chans[id]
     57 	}
     58 	return nil
     59 }
     60 
     61 func (c *chanList) remove(id uint32) {
     62 	id -= c.offset
     63 	c.Lock()
     64 	if id < uint32(len(c.chans)) {
     65 		c.chans[id] = nil
     66 	}
     67 	c.Unlock()
     68 }
     69 
     70 // dropAll forgets all channels it knows, returning them in a slice.
     71 func (c *chanList) dropAll() []*channel {
     72 	c.Lock()
     73 	defer c.Unlock()
     74 	var r []*channel
     75 
     76 	for _, ch := range c.chans {
     77 		if ch == nil {
     78 			continue
     79 		}
     80 		r = append(r, ch)
     81 	}
     82 	c.chans = nil
     83 	return r
     84 }
     85 
     86 // mux represents the state for the SSH connection protocol, which
     87 // multiplexes many channels onto a single packet transport.
     88 type mux struct {
     89 	conn     packetConn
     90 	chanList chanList
     91 
     92 	incomingChannels chan NewChannel
     93 
     94 	globalSentMu     sync.Mutex
     95 	globalResponses  chan interface{}
     96 	incomingRequests chan *Request
     97 
     98 	errCond *sync.Cond
     99 	err     error
    100 }
    101 
    102 // When debugging, each new chanList instantiation has a different
    103 // offset.
    104 var globalOff uint32
    105 
    106 func (m *mux) Wait() error {
    107 	m.errCond.L.Lock()
    108 	defer m.errCond.L.Unlock()
    109 	for m.err == nil {
    110 		m.errCond.Wait()
    111 	}
    112 	return m.err
    113 }
    114 
    115 // newMux returns a mux that runs over the given connection.
    116 func newMux(p packetConn) *mux {
    117 	m := &mux{
    118 		conn:             p,
    119 		incomingChannels: make(chan NewChannel, chanSize),
    120 		globalResponses:  make(chan interface{}, 1),
    121 		incomingRequests: make(chan *Request, chanSize),
    122 		errCond:          newCond(),
    123 	}
    124 	if debugMux {
    125 		m.chanList.offset = atomic.AddUint32(&globalOff, 1)
    126 	}
    127 
    128 	go m.loop()
    129 	return m
    130 }
    131 
    132 func (m *mux) sendMessage(msg interface{}) error {
    133 	p := Marshal(msg)
    134 	if debugMux {
    135 		log.Printf("send global(%d): %#v", m.chanList.offset, msg)
    136 	}
    137 	return m.conn.writePacket(p)
    138 }
    139 
    140 func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
    141 	if wantReply {
    142 		m.globalSentMu.Lock()
    143 		defer m.globalSentMu.Unlock()
    144 	}
    145 
    146 	if err := m.sendMessage(globalRequestMsg{
    147 		Type:      name,
    148 		WantReply: wantReply,
    149 		Data:      payload,
    150 	}); err != nil {
    151 		return false, nil, err
    152 	}
    153 
    154 	if !wantReply {
    155 		return false, nil, nil
    156 	}
    157 
    158 	msg, ok := <-m.globalResponses
    159 	if !ok {
    160 		return false, nil, io.EOF
    161 	}
    162 	switch msg := msg.(type) {
    163 	case *globalRequestFailureMsg:
    164 		return false, msg.Data, nil
    165 	case *globalRequestSuccessMsg:
    166 		return true, msg.Data, nil
    167 	default:
    168 		return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
    169 	}
    170 }
    171 
    172 // ackRequest must be called after processing a global request that
    173 // has WantReply set.
    174 func (m *mux) ackRequest(ok bool, data []byte) error {
    175 	if ok {
    176 		return m.sendMessage(globalRequestSuccessMsg{Data: data})
    177 	}
    178 	return m.sendMessage(globalRequestFailureMsg{Data: data})
    179 }
    180 
    181 func (m *mux) Close() error {
    182 	return m.conn.Close()
    183 }
    184 
    185 // loop runs the connection machine. It will process packets until an
    186 // error is encountered. To synchronize on loop exit, use mux.Wait.
    187 func (m *mux) loop() {
    188 	var err error
    189 	for err == nil {
    190 		err = m.onePacket()
    191 	}
    192 
    193 	for _, ch := range m.chanList.dropAll() {
    194 		ch.close()
    195 	}
    196 
    197 	close(m.incomingChannels)
    198 	close(m.incomingRequests)
    199 	close(m.globalResponses)
    200 
    201 	m.conn.Close()
    202 
    203 	m.errCond.L.Lock()
    204 	m.err = err
    205 	m.errCond.Broadcast()
    206 	m.errCond.L.Unlock()
    207 
    208 	if debugMux {
    209 		log.Println("loop exit", err)
    210 	}
    211 }
    212 
    213 // onePacket reads and processes one packet.
    214 func (m *mux) onePacket() error {
    215 	packet, err := m.conn.readPacket()
    216 	if err != nil {
    217 		return err
    218 	}
    219 
    220 	if debugMux {
    221 		if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
    222 			log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
    223 		} else {
    224 			p, _ := decode(packet)
    225 			log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
    226 		}
    227 	}
    228 
    229 	switch packet[0] {
    230 	case msgChannelOpen:
    231 		return m.handleChannelOpen(packet)
    232 	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
    233 		return m.handleGlobalPacket(packet)
    234 	}
    235 
    236 	// assume a channel packet.
    237 	if len(packet) < 5 {
    238 		return parseError(packet[0])
    239 	}
    240 	id := binary.BigEndian.Uint32(packet[1:])
    241 	ch := m.chanList.getChan(id)
    242 	if ch == nil {
    243 		return m.handleUnknownChannelPacket(id, packet)
    244 	}
    245 
    246 	return ch.handlePacket(packet)
    247 }
    248 
    249 func (m *mux) handleGlobalPacket(packet []byte) error {
    250 	msg, err := decode(packet)
    251 	if err != nil {
    252 		return err
    253 	}
    254 
    255 	switch msg := msg.(type) {
    256 	case *globalRequestMsg:
    257 		m.incomingRequests <- &Request{
    258 			Type:      msg.Type,
    259 			WantReply: msg.WantReply,
    260 			Payload:   msg.Data,
    261 			mux:       m,
    262 		}
    263 	case *globalRequestSuccessMsg, *globalRequestFailureMsg:
    264 		m.globalResponses <- msg
    265 	default:
    266 		panic(fmt.Sprintf("not a global message %#v", msg))
    267 	}
    268 
    269 	return nil
    270 }
    271 
    272 // handleChannelOpen schedules a channel to be Accept()ed.
    273 func (m *mux) handleChannelOpen(packet []byte) error {
    274 	var msg channelOpenMsg
    275 	if err := Unmarshal(packet, &msg); err != nil {
    276 		return err
    277 	}
    278 
    279 	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
    280 		failMsg := channelOpenFailureMsg{
    281 			PeersID:  msg.PeersID,
    282 			Reason:   ConnectionFailed,
    283 			Message:  "invalid request",
    284 			Language: "en_US.UTF-8",
    285 		}
    286 		return m.sendMessage(failMsg)
    287 	}
    288 
    289 	c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
    290 	c.remoteId = msg.PeersID
    291 	c.maxRemotePayload = msg.MaxPacketSize
    292 	c.remoteWin.add(msg.PeersWindow)
    293 	m.incomingChannels <- c
    294 	return nil
    295 }
    296 
    297 func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
    298 	ch, err := m.openChannel(chanType, extra)
    299 	if err != nil {
    300 		return nil, nil, err
    301 	}
    302 
    303 	return ch, ch.incomingRequests, nil
    304 }
    305 
    306 func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
    307 	ch := m.newChannel(chanType, channelOutbound, extra)
    308 
    309 	ch.maxIncomingPayload = channelMaxPacket
    310 
    311 	open := channelOpenMsg{
    312 		ChanType:         chanType,
    313 		PeersWindow:      ch.myWindow,
    314 		MaxPacketSize:    ch.maxIncomingPayload,
    315 		TypeSpecificData: extra,
    316 		PeersID:          ch.localId,
    317 	}
    318 	if err := m.sendMessage(open); err != nil {
    319 		return nil, err
    320 	}
    321 
    322 	switch msg := (<-ch.msg).(type) {
    323 	case *channelOpenConfirmMsg:
    324 		return ch, nil
    325 	case *channelOpenFailureMsg:
    326 		return nil, &OpenChannelError{msg.Reason, msg.Message}
    327 	default:
    328 		return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
    329 	}
    330 }
    331 
    332 func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error {
    333 	msg, err := decode(packet)
    334 	if err != nil {
    335 		return err
    336 	}
    337 
    338 	switch msg := msg.(type) {
    339 	// RFC 4254 section 5.4 says unrecognized channel requests should
    340 	// receive a failure response.
    341 	case *channelRequestMsg:
    342 		if msg.WantReply {
    343 			return m.sendMessage(channelRequestFailureMsg{
    344 				PeersID: msg.PeersID,
    345 			})
    346 		}
    347 		return nil
    348 	default:
    349 		return fmt.Errorf("ssh: invalid channel %d", id)
    350 	}
    351 }