mirror of https://github.com/tailscale/tailscale/
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
358 lines
7.9 KiB
Go
358 lines
7.9 KiB
Go
1 month ago
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package ssh
|
||
|
|
||
|
import (
|
||
|
"encoding/binary"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"log"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
)
|
||
|
|
||
|
// debugMux, if set, causes messages in the connection protocol to be
|
||
|
// logged.
|
||
|
const debugMux = false
|
||
|
|
||
|
// chanList is a thread safe channel list.
|
||
|
type chanList struct {
|
||
|
// protects concurrent access to chans
|
||
|
sync.Mutex
|
||
|
|
||
|
// chans are indexed by the local id of the channel, which the
|
||
|
// other side should send in the PeersId field.
|
||
|
chans []*channel
|
||
|
|
||
|
// This is a debugging aid: it offsets all IDs by this
|
||
|
// amount. This helps distinguish otherwise identical
|
||
|
// server/client muxes
|
||
|
offset uint32
|
||
|
}
|
||
|
|
||
|
// Assigns a channel ID to the given channel.
|
||
|
func (c *chanList) add(ch *channel) uint32 {
|
||
|
c.Lock()
|
||
|
defer c.Unlock()
|
||
|
for i := range c.chans {
|
||
|
if c.chans[i] == nil {
|
||
|
c.chans[i] = ch
|
||
|
return uint32(i) + c.offset
|
||
|
}
|
||
|
}
|
||
|
c.chans = append(c.chans, ch)
|
||
|
return uint32(len(c.chans)-1) + c.offset
|
||
|
}
|
||
|
|
||
|
// getChan returns the channel for the given ID.
|
||
|
func (c *chanList) getChan(id uint32) *channel {
|
||
|
id -= c.offset
|
||
|
|
||
|
c.Lock()
|
||
|
defer c.Unlock()
|
||
|
if id < uint32(len(c.chans)) {
|
||
|
return c.chans[id]
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *chanList) remove(id uint32) {
|
||
|
id -= c.offset
|
||
|
c.Lock()
|
||
|
if id < uint32(len(c.chans)) {
|
||
|
c.chans[id] = nil
|
||
|
}
|
||
|
c.Unlock()
|
||
|
}
|
||
|
|
||
|
// dropAll forgets all channels it knows, returning them in a slice.
|
||
|
func (c *chanList) dropAll() []*channel {
|
||
|
c.Lock()
|
||
|
defer c.Unlock()
|
||
|
var r []*channel
|
||
|
|
||
|
for _, ch := range c.chans {
|
||
|
if ch == nil {
|
||
|
continue
|
||
|
}
|
||
|
r = append(r, ch)
|
||
|
}
|
||
|
c.chans = nil
|
||
|
return r
|
||
|
}
|
||
|
|
||
|
// mux represents the state for the SSH connection protocol, which
|
||
|
// multiplexes many channels onto a single packet transport.
|
||
|
type mux struct {
|
||
|
conn packetConn
|
||
|
chanList chanList
|
||
|
|
||
|
incomingChannels chan NewChannel
|
||
|
|
||
|
globalSentMu sync.Mutex
|
||
|
globalResponses chan interface{}
|
||
|
incomingRequests chan *Request
|
||
|
|
||
|
errCond *sync.Cond
|
||
|
err error
|
||
|
}
|
||
|
|
||
|
// When debugging, each new chanList instantiation has a different
|
||
|
// offset.
|
||
|
var globalOff uint32
|
||
|
|
||
|
func (m *mux) Wait() error {
|
||
|
m.errCond.L.Lock()
|
||
|
defer m.errCond.L.Unlock()
|
||
|
for m.err == nil {
|
||
|
m.errCond.Wait()
|
||
|
}
|
||
|
return m.err
|
||
|
}
|
||
|
|
||
|
// newMux returns a mux that runs over the given connection.
|
||
|
func newMux(p packetConn) *mux {
|
||
|
m := &mux{
|
||
|
conn: p,
|
||
|
incomingChannels: make(chan NewChannel, chanSize),
|
||
|
globalResponses: make(chan interface{}, 1),
|
||
|
incomingRequests: make(chan *Request, chanSize),
|
||
|
errCond: newCond(),
|
||
|
}
|
||
|
if debugMux {
|
||
|
m.chanList.offset = atomic.AddUint32(&globalOff, 1)
|
||
|
}
|
||
|
|
||
|
go m.loop()
|
||
|
return m
|
||
|
}
|
||
|
|
||
|
func (m *mux) sendMessage(msg interface{}) error {
|
||
|
p := Marshal(msg)
|
||
|
if debugMux {
|
||
|
log.Printf("send global(%d): %#v", m.chanList.offset, msg)
|
||
|
}
|
||
|
return m.conn.writePacket(p)
|
||
|
}
|
||
|
|
||
|
func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
|
||
|
if wantReply {
|
||
|
m.globalSentMu.Lock()
|
||
|
defer m.globalSentMu.Unlock()
|
||
|
}
|
||
|
|
||
|
if err := m.sendMessage(globalRequestMsg{
|
||
|
Type: name,
|
||
|
WantReply: wantReply,
|
||
|
Data: payload,
|
||
|
}); err != nil {
|
||
|
return false, nil, err
|
||
|
}
|
||
|
|
||
|
if !wantReply {
|
||
|
return false, nil, nil
|
||
|
}
|
||
|
|
||
|
msg, ok := <-m.globalResponses
|
||
|
if !ok {
|
||
|
return false, nil, io.EOF
|
||
|
}
|
||
|
switch msg := msg.(type) {
|
||
|
case *globalRequestFailureMsg:
|
||
|
return false, msg.Data, nil
|
||
|
case *globalRequestSuccessMsg:
|
||
|
return true, msg.Data, nil
|
||
|
default:
|
||
|
return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ackRequest must be called after processing a global request that
|
||
|
// has WantReply set.
|
||
|
func (m *mux) ackRequest(ok bool, data []byte) error {
|
||
|
if ok {
|
||
|
return m.sendMessage(globalRequestSuccessMsg{Data: data})
|
||
|
}
|
||
|
return m.sendMessage(globalRequestFailureMsg{Data: data})
|
||
|
}
|
||
|
|
||
|
func (m *mux) Close() error {
|
||
|
return m.conn.Close()
|
||
|
}
|
||
|
|
||
|
// loop runs the connection machine. It will process packets until an
|
||
|
// error is encountered. To synchronize on loop exit, use mux.Wait.
|
||
|
func (m *mux) loop() {
|
||
|
var err error
|
||
|
for err == nil {
|
||
|
err = m.onePacket()
|
||
|
}
|
||
|
|
||
|
for _, ch := range m.chanList.dropAll() {
|
||
|
ch.close()
|
||
|
}
|
||
|
|
||
|
close(m.incomingChannels)
|
||
|
close(m.incomingRequests)
|
||
|
close(m.globalResponses)
|
||
|
|
||
|
m.conn.Close()
|
||
|
|
||
|
m.errCond.L.Lock()
|
||
|
m.err = err
|
||
|
m.errCond.Broadcast()
|
||
|
m.errCond.L.Unlock()
|
||
|
|
||
|
if debugMux {
|
||
|
log.Println("loop exit", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// onePacket reads and processes one packet.
|
||
|
func (m *mux) onePacket() error {
|
||
|
packet, err := m.conn.readPacket()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if debugMux {
|
||
|
if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
|
||
|
log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
|
||
|
} else {
|
||
|
p, _ := decode(packet)
|
||
|
log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
switch packet[0] {
|
||
|
case msgChannelOpen:
|
||
|
return m.handleChannelOpen(packet)
|
||
|
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
|
||
|
return m.handleGlobalPacket(packet)
|
||
|
case msgPing:
|
||
|
var msg pingMsg
|
||
|
if err := Unmarshal(packet, &msg); err != nil {
|
||
|
return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err)
|
||
|
}
|
||
|
return m.sendMessage(pongMsg(msg))
|
||
|
}
|
||
|
|
||
|
// assume a channel packet.
|
||
|
if len(packet) < 5 {
|
||
|
return parseError(packet[0])
|
||
|
}
|
||
|
id := binary.BigEndian.Uint32(packet[1:])
|
||
|
ch := m.chanList.getChan(id)
|
||
|
if ch == nil {
|
||
|
return m.handleUnknownChannelPacket(id, packet)
|
||
|
}
|
||
|
|
||
|
return ch.handlePacket(packet)
|
||
|
}
|
||
|
|
||
|
func (m *mux) handleGlobalPacket(packet []byte) error {
|
||
|
msg, err := decode(packet)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
switch msg := msg.(type) {
|
||
|
case *globalRequestMsg:
|
||
|
m.incomingRequests <- &Request{
|
||
|
Type: msg.Type,
|
||
|
WantReply: msg.WantReply,
|
||
|
Payload: msg.Data,
|
||
|
mux: m,
|
||
|
}
|
||
|
case *globalRequestSuccessMsg, *globalRequestFailureMsg:
|
||
|
m.globalResponses <- msg
|
||
|
default:
|
||
|
panic(fmt.Sprintf("not a global message %#v", msg))
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// handleChannelOpen schedules a channel to be Accept()ed.
|
||
|
func (m *mux) handleChannelOpen(packet []byte) error {
|
||
|
var msg channelOpenMsg
|
||
|
if err := Unmarshal(packet, &msg); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
|
||
|
failMsg := channelOpenFailureMsg{
|
||
|
PeersID: msg.PeersID,
|
||
|
Reason: ConnectionFailed,
|
||
|
Message: "invalid request",
|
||
|
Language: "en_US.UTF-8",
|
||
|
}
|
||
|
return m.sendMessage(failMsg)
|
||
|
}
|
||
|
|
||
|
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
|
||
|
c.remoteId = msg.PeersID
|
||
|
c.maxRemotePayload = msg.MaxPacketSize
|
||
|
c.remoteWin.add(msg.PeersWindow)
|
||
|
m.incomingChannels <- c
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
|
||
|
ch, err := m.openChannel(chanType, extra)
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
|
||
|
return ch, ch.incomingRequests, nil
|
||
|
}
|
||
|
|
||
|
func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
|
||
|
ch := m.newChannel(chanType, channelOutbound, extra)
|
||
|
|
||
|
ch.maxIncomingPayload = channelMaxPacket
|
||
|
|
||
|
open := channelOpenMsg{
|
||
|
ChanType: chanType,
|
||
|
PeersWindow: ch.myWindow,
|
||
|
MaxPacketSize: ch.maxIncomingPayload,
|
||
|
TypeSpecificData: extra,
|
||
|
PeersID: ch.localId,
|
||
|
}
|
||
|
if err := m.sendMessage(open); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
switch msg := (<-ch.msg).(type) {
|
||
|
case *channelOpenConfirmMsg:
|
||
|
return ch, nil
|
||
|
case *channelOpenFailureMsg:
|
||
|
return nil, &OpenChannelError{msg.Reason, msg.Message}
|
||
|
default:
|
||
|
return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error {
|
||
|
msg, err := decode(packet)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
switch msg := msg.(type) {
|
||
|
// RFC 4254 section 5.4 says unrecognized channel requests should
|
||
|
// receive a failure response.
|
||
|
case *channelRequestMsg:
|
||
|
if msg.WantReply {
|
||
|
return m.sendMessage(channelRequestFailureMsg{
|
||
|
PeersID: msg.PeersID,
|
||
|
})
|
||
|
}
|
||
|
return nil
|
||
|
default:
|
||
|
return fmt.Errorf("ssh: invalid channel %d", id)
|
||
|
}
|
||
|
}
|