From 828aa6dcb0711450f165235bd69055d24738f316 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 4 May 2020 22:08:08 -0700 Subject: [PATCH] stunner: add Stunner.MaxTries option --- stunner/stunner.go | 20 +++++++++++++++++--- stunner/stunner_test.go | 4 ++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/stunner/stunner.go b/stunner/stunner.go index f407bf400..a99ad3c49 100644 --- a/stunner/stunner.go +++ b/stunner/stunner.go @@ -58,6 +58,12 @@ type Stunner struct { // If false, only IPv4 is used. There is currently no mixed mode. OnlyIPv6 bool + // MaxTries optionally provides a mapping from server name to the maximum + // number of tries that should be made for a given server. + // If nil or a server is not present in the map, the default is 1. + // Values less than 1 are ignored. + MaxTries map[string]int + mu sync.Mutex inFlight map[stun.TxID]request } @@ -268,14 +274,22 @@ func (s *Stunner) serverAddr(ctx context.Context, server string) (*net.UDPAddr, return addr, nil } +// maxTriesForServer returns the maximum number of STUN queries that +// will be sent to server (for one call to Run). The default is 1. +func (s *Stunner) maxTriesForServer(server string) int { + if v, ok := s.MaxTries[server]; ok && v > 0 { + return v + } + return 1 +} + func (s *Stunner) sendPackets(ctx context.Context, server string) error { addr, err := s.serverAddr(ctx, server) if err != nil { return err } - - const maxSend = 2 - for i := 0; i < maxSend; i++ { + maxTries := s.maxTriesForServer(server) + for i := 0; i < maxTries; i++ { txID := stun.NewTxID() req := stun.Request(txID) s.addTX(txID, server) diff --git a/stunner/stunner_test.go b/stunner/stunner_test.go index 9f8e4d5a7..a3555f1be 100644 --- a/stunner/stunner_test.go +++ b/stunner/stunner_test.go @@ -42,6 +42,10 @@ func TestStun(t *testing.T) { Send: localConn.WriteTo, Endpoint: func(server, ep string, d time.Duration) { epCh <- ep }, Servers: stunServers, + MaxTries: map[string]int{ + stunServers[0]: 2, + stunServers[1]: 2, + }, } stun1Err := make(chan error)