stun: check high bits in Is, add tests

Also use new stun.TxID type in stunner.

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/106/head
Brad Fitzpatrick 5 years ago
parent 2489ea4268
commit 14abc82033

@ -218,9 +218,7 @@ func mappedAddress(b []byte) (addr []byte, port uint16, err error) {
// Is reports whether b is a STUN message. // Is reports whether b is a STUN message.
func Is(b []byte) bool { func Is(b []byte) bool {
if len(b) < headerLen { return len(b) >= headerLen &&
return false // every STUN message must have a 20-byte header b[0]&0b11000000 == 0 && // top two bits must be zero
} string(b[4:8]) == magicCookie
// TODO RFC5389 suggests checking the first 2 bits of the header are zero.
return string(b[4:8]) == magicCookie
} }

@ -166,3 +166,29 @@ func TestParseResponse(t *testing.T) {
}) })
} }
} }
func TestIs(t *testing.T) {
const magicCookie = "\x21\x12\xa4\x42"
tests := []struct {
in string
want bool
}{
{"", false},
{"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false},
{"\x00\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false},
{"\x00\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", true},
{"\x00\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00foo", true},
// high bits set:
{"\xf0\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false},
{"\x40\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false},
// first byte non-zero, but not high bits:
{"\x20\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", true},
}
for i, tt := range tests {
pkt := []byte(tt.in)
got := stun.Is(pkt)
if got != tt.want {
t.Errorf("%d. In(%q (%v)) = %v; want %v", i, pkt, pkt, got, tt.want)
}
}
}

@ -40,7 +40,7 @@ type Stunner struct {
type session struct { type session struct {
replied chan struct{} // closed when server responds replied chan struct{} // closed when server responds
tIDs [][12]byte // transaction IDs sent to a server tIDs []stun.TxID // transaction IDs sent to a server
} }
// Receive delivers a STUN packet to the stunner. // Receive delivers a STUN packet to the stunner.
@ -90,7 +90,7 @@ func (s *Stunner) Run(ctx context.Context) error {
} }
for _, server := range s.Servers { for _, server := range s.Servers {
// Generate the transaction IDs for this session. // Generate the transaction IDs for this session.
tIDs := make([][12]byte, len(retryDurations)) tIDs := make([]stun.TxID, len(retryDurations))
for i := range tIDs { for i := range tIDs {
if _, err := rand.Read(tIDs[i][:]); err != nil { if _, err := rand.Read(tIDs[i][:]); err != nil {
return fmt.Errorf("stunner: rand failed: %v", err) return fmt.Errorf("stunner: rand failed: %v", err)
@ -147,7 +147,7 @@ func (s *Stunner) runServer(ctx context.Context, server string) {
} }
} }
func (s *Stunner) sendSTUN(ctx context.Context, tID [12]byte, server string) error { func (s *Stunner) sendSTUN(ctx context.Context, tID stun.TxID, server string) error {
host, port, err := net.SplitHostPort(server) host, port, err := net.SplitHostPort(server)
if err != nil { if err != nil {
return err return err

Loading…
Cancel
Save