diff --git a/net/netns/netns.go b/net/netns/netns.go index fcb130a64..d9490bf8d 100644 --- a/net/netns/netns.go +++ b/net/netns/netns.go @@ -17,6 +17,8 @@ package netns import ( "context" "net" + + "inet.af/netaddr" ) // Listener returns a new net.Listener with its Control hook func @@ -66,3 +68,19 @@ type Dialer interface { Dial(network, address string) (net.Conn, error) DialContext(ctx context.Context, network, address string) (net.Conn, error) } + +func isLocalhost(addr string) bool { + host, _, err := net.SplitHostPort(addr) + if err != nil { + // error means the string didn't contain a port number, so use the string directly + host = addr + } + + // localhost6 == RedHat /etc/hosts for ::1, ip6-loopback & ip6-localhost == Debian /etc/hosts for ::1 + if host == "localhost" || host == "localhost6" || host == "ip6-loopback" || host == "ip6-localhost" { + return true + } + + ip, _ := netaddr.ParseIP(host) + return ip.IsLoopback() +} diff --git a/net/netns/netns_linux.go b/net/netns/netns_linux.go index 7292cc376..e3d2c03da 100644 --- a/net/netns/netns_linux.go +++ b/net/netns/netns_linux.go @@ -64,6 +64,14 @@ func ignoreErrors() bool { // It's intentionally the same signature as net.Dialer.Control // and net.ListenConfig.Control. func control(network, address string, c syscall.RawConn) error { + if hostinfo.GetEnvType() == hostinfo.TestCase { + return nil + } + if IsLocalhost(address) { + // Don't bind to an interface for localhost connections. + return nil + } + var sockErr error err := c.Control(func(fd uintptr) { if ipRuleAvailable() { diff --git a/net/netns/netns_test.go b/net/netns/netns_test.go index 0e3eb963f..caaf78079 100644 --- a/net/netns/netns_test.go +++ b/net/netns/netns_test.go @@ -40,3 +40,40 @@ func TestDial(t *testing.T) { defer c.Close() t.Logf("got addr %v", c.RemoteAddr()) } + +func TestIsLocalhost(t *testing.T) { + tests := []struct { + name string + host string + want bool + }{ + {"IPv4 loopback", "127.0.0.1", true}, + {"IPv4 !loopback", "192.168.0.1", false}, + {"IPv4 loopback with port", "127.0.0.1:1", true}, + {"IPv4 !loopback with port", "192.168.0.1:1", false}, + {"IPv4 unspecified", "0.0.0.0", false}, + {"IPv4 unspecified with port", "0.0.0.0:1", false}, + {"IPv6 loopback", "::1", true}, + {"IPv6 !loopback", "2001:4860:4860::8888", false}, + {"IPv6 loopback with port", "[::1]:1", true}, + {"IPv6 !loopback with port", "[2001:4860:4860::8888]:1", false}, + {"IPv6 unspecified", "::", false}, + {"IPv6 unspecified with port", "[::]:1", false}, + {"empty", "", false}, + {"hostname", "example.com", false}, + {"localhost", "localhost", true}, + {"localhost6", "localhost6", true}, + {"localhost with port", "localhost:1", true}, + {"localhost6 with port", "localhost6:1", true}, + {"ip6-localhost", "ip6-localhost", true}, + {"ip6-localhost with port", "ip6-localhost:1", true}, + {"ip6-loopback", "ip6-loopback", true}, + {"ip6-loopback with port", "ip6-loopback:1", true}, + } + + for _, test := range tests { + if got := isLocalhost(test.host); got != test.want { + t.Errorf("isLocalhost(%q) = %v, want %v", test.name, got, test.want) + } + } +}