ipn/ipnlocal, net/tsdial: make SOCKS/HTTP dials use ExitDNS

And simplify, unexport some tsdial/netstack stuff in the the process.

Fixes #3475

Change-Id: I186a5a5cbd8958e25c075b4676f7f6e70f3ff76e
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/3497/head
Brad Fitzpatrick 3 years ago committed by Brad Fitzpatrick
parent 9f6249b26d
commit 9c5c9d0a50

@ -333,7 +333,7 @@ func run() error {
return ok return ok
} }
dialer.NetstackDialTCP = func(ctx context.Context, dst netaddr.IPPort) (net.Conn, error) { dialer.NetstackDialTCP = func(ctx context.Context, dst netaddr.IPPort) (net.Conn, error) {
return ns.DialContextTCP(ctx, dst.String()) return ns.DialContextTCP(ctx, dst)
} }
} }

@ -1898,6 +1898,15 @@ func (b *LocalBackend) authReconfig() {
} }
} }
// Keep the dialer updated about whether we're supposed to use
// an exit node's DNS server (so SOCKS5/HTTP outgoing dials
// can use it for name resolution)
if dohURL, ok := exitNodeCanProxyDNS(nm, prefs.ExitNodeID); ok {
b.dialer.SetExitDNSDoH(dohURL)
} else {
b.dialer.SetExitDNSDoH("")
}
cfg, err := nmcfg.WGCfg(nm, b.logf, flags, prefs.ExitNodeID) cfg, err := nmcfg.WGCfg(nm, b.logf, flags, prefs.ExitNodeID)
if err != nil { if err != nil {
b.logf("wgcfg: %v", err) b.logf("wgcfg: %v", err)

@ -0,0 +1,86 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tsdial
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"time"
)
// dohConn is a net.PacketConn suitable for returning from
// net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes'
// ExitDNS DoH proxy service.
type dohConn struct {
ctx context.Context
baseURL string
hc *http.Client // if nil, default is used
rbuf bytes.Buffer
}
var (
_ net.Conn = (*dohConn)(nil)
_ net.PacketConn = (*dohConn)(nil) // be a PacketConn to change net.Resolver semantics
)
func (*dohConn) Close() error { return nil }
func (*dohConn) LocalAddr() net.Addr { return todoAddr{} }
func (*dohConn) RemoteAddr() net.Addr { return todoAddr{} }
func (*dohConn) SetDeadline(t time.Time) error { return nil }
func (*dohConn) SetReadDeadline(t time.Time) error { return nil }
func (*dohConn) SetWriteDeadline(t time.Time) error { return nil }
func (c *dohConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return c.Write(p)
}
func (c *dohConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(p)
return n, todoAddr{}, err
}
func (c *dohConn) Read(p []byte) (n int, err error) {
return c.rbuf.Read(p)
}
func (c *dohConn) Write(packet []byte) (n int, err error) {
req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet))
if err != nil {
return 0, err
}
const dohType = "application/dns-message"
req.Header.Set("Content-Type", dohType)
hc := c.hc
if hc == nil {
hc = http.DefaultClient
}
hres, err := hc.Do(req)
if err != nil {
return 0, err
}
defer hres.Body.Close()
if hres.StatusCode != 200 {
return 0, errors.New(hres.Status)
}
if ct := hres.Header.Get("Content-Type"); ct != dohType {
return 0, fmt.Errorf("unexpected response Content-Type %q", ct)
}
_, err = io.Copy(&c.rbuf, hres.Body)
if err != nil {
return 0, err
}
return len(packet), nil
}
type todoAddr struct{}
func (todoAddr) Network() string { return "unused" }
func (todoAddr) String() string { return "unused-todoAddr" }

