wgengine, magicsock, derp: misc cleanups, docs

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/75/head
Brad Fitzpatrick 5 years ago
parent a23a0d9c9f
commit e06ca40650

@ -13,6 +13,7 @@ package derphttp
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -31,7 +32,8 @@ import (
// //
// It automatically reconnects on error retry. That is, a failed Send or // It automatically reconnects on error retry. That is, a failed Send or
// Recv will report the error and not retry, but subsequent calls to // Recv will report the error and not retry, but subsequent calls to
// Send/Recv will completely re-establish the connection. // Send/Recv will completely re-establish the connection (unless Close
// has been called).
type Client struct { type Client struct {
privateKey key.Private privateKey key.Private
logf logger.Logf logf logger.Logf
@ -46,6 +48,8 @@ type Client struct {
client *derp.Client client *derp.Client
} }
// NewClient returns a new DERP-over-HTTP client. It connects lazily.
// To trigger a connection use Connect.
func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Client, error) { func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Client, error) {
u, err := url.Parse(serverURL) u, err := url.Parse(serverURL)
if err != nil { if err != nil {
@ -58,13 +62,18 @@ func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Cli
url: u, url: u,
closed: make(chan struct{}), closed: make(chan struct{}),
} }
if _, err := c.connect("derphttp.NewClient"); err != nil {
c.logf("%v", err)
}
return c, nil return c, nil
} }
func (c *Client) connect(caller string) (client *derp.Client, err error) { // Connect connects or reconnects to the server, unless already connected.
// It returns nil if there was already a good connection, or if one was made.
func (c *Client) Connect(ctx context.Context) error {
_, err := c.connect(ctx, "derphttp.Client.Connect")
return err
}
func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, err error) {
// TODO: use ctx for TCP+TLS+HTTP below
select { select {
case <-c.closed: case <-c.closed:
return nil, ErrClientClosed return nil, ErrClientClosed
@ -84,7 +93,7 @@ func (c *Client) connect(caller string) (client *derp.Client, err error) {
defer func() { defer func() {
if err != nil { if err != nil {
err = fmt.Errorf("%s connect: %v", caller, err) err = fmt.Errorf("%s connect: %v", caller, err)
if netConn := netConn; netConn != nil { if netConn != nil {
netConn.Close() netConn.Close()
} }
} }
@ -148,7 +157,7 @@ func (c *Client) connect(caller string) (client *derp.Client, err error) {
} }
func (c *Client) Send(dstKey key.Public, b []byte) error { func (c *Client) Send(dstKey key.Public, b []byte) error {
client, err := c.connect("derphttp.Client.Send") client, err := c.connect(context.TODO(), "derphttp.Client.Send")
if err != nil { if err != nil {
return err return err
} }
@ -159,7 +168,7 @@ func (c *Client) Send(dstKey key.Public, b []byte) error {
} }
func (c *Client) Recv(b []byte) (int, error) { func (c *Client) Recv(b []byte) (int, error) {
client, err := c.connect("derphttp.Client.Recv") client, err := c.connect(context.TODO(), "derphttp.Client.Recv")
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -170,6 +179,8 @@ func (c *Client) Recv(b []byte) (int, error) {
return n, err return n, err
} }
// Close closes the client. It will not automatically reconnect after
// being closed.
func (c *Client) Close() error { func (c *Client) Close() error {
select { select {
case <-c.closed: case <-c.closed:

@ -14,6 +14,7 @@ import (
"net" "net"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"syscall" "syscall"
"time" "time"
@ -51,8 +52,9 @@ type Conn struct {
indexedAddrsMu sync.Mutex indexedAddrsMu sync.Mutex
indexedAddrs map[udpAddr]indexedAddrSet indexedAddrs map[udpAddr]indexedAddrSet
stunReceiveMu sync.Mutex // stunReceiveFunc holds the current STUN packet processing func.
stunReceive func(p []byte, fromAddr *net.UDPAddr) // Its Loaded value is always non-nil.
stunReceiveFunc atomic.Value // of func(p []byte, fromAddr *net.UDPAddr)
derpMu sync.Mutex derpMu sync.Mutex
derp *derphttp.Client derp *derphttp.Client
@ -140,12 +142,21 @@ func Listen(opts Options) (*Conn, error) {
logf: log.Printf, logf: log.Printf,
indexedAddrs: make(map[udpAddr]indexedAddrSet), indexedAddrs: make(map[udpAddr]indexedAddrSet),
} }
c.ignoreSTUNPackets()
c.pconn.Reset(packetConn.(*net.UDPConn)) c.pconn.Reset(packetConn.(*net.UDPConn))
c.startEpUpdate <- struct{}{} // STUN immediately on start c.startEpUpdate <- struct{}{} // STUN immediately on start
go c.epUpdate(epUpdateCtx) go c.epUpdate(epUpdateCtx)
return c, nil return c, nil
} }
// ignoreSTUNPackets sets a STUN packet processing func that does nothing.
func (c *Conn) ignoreSTUNPackets() {
c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {})
}
// epUpdate runs in its own goroutine until ctx is shut down.
// Whenever c.startEpUpdate receives a value, it starts an
// STUN endpoint lookup.
func (c *Conn) epUpdate(ctx context.Context) { func (c *Conn) epUpdate(ctx context.Context) {
var lastEndpoints []string var lastEndpoints []string
var lastCancel func() var lastCancel func()
@ -186,18 +197,22 @@ func (c *Conn) epUpdate(ctx context.Context) {
} }
} }
// determineEndpoints returns the machine's endpoint addresses. It
// does a STUN lookup to determine its public address.
func (c *Conn) determineEndpoints(ctx context.Context) ([]string, error) { func (c *Conn) determineEndpoints(ctx context.Context) ([]string, error) {
var alreadyMu sync.Mutex var (
already := make(map[string]struct{}) alreadyMu sync.Mutex
var eps []string already = make(map[string]bool) // endpoint -> true
)
var eps []string // unique endpoints
addAddr := func(s, reason string) { addAddr := func(s, reason string) {
log.Printf("magicsock: found local %s (%s)\n", s, reason) log.Printf("magicsock: found local %s (%s)\n", s, reason)
alreadyMu.Lock() alreadyMu.Lock()
defer alreadyMu.Unlock() defer alreadyMu.Unlock()
if _, ok := already[s]; !ok { if !already[s] {
already[s] = struct{}{} already[s] = true
eps = append(eps, s) eps = append(eps, s)
} }
} }
@ -209,17 +224,13 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]string, error) {
Logf: c.logf, Logf: c.logf,
} }
c.stunReceiveMu.Lock() c.stunReceiveFunc.Store(s.Receive)
c.stunReceive = s.Receive
c.stunReceiveMu.Unlock()
if err := s.Run(ctx); err != nil { if err := s.Run(ctx); err != nil {
return nil, err return nil, err
} }
c.stunReceiveMu.Lock() c.ignoreSTUNPackets()
c.stunReceive = nil
c.stunReceiveMu.Unlock()
if localAddr := c.pconn.LocalAddr(); localAddr.IP.IsUnspecified() { if localAddr := c.pconn.LocalAddr(); localAddr.IP.IsUnspecified() {
localPort := fmt.Sprintf("%d", localAddr.Port) localPort := fmt.Sprintf("%d", localAddr.Port)
@ -421,16 +432,9 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep device.Endpoint, addr *net.UDPAd
if !stun.Is(b[:n]) { if !stun.Is(b[:n]) {
break break
} }
c.stunReceiveMu.Lock() c.stunReceiveFunc.Load().(func([]byte, *net.UDPAddr))(b, addr)
fn := c.stunReceive
c.stunReceiveMu.Unlock()
if fn != nil {
fn(b, addr)
}
} }
// TODO(crawshaw): remove all the indexed-addr logic
addrSet, _ := c.findIndexedAddrSet(addr) addrSet, _ := c.findIndexedAddrSet(addr)
if addrSet == nil { if addrSet == nil {
// The peer that sent this packet has roamed beyond the // The peer that sent this packet has roamed beyond the
@ -457,14 +461,14 @@ func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error {
return err return err
} }
go func() { go func() {
var b [1 << 16]byte var b [64 << 10]byte
for { for {
n, err := derp.Recv(b[:]) n, err := derp.Recv(b[:])
if err != nil { if err != nil {
if err == derphttp.ErrClientClosed { if err == derphttp.ErrClientClosed {
return return
} }
log.Printf("%v", err) log.Printf("derp.Recv: %v", err)
time.Sleep(250 * time.Millisecond) time.Sleep(250 * time.Millisecond)
} }
@ -696,16 +700,19 @@ func (a *AddrSet) Addrs() []wgcfg.Endpoint {
return eps return eps
} }
func (c *Conn) CreateEndpoint(key [32]byte, s string) (device.Endpoint, error) { // CreateEndpoint is called by WireGuard to connect to an endpoint.
// The key is the public key of the peer and addrs is a
// comma-separated list of UDP ip:ports.
func (c *Conn) CreateEndpoint(key [32]byte, addrs string) (device.Endpoint, error) {
pk := wgcfg.Key(key) pk := wgcfg.Key(key)
log.Printf("magicsock: CreateEndpoint: key=%s: %s", pk.ShortString(), s) log.Printf("magicsock: CreateEndpoint: key=%s: %s", pk.ShortString(), addrs)
a := &AddrSet{ a := &AddrSet{
publicKey: key, publicKey: key,
curAddr: -1, curAddr: -1,
} }
if s != "" { if addrs != "" {
for _, ep := range strings.Split(s, ",") { for _, ep := range strings.Split(addrs, ",") {
addr, err := net.ResolveUDPAddr("udp", ep) addr, err := net.ResolveUDPAddr("udp", ep)
if err != nil { if err != nil {
return nil, err return nil, err

@ -36,7 +36,7 @@ func TestListen(t *testing.T) {
defer conn.Close() defer conn.Close()
go func() { go func() {
var pkt [1 << 16]byte var pkt [64 << 10]byte
for { for {
_, _, _, err := conn.ReceiveIPv4(pkt[:]) _, _, _, err := conn.ReceiveIPv4(pkt[:])
if err != nil { if err != nil {

@ -107,18 +107,14 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev tun.Device, routerGen R
endpointsFn := func(endpoints []string) { endpointsFn := func(endpoints []string) {
e.mu.Lock() e.mu.Lock()
if e.endpoints != nil { e.endpoints = append(e.endpoints[:0], endpoints...)
e.endpoints = e.endpoints[:0]
}
e.endpoints = append(e.endpoints, endpoints...)
e.mu.Unlock() e.mu.Unlock()
e.RequestStatus() e.RequestStatus()
} }
magicsockOpts := magicsock.Options{ magicsockOpts := magicsock.Options{
Port: listenPort, Port: listenPort,
STUN: magicsock.DefaultSTUN, STUN: magicsock.DefaultSTUN,
// TODO(crawshaw): DERP: magicsock.DefaultDERP,
EndpointsFunc: endpointsFn, EndpointsFunc: endpointsFn,
} }
e.magicConn, err = magicsock.Listen(magicsockOpts) e.magicConn, err = magicsock.Listen(magicsockOpts)

Loading…
Cancel
Save