diff --git a/cmd/sniproxy/snipproxy.go b/cmd/sniproxy/snipproxy.go index d465c690a..8e7b4f85d 100644 --- a/cmd/sniproxy/snipproxy.go +++ b/cmd/sniproxy/snipproxy.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "golang.org/x/net/dns/dnsmessage" "inet.af/tcpproxy" "tailscale.com/client/tailscale" "tailscale.com/net/netutil" @@ -23,6 +24,8 @@ import ( var ports = flag.String("ports", "443", "comma-separated list of ports to proxy") +var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") + func main() { flag.Parse() if *ports == "" { @@ -86,8 +89,29 @@ func (s *server) serveDNSConn(c nettype.ConnPacketConn) { c.SetReadDeadline(time.Now().Add(5 * time.Second)) buf := make([]byte, 1500) n, err := c.Read(buf) - log.Printf("got DNS packet: %q, %v", buf[:n], err) - // TODO: rest of the owl + if err != nil { + log.Printf("c.Read failed: %v\n ", err) + return + } + + var msg dnsmessage.Message + err = msg.Unpack(buf[:n]) + if err != nil { + log.Printf("dnsmessage unpack failed: %v\n ", err) + return + } + + buf, err = s.dnsResponse(&msg) + if err != nil { + log.Printf("s.dnsResponse failed: %v\n", err) + return + } + + _, err = c.Write(buf) + if err != nil { + log.Printf("c.Write failed: %v\n", err) + return + } } func (s *server) serveConn(c net.Conn) { @@ -107,7 +131,6 @@ func (s *server) serveConn(c net.Conn) { return netutil.NewOneConnListener(c, nil), nil } p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) { - log.Printf("got req for %q from %v", sniName, c.RemoteAddr()) return &tcpproxy.DialProxy{ Addr: net.JoinHostPort(sniName, port), DialContext: dialer.DialContext, @@ -115,3 +138,62 @@ func (s *server) serveConn(c net.Conn) { }) p.Start() } + +func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) { + resp := dnsmessage.NewBuilder(buf, + dnsmessage.Header{ + ID: req.Header.ID, + Response: true, + Authoritative: true, + }) + resp.EnableCompression() + + if len(req.Questions) == 0 { + buf, _ = resp.Finish() + return + } + + q := req.Questions[0] + err = resp.StartQuestions() + if err != nil { + return + } + resp.Question(q) + + ip4, ip6 := s.ts.TailscaleIPs() + err = resp.StartAnswers() + if err != nil { + return + } + + switch q.Type { + case dnsmessage.TypeAAAA: + err = resp.AAAAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AAAAResource{AAAA: ip6.As16()}, + ) + + case dnsmessage.TypeA: + err = resp.AResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AResource{A: ip4.As4()}, + ) + case dnsmessage.TypeSOA: + err = resp.SOAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, + Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, + ) + case dnsmessage.TypeNS: + err = resp.NSResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.NSResource{NS: tsMBox}, + ) + } + + if err != nil { + return + } + + return resp.Finish() +} diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index a348285c7..e021a55fa 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -351,6 +351,24 @@ func (s *Server) doInit() { } } +// TailscaleIPs returns IPv4 and IPv6 addresses for this node. If the node +// has not yet joined a tailnet or is otherwise unaware of its own IP addresses, +// the returned ip4, ip6 will be !netip.IsValid(). +func (s *Server) TailscaleIPs() (ip4, ip6 netip.Addr) { + nm := s.lb.NetMap() + for _, addr := range nm.Addresses { + ip := addr.Addr() + if ip.Is6() { + ip6 = ip + } + if ip.Is4() { + ip4 = ip + } + } + + return ip4, ip6 +} + func (s *Server) getAuthKey() string { if v := s.AuthKey; v != "" { return v diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 488c518d4..bc5adc2e2 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -299,3 +299,43 @@ func TestLoopbackSOCKS5(t *testing.T) { t.Errorf("got %q, want %q", got, want) } } + +func TestTailscaleIPs(t *testing.T) { + controlURL := startControl(t) + + tmp := t.TempDir() + tmps1 := filepath.Join(tmp, "s1") + os.MkdirAll(tmps1, 0755) + s1 := &Server{ + Dir: tmps1, + ControlURL: controlURL, + Hostname: "s1", + Store: new(mem.Store), + Ephemeral: true, + } + defer s1.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + s1status, err := s1.Up(ctx) + if err != nil { + t.Fatal(err) + } + + var upIp4, upIp6 netip.Addr + for _, ip := range s1status.TailscaleIPs { + if ip.Is6() { + upIp6 = ip + } + if ip.Is4() { + upIp4 = ip + } + } + + sIp4, sIp6 := s1.TailscaleIPs() + if !(upIp4 == sIp4 && upIp6 == sIp6) { + t.Errorf("s1.TailscaleIPs returned a different result than S1.Up, (%s, %s) != (%s, %s)", + sIp4, upIp4, sIp6, upIp6) + } +}