@ -0,0 +1,32 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tsdial
import (
"context"
"flag"
"net"
"testing"
"time"
)
var dohBase = flag.String("doh-base", "", "DoH base URL for manual DoH tests; e.g. \"http://100.68.82.120:47830/dns-query\"")
func TestDoHResolve(t *testing.T) {
if *dohBase == "" {
t.Skip("skipping manual test without --doh-base= set")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
var r net.Resolver
r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
return &dohConn{ctx: ctx, baseURL: *dohBase}, nil
}
addrs, err := r.LookupIP(ctx, "ip4", "google.com.")
if err != nil {
t.Fatal(err)
}
t.Logf("Got: %q", addrs)
}

@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
@ -43,10 +44,11 @@ type Dialer struct {
peerDialerOnce sync.Once peerDialerOnce sync.Once
peerDialer *net.Dialer peerDialer *net.Dialer
mu sync.Mutex mu sync.Mutex
dns dnsMap dns dnsMap
tunName string // tun device name tunName string // tun device name
linkMon *monitor.Mon linkMon *monitor.Mon
exitDNSDoHBase string // non-empty if DoH-proxying exit node in use; base URL+path (without '?')
} }
// SetTUNName sets the name of the tun device in use ("tailscale0", "utun6", // SetTUNName sets the name of the tun device in use ("tailscale0", "utun6",
@ -66,6 +68,17 @@ func (d *Dialer) TUNName() string {
return d.tunName return d.tunName
} }
// SetExitDNSDoH sets (or clears) the exit node DNS DoH server base URL to use.
// The doh URL should contain the scheme, authority, and path, but without
// a '?' and/or query parameters.
//
// For example, "http://100.68.82.120:47830/dns-query".
func (d *Dialer) SetExitDNSDoH(doh string) {
d.mu.Lock()
defer d.mu.Unlock()
d.exitDNSDoHBase = doh
}
func (d *Dialer) SetLinkMonitor(mon *monitor.Mon) { func (d *Dialer) SetLinkMonitor(mon *monitor.Mon) {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
@ -113,21 +126,20 @@ func (d *Dialer) SetNetMap(nm *netmap.NetworkMap) {
d.dns = m d.dns = m
} }
func (d *Dialer) Resolve(ctx context.Context, network, addr string) (netaddr.IPPort, error) { func (d *Dialer) userDialResolve(ctx context.Context, network, addr string) (netaddr.IPPort, error) {
d.mu.Lock() d.mu.Lock()
dns := d.dns dns := d.dns
exitDNSDoH := d.exitDNSDoHBase
d.mu.Unlock() d.mu.Unlock()
// MagicDNS or otherwise baked in to the NetworkMap? Try that first. // MagicDNS or otherwise baked in to the NetworkMap? Try that first.
ipp, err := dns.resolveMemory(ctx, network, addr) ipp, err := dns.resolveMemory(ctx, network, addr)
if err != errUnresolved { if err != errUnresolved {
return ipp, err return ipp, err
} }
// Otherwise, hit the network. // Otherwise, hit the network.
// TODO(bradfitz): use ExitDNS (Issue 3475)
// TODO(bradfitz): wire up net/dnscache too. // TODO(bradfitz): wire up net/dnscache too.
host, port, err := splitHostPort(addr) host, port, err := splitHostPort(addr)
@ -137,7 +149,17 @@ func (d *Dialer) Resolve(ctx context.Context, network, addr string) (netaddr.IPP
} }
var r net.Resolver var r net.Resolver
ips, err := r.LookupIP(ctx, network, host) if exitDNSDoH != "" {
r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
return &dohConn{
ctx: ctx,
baseURL: exitDNSDoH,
hc: d.PeerAPIHTTPClient(),
}, nil
}
}
ips, err := r.LookupIP(ctx, ipNetOfNetwork(network), host)
if err != nil { if err != nil {
return netaddr.IPPort{}, err return netaddr.IPPort{}, err
} }
@ -148,10 +170,23 @@ func (d *Dialer) Resolve(ctx context.Context, network, addr string) (netaddr.IPP
return netaddr.IPPortFrom(ip, port), nil return netaddr.IPPortFrom(ip, port), nil
} }
// ipNetOfNetwork returns "ip", "ip4", or "ip6" corresponding
// to the input value of "tcp", "tcp4", "udp6" etc network
// names.
func ipNetOfNetwork(n string) string {
if strings.HasSuffix(n, "4") {
return "ip4"
}
if strings.HasSuffix(n, "6") {
return "ip6"
}
return "ip"
}
// UserDial connects to the provided network address as if a user were initiating the dial. // UserDial connects to the provided network address as if a user were initiating the dial.
// (e.g. from a SOCKS or HTTP outbound proxy) // (e.g. from a SOCKS or HTTP outbound proxy)
func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, error) { func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, error) {
ipp, err := d.Resolve(ctx, network, addr) ipp, err := d.userDialResolve(ctx, network, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -146,7 +146,7 @@ func (s *Server) start() error {
return ok return ok
} }
dialer.NetstackDialTCP = func(ctx context.Context, dst netaddr.IPPort) (net.Conn, error) { dialer.NetstackDialTCP = func(ctx context.Context, dst netaddr.IPPort) (net.Conn, error) {
return ns.DialContextTCP(ctx, dst.String()) return ns.DialContextTCP(ctx, dst)
} }
statePath := filepath.Join(s.dir, "tailscaled.state") statePath := filepath.Join(s.dir, "tailscaled.state")

@ -295,18 +295,14 @@ func (ns *Impl) updateIPs(nm *netmap.NetworkMap) {
} }
} }
func (ns *Impl) DialContextTCP(ctx context.Context, addr string) (*gonet.TCPConn, error) { func (ns *Impl) DialContextTCP(ctx context.Context, ipp netaddr.IPPort) (*gonet.TCPConn, error) {
remoteIPPort, err := ns.dialer.Resolve(ctx, "tcp", addr)
if err != nil {
return nil, err
}
remoteAddress := tcpip.FullAddress{ remoteAddress := tcpip.FullAddress{
NIC: nicID, NIC: nicID,
Addr: tcpip.Address(remoteIPPort.IP().IPAddr().IP), Addr: tcpip.Address(ipp.IP().IPAddr().IP),
Port: remoteIPPort.Port(), Port: ipp.Port(),
} }
var ipType tcpip.NetworkProtocolNumber var ipType tcpip.NetworkProtocolNumber
if remoteIPPort.IP().Is4() { if ipp.IP().Is4() {
ipType = ipv4.ProtocolNumber ipType = ipv4.ProtocolNumber
} else { } else {
ipType = ipv6.ProtocolNumber ipType = ipv6.ProtocolNumber
@ -315,18 +311,14 @@ func (ns *Impl) DialContextTCP(ctx context.Context, addr string) (*gonet.TCPConn
return gonet.DialContextTCP(ctx, ns.ipstack, remoteAddress, ipType) return gonet.DialContextTCP(ctx, ns.ipstack, remoteAddress, ipType)
} }
func (ns *Impl) DialContextUDP(ctx context.Context, addr string) (*gonet.UDPConn, error) { func (ns *Impl) DialContextUDP(ctx context.Context, ipp netaddr.IPPort) (*gonet.UDPConn, error) {
remoteIPPort, err := ns.dialer.Resolve(ctx, "udp", addr)
if err != nil {
return nil, err
}
remoteAddress := &tcpip.FullAddress{ remoteAddress := &tcpip.FullAddress{
NIC: nicID, NIC: nicID,
Addr: tcpip.Address(remoteIPPort.IP().IPAddr().IP), Addr: tcpip.Address(ipp.IP().IPAddr().IP),
Port: remoteIPPort.Port(), Port: ipp.Port(),
} }
var ipType tcpip.NetworkProtocolNumber var ipType tcpip.NetworkProtocolNumber
if remoteIPPort.IP().Is4() { if ipp.IP().Is4() {
ipType = ipv4.ProtocolNumber ipType = ipv4.ProtocolNumber
} else { } else {
ipType = ipv6.ProtocolNumber ipType = ipv6.ProtocolNumber

Loading…
Cancel
Save