mux.go (7870B)
1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssh 6 7 import ( 8 "encoding/binary" 9 "fmt" 10 "io" 11 "log" 12 "sync" 13 "sync/atomic" 14 ) 15 16 // debugMux, if set, causes messages in the connection protocol to be 17 // logged. 18 const debugMux = false 19 20 // chanList is a thread safe channel list. 21 type chanList struct { 22 // protects concurrent access to chans 23 sync.Mutex 24 25 // chans are indexed by the local id of the channel, which the 26 // other side should send in the PeersId field. 27 chans []*channel 28 29 // This is a debugging aid: it offsets all IDs by this 30 // amount. This helps distinguish otherwise identical 31 // server/client muxes 32 offset uint32 33 } 34 35 // Assigns a channel ID to the given channel. 36 func (c *chanList) add(ch *channel) uint32 { 37 c.Lock() 38 defer c.Unlock() 39 for i := range c.chans { 40 if c.chans[i] == nil { 41 c.chans[i] = ch 42 return uint32(i) + c.offset 43 } 44 } 45 c.chans = append(c.chans, ch) 46 return uint32(len(c.chans)-1) + c.offset 47 } 48 49 // getChan returns the channel for the given ID. 50 func (c *chanList) getChan(id uint32) *channel { 51 id -= c.offset 52 53 c.Lock() 54 defer c.Unlock() 55 if id < uint32(len(c.chans)) { 56 return c.chans[id] 57 } 58 return nil 59 } 60 61 func (c *chanList) remove(id uint32) { 62 id -= c.offset 63 c.Lock() 64 if id < uint32(len(c.chans)) { 65 c.chans[id] = nil 66 } 67 c.Unlock() 68 } 69 70 // dropAll forgets all channels it knows, returning them in a slice. 71 func (c *chanList) dropAll() []*channel { 72 c.Lock() 73 defer c.Unlock() 74 var r []*channel 75 76 for _, ch := range c.chans { 77 if ch == nil { 78 continue 79 } 80 r = append(r, ch) 81 } 82 c.chans = nil 83 return r 84 } 85 86 // mux represents the state for the SSH connection protocol, which 87 // multiplexes many channels onto a single packet transport. 88 type mux struct { 89 conn packetConn 90 chanList chanList 91 92 incomingChannels chan NewChannel 93 94 globalSentMu sync.Mutex 95 globalResponses chan interface{} 96 incomingRequests chan *Request 97 98 errCond *sync.Cond 99 err error 100 } 101 102 // When debugging, each new chanList instantiation has a different 103 // offset. 104 var globalOff uint32 105 106 func (m *mux) Wait() error { 107 m.errCond.L.Lock() 108 defer m.errCond.L.Unlock() 109 for m.err == nil { 110 m.errCond.Wait() 111 } 112 return m.err 113 } 114 115 // newMux returns a mux that runs over the given connection. 116 func newMux(p packetConn) *mux { 117 m := &mux{ 118 conn: p, 119 incomingChannels: make(chan NewChannel, chanSize), 120 globalResponses: make(chan interface{}, 1), 121 incomingRequests: make(chan *Request, chanSize), 122 errCond: newCond(), 123 } 124 if debugMux { 125 m.chanList.offset = atomic.AddUint32(&globalOff, 1) 126 } 127 128 go m.loop() 129 return m 130 } 131 132 func (m *mux) sendMessage(msg interface{}) error { 133 p := Marshal(msg) 134 if debugMux { 135 log.Printf("send global(%d): %#v", m.chanList.offset, msg) 136 } 137 return m.conn.writePacket(p) 138 } 139 140 func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { 141 if wantReply { 142 m.globalSentMu.Lock() 143 defer m.globalSentMu.Unlock() 144 } 145 146 if err := m.sendMessage(globalRequestMsg{ 147 Type: name, 148 WantReply: wantReply, 149 Data: payload, 150 }); err != nil { 151 return false, nil, err 152 } 153 154 if !wantReply { 155 return false, nil, nil 156 } 157 158 msg, ok := <-m.globalResponses 159 if !ok { 160 return false, nil, io.EOF 161 } 162 switch msg := msg.(type) { 163 case *globalRequestFailureMsg: 164 return false, msg.Data, nil 165 case *globalRequestSuccessMsg: 166 return true, msg.Data, nil 167 default: 168 return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) 169 } 170 } 171 172 // ackRequest must be called after processing a global request that 173 // has WantReply set. 174 func (m *mux) ackRequest(ok bool, data []byte) error { 175 if ok { 176 return m.sendMessage(globalRequestSuccessMsg{Data: data}) 177 } 178 return m.sendMessage(globalRequestFailureMsg{Data: data}) 179 } 180 181 func (m *mux) Close() error { 182 return m.conn.Close() 183 } 184 185 // loop runs the connection machine. It will process packets until an 186 // error is encountered. To synchronize on loop exit, use mux.Wait. 187 func (m *mux) loop() { 188 var err error 189 for err == nil { 190 err = m.onePacket() 191 } 192 193 for _, ch := range m.chanList.dropAll() { 194 ch.close() 195 } 196 197 close(m.incomingChannels) 198 close(m.incomingRequests) 199 close(m.globalResponses) 200 201 m.conn.Close() 202 203 m.errCond.L.Lock() 204 m.err = err 205 m.errCond.Broadcast() 206 m.errCond.L.Unlock() 207 208 if debugMux { 209 log.Println("loop exit", err) 210 } 211 } 212 213 // onePacket reads and processes one packet. 214 func (m *mux) onePacket() error { 215 packet, err := m.conn.readPacket() 216 if err != nil { 217 return err 218 } 219 220 if debugMux { 221 if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { 222 log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) 223 } else { 224 p, _ := decode(packet) 225 log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) 226 } 227 } 228 229 switch packet[0] { 230 case msgChannelOpen: 231 return m.handleChannelOpen(packet) 232 case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: 233 return m.handleGlobalPacket(packet) 234 } 235 236 // assume a channel packet. 237 if len(packet) < 5 { 238 return parseError(packet[0]) 239 } 240 id := binary.BigEndian.Uint32(packet[1:]) 241 ch := m.chanList.getChan(id) 242 if ch == nil { 243 return m.handleUnknownChannelPacket(id, packet) 244 } 245 246 return ch.handlePacket(packet) 247 } 248 249 func (m *mux) handleGlobalPacket(packet []byte) error { 250 msg, err := decode(packet) 251 if err != nil { 252 return err 253 } 254 255 switch msg := msg.(type) { 256 case *globalRequestMsg: 257 m.incomingRequests <- &Request{ 258 Type: msg.Type, 259 WantReply: msg.WantReply, 260 Payload: msg.Data, 261 mux: m, 262 } 263 case *globalRequestSuccessMsg, *globalRequestFailureMsg: 264 m.globalResponses <- msg 265 default: 266 panic(fmt.Sprintf("not a global message %#v", msg)) 267 } 268 269 return nil 270 } 271 272 // handleChannelOpen schedules a channel to be Accept()ed. 273 func (m *mux) handleChannelOpen(packet []byte) error { 274 var msg channelOpenMsg 275 if err := Unmarshal(packet, &msg); err != nil { 276 return err 277 } 278 279 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { 280 failMsg := channelOpenFailureMsg{ 281 PeersID: msg.PeersID, 282 Reason: ConnectionFailed, 283 Message: "invalid request", 284 Language: "en_US.UTF-8", 285 } 286 return m.sendMessage(failMsg) 287 } 288 289 c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) 290 c.remoteId = msg.PeersID 291 c.maxRemotePayload = msg.MaxPacketSize 292 c.remoteWin.add(msg.PeersWindow) 293 m.incomingChannels <- c 294 return nil 295 } 296 297 func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { 298 ch, err := m.openChannel(chanType, extra) 299 if err != nil { 300 return nil, nil, err 301 } 302 303 return ch, ch.incomingRequests, nil 304 } 305 306 func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { 307 ch := m.newChannel(chanType, channelOutbound, extra) 308 309 ch.maxIncomingPayload = channelMaxPacket 310 311 open := channelOpenMsg{ 312 ChanType: chanType, 313 PeersWindow: ch.myWindow, 314 MaxPacketSize: ch.maxIncomingPayload, 315 TypeSpecificData: extra, 316 PeersID: ch.localId, 317 } 318 if err := m.sendMessage(open); err != nil { 319 return nil, err 320 } 321 322 switch msg := (<-ch.msg).(type) { 323 case *channelOpenConfirmMsg: 324 return ch, nil 325 case *channelOpenFailureMsg: 326 return nil, &OpenChannelError{msg.Reason, msg.Message} 327 default: 328 return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) 329 } 330 } 331 332 func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error { 333 msg, err := decode(packet) 334 if err != nil { 335 return err 336 } 337 338 switch msg := msg.(type) { 339 // RFC 4254 section 5.4 says unrecognized channel requests should 340 // receive a failure response. 341 case *channelRequestMsg: 342 if msg.WantReply { 343 return m.sendMessage(channelRequestFailureMsg{ 344 PeersID: msg.PeersID, 345 }) 346 } 347 return nil 348 default: 349 return fmt.Errorf("ssh: invalid channel %d", id) 350 } 351 }