@ -6,9 +6,12 @@ package stunner
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
@ -37,6 +40,10 @@ type Stunner struct {
// took on the wire (not including DNS lookup time.
Endpoint func ( server , endpoint string , d time . Duration )
// onPacket is the internal version of Endpoint that does de-dup.
// It's set by Run.
onPacket func ( server , endpoint string , d time . Duration )
Servers [ ] string // STUN servers to contact
// DNSCache optionally specifies a DNSCache to use.
@ -50,10 +57,6 @@ type Stunner struct {
// If false, only IPv4 is used. There is currently no mixed mode.
OnlyIPv6 bool
// sessions tracks the state of each server.
// It's keyed by the STUN server (from the Servers field).
sessions map [ string ] * session
mu sync . Mutex
inFlight map [ stun . TxID ] request
}
@ -61,8 +64,8 @@ type Stunner struct {
func ( s * Stunner ) addTX ( tx stun . TxID , server string ) {
s . mu . Lock ( )
defer s . mu . Unlock ( )
if s. inFlight == nil {
s . inFlight = make ( map [ stun . TxID ] request )
if _, dup := s . inFlight [ tx ] ; dup {
panic ( "unexpected duplicate STUN TransactionID" )
}
s . inFlight [ tx ] = request { sent : time . Now ( ) , server : server }
}
@ -70,8 +73,15 @@ func (s *Stunner) addTX(tx stun.TxID, server string) {
func ( s * Stunner ) removeTX ( tx stun . TxID ) ( request , bool ) {
s . mu . Lock ( )
defer s . mu . Unlock ( )
if s . inFlight == nil {
return request { } , false
}
r , ok := s . inFlight [ tx ]
if ok {
delete ( s . inFlight , tx )
} else {
s . logf ( "stunner: got STUN packet for unknown TxID %x" , tx )
}
return r , ok
}
@ -80,11 +90,6 @@ type request struct {
server string
}
type session struct {
ctx context . Context // closed via call to done when reply received
cancel context . CancelFunc
}
func ( s * Stunner ) logf ( format string , args ... interface { } ) {
if s . Logf != nil {
s . Logf ( format , args ... )
@ -105,95 +110,113 @@ func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) {
}
r , ok := s . removeTX ( tx )
if ! ok {
s . logf ( "stunner: got STUN packet for unknown TxID %x" , tx )
return
}
d := now . Sub ( r . sent )
session := s . sessions [ r . server ]
if session != nil {
host := net . JoinHostPort ( net . IP ( addr ) . String ( ) , fmt . Sprint ( port ) )
s . Endpoint ( r . server , host , d )
session . cancel ( )
}
s . onPacket ( r . server , host , d )
}
func ( s * Stunner ) resolver ( ) * net . Resolver {
return net . DefaultResolver
}
// cleanUpPostRun zeros out some fields, mostly for debugging (so
// things crash or race+fail if there's a sender still running.)
func ( s * Stunner ) cleanUpPostRun ( ) {
s . mu . Lock ( )
s . inFlight = nil
s . mu . Unlock ( )
}
// Run starts a Stunner and blocks until all servers either respond
// or are tried multiple times and timeout.
//
// TODO: this always returns success now. It should return errors
// if certain servers are unavailable probably. Or if all are.
// Or some configured threshold are.
// It can not be called concurrently with itself.
func ( s * Stunner ) Run ( ctx context . Context ) error {
s . sessions = map [ string ] * session { }
for _ , server := range s . Servers {
if _ , _ , err := net . SplitHostPort ( server ) ; err != nil {
return fmt . Errorf ( "Stunner.Run: invalid server %q (in Server list %q)" , server , s . Servers )
}
sctx , cancel := context . WithCancel ( ctx )
s . sessions [ server ] = & session {
ctx : sctx ,
cancel : cancel ,
}
if len ( s . Servers ) == 0 {
return errors . New ( "stunner: no Servers" )
}
// after this point, the s.sessions map is read-only
var wg sync . WaitGroup
for _ , server := range s . Servers {
wg . Add ( 1 )
go func ( server string ) {
defer wg . Done ( )
s . runServer ( ctx , server )
} ( server )
}
wg . Wait ( )
return nil
}
s . inFlight = make ( map [ stun . TxID ] request )
defer s . cleanUpPostRun ( )
func ( s * Stunner ) runServer ( ctx context . Context , server string ) {
session := s . sessions [ server ]
ctx , cancel := context . WithCancel ( ctx )
defer cancel ( )
// If we're using a DNS cache, prime the cache before doing
// any quick timeouts (100ms, etc) so the timeout doesn't
// apply to the first DNS lookup.
if s . DNSCache != nil {
_ , _ = s . DNSCache . LookupIP ( ctx , server )
type sender struct {
ctx context . Context
cancel context . CancelFunc
}
var (
needMu sync . Mutex
need = make ( map [ string ] sender ) // keyed by server; deleted when done
allDone = make ( chan struct { } ) // closed when need is empty
)
s . onPacket = func ( server , endpoint string , d time . Duration ) {
needMu . Lock ( )
defer needMu . Unlock ( )
sender , ok := need [ server ]
if ! ok {
return
}
sender . cancel ( )
delete ( need , server )
s . Endpoint ( server , endpoint , d )
if len ( need ) == 0 {
close ( allDone )
}
for i , d := range retryDurations {
ctx , cancel := context . WithTimeout ( ctx , d )
err := s . sendSTUN ( ctx , server )
if err != nil {
s . logf ( "stunner: sendSTUN(%q): %v" , server , err )
}
var wg sync . WaitGroup
for _ , server := range s . Servers {
ctx , cancel := context . WithCancel ( ctx )
defer cancel ( )
need [ server ] = sender { ctx , cancel }
}
for server , sender := range need {
wg . Add ( 1 )
server , ctx := server , sender . ctx
go func ( ) {
defer wg . Done ( )
s . sendPackets ( ctx , server )
} ( )
}
var err error
select {
case <- ctx . Done ( ) :
err = ctx . Err ( )
case <- allDone :
cancel ( )
case <- session . ctx . Done ( ) :
cancel ( )
if i > 0 {
s . logf ( "stunner: slow STUN response from %s: %d retries" , server , i )
}
return
wg . Wait ( )
var missing [ ] string
needMu . Lock ( )
for server := range need {
missing = append ( missing , server )
}
needMu . Unlock ( )
if len ( missing ) == 0 || err == nil {
return nil
}
s . logf ( "stunner: no STUN response from %s" , server )
return fmt . Errorf ( "got STUN error: %v; missing replies from: %v" , err , strings . Join ( missing , ", " ) )
}
func ( s * Stunner ) sendSTUN ( ctx context . Context , server string ) error {
host , port , err := net . SplitHostPort ( server )
func ( s * Stunner ) se rverAddr ( ctx context . Context , server string ) ( * net . UDPAddr , error ) {
host Str , port Str , err := net . SplitHostPort ( server )
if err != nil {
return err
return nil , err
}
addrPort , err := strconv . Atoi ( port )
addrPort , err := strconv . Atoi ( port Str )
if err != nil {
return fmt . Errorf ( "port: %v" , err )
return nil , fmt . Errorf ( "port: %v" , err )
}
if addrPort == 0 {
addrPort = 3478
@ -202,17 +225,18 @@ func (s *Stunner) sendSTUN(ctx context.Context, server string) error {
var ipAddrs [ ] net . IPAddr
if s . DNSCache != nil {
ip , err := s . DNSCache . LookupIP ( ctx , host )
ip , err := s . DNSCache . LookupIP ( ctx , host Str )
if err != nil {
return fmt . Errorf ( "lookup ip addr from cache (%q): %v" , host , err )
return nil , err
}
ipAddrs = [ ] net . IPAddr { { IP : ip } }
} else {
ipAddrs , err = s . resolver ( ) . LookupIPAddr ( ctx , host )
ipAddrs , err = s . resolver ( ) . LookupIPAddr ( ctx , host Str )
if err != nil {
return fmt . Errorf ( "lookup ip addr (%q): %v" , host , err )
return nil , fmt . Errorf ( "lookup ip addr (%q): %v" , host Str , err )
}
}
for _ , ipAddr := range ipAddrs {
ip4 := ipAddr . IP . To4 ( )
if ip4 != nil {
@ -228,11 +252,21 @@ func (s *Stunner) sendSTUN(ctx context.Context, server string) error {
}
if addr . IP == nil {
if s . OnlyIPv6 {
return fmt . Errorf ( "cannot resolve any ipv6 addresses for %s, got: %v" , server , ipAddrs )
return nil , fmt . Errorf ( "cannot resolve any ipv6 addresses for %s, got: %v" , server , ipAddrs )
}
return fmt . Errorf ( "cannot resolve any ipv4 addresses for %s, got: %v" , server , ipAddrs )
return nil , fmt . Errorf ( "cannot resolve any ipv4 addresses for %s, got: %v" , server , ipAddrs )
}
return addr , nil
}
func ( s * Stunner ) sendPackets ( ctx context . Context , server string ) error {
addr , err := s . serverAddr ( ctx , server )
if err != nil {
return err
}
const maxSend = 2
for i := 0 ; i < maxSend ; i ++ {
txID := stun . NewTxID ( )
req := stun . Request ( txID )
s . addTX ( txID , server )
@ -240,17 +274,14 @@ func (s *Stunner) sendSTUN(ctx context.Context, server string) error {
if err != nil {
return fmt . Errorf ( "send: %v" , err )
}
return nil
}
var retryDurations = [ ] time . Duration {
100 * time . Millisecond ,
100 * time . Millisecond ,
100 * time . Millisecond ,
200 * time . Millisecond ,
200 * time . Millisecond ,
400 * time . Millisecond ,
800 * time . Millisecond ,
1600 * time . Millisecond ,
3200 * time . Millisecond ,
select {
case <- ctx . Done ( ) :
// Ignore error. The caller deals with handling contexts.
// We only use it to dermine when to stop spraying STUN packets.
return nil
case <- time . After ( time . Millisecond * time . Duration ( 50 + rand . Intn ( 200 ) ) ) :
}
}
return nil
}