diff --git a/ipn/ipnserver/server.go b/ipn/ipnserver/server.go index fb63de133..0fa8a5f0e 100644 --- a/ipn/ipnserver/server.go +++ b/ipn/ipnserver/server.go @@ -7,8 +7,8 @@ package ipnserver import ( "bufio" "context" + "errors" "fmt" - "html" "io" "log" "net" @@ -82,29 +82,112 @@ type Options struct { // talking to an IPN backend. type server struct { resetOnZero bool // call bs.Reset on transition from 1->0 connections + b *ipn.LocalBackend bsMu sync.Mutex // lock order: bsMu, then mu bs *ipn.BackendServer - mu sync.Mutex - clients map[net.Conn]bool + mu sync.Mutex + allClients map[net.Conn]connIdentity // HTTP or IPN + clients map[net.Conn]bool // subset of allClients; only IPN protocol } -func (s *server) serveConn(ctx context.Context, c net.Conn, logf logger.Logf) { - br := bufio.NewReader(c) +// connIdentity represents the owner of a localhost TCP connection. +type connIdentity struct { + Unknown bool + Pid int + UserID string + User *user.User +} + +// getConnIdentity returns the localhost TCP connection's identity information +// (pid, userid, user). If it's not Windows (for now), it returns a nil error +// and a ConnIdentity with Unknown set true. It's only an error if we expected +// to be able to map it and couldn't. +func getConnIdentity(c net.Conn) (ci connIdentity, err error) { + if runtime.GOOS != "windows" { // for now; TODO: expand to other OSes + return connIdentity{Unknown: true}, nil + } + la, err := netaddr.ParseIPPort(c.LocalAddr().String()) + if err != nil { + return ci, fmt.Errorf("parsing local address: %w", err) + } + ra, err := netaddr.ParseIPPort(c.RemoteAddr().String()) + if err != nil { + return ci, fmt.Errorf("parsing local remote: %w", err) + } + if !la.IP.IsLoopback() || !ra.IP.IsLoopback() { + return ci, errors.New("non-loopback connection") + } + tab, err := netstat.Get() + if err != nil { + return ci, fmt.Errorf("failed to get local connection table: %w", err) + } + pid := peerPid(tab.Entries, la, ra) + if pid == 0 { + return ci, errors.New("no local process found matching localhost connection") + } + ci.Pid = pid + uid, err := pidowner.OwnerOfPID(pid) + if err != nil { + var hint string + if runtime.GOOS == "windows" { + hint = " (WSL?)" + } + return ci, fmt.Errorf("failed to map connection's pid to a user%s: %w", hint, err) + } + ci.UserID = uid + u, err := user.LookupId(uid) + if err != nil { + return ci, fmt.Errorf("failed to look up user from userid: %w", err) + } + ci.User = u + return ci, nil +} +func (s *server) serveConn(ctx context.Context, c net.Conn, logf logger.Logf) { // First see if it's an HTTP request. + br := bufio.NewReader(c) c.SetReadDeadline(time.Now().Add(time.Second)) peek, _ := br.Peek(4) c.SetReadDeadline(time.Time{}) - if string(peek) == "GET " { - http.Serve(&oneConnListener{altReaderNetConn{br, c}}, localhostHandler(c)) + isHTTPReq := string(peek) == "GET " + + ci, err := s.addConn(c, isHTTPReq) + if err != nil { + if isHTTPReq { + fmt.Fprintf(c, "HTTP/1.0 500 Nope\r\nContent-Type: text/plain\r\nX-Content-Type-Options: nosniff\r\n\r\n%s\n", err.Error()) + c.Close() + return + } + defer c.Close() + serverToClient := func(b []byte) { ipn.WriteMsg(c, b) } + bs := ipn.NewBackendServer(logf, nil, serverToClient) + bs.SendErrorMessage(err.Error()) + time.Sleep(time.Second) + return + } + + if isHTTPReq { + httpServer := http.Server{ + // Localhost connections are cheap; so only do + // keep-alives for a short period of time, as these + // active connections lock the server into only serving + // that user. If the user has this page open, we don't + // want another switching user to be locked out for + // minutes. 5 seconds is enough to let browser hit + // favicon.ico and such. + IdleTimeout: 5 * time.Second, + ErrorLog: logger.StdLogger(logf), + Handler: s.localhostHandler(ci), + } + httpServer.Serve(&oneConnListener{&protoSwitchConn{s: s, br: br, Conn: c}}) return } - s.addConn(c) - logf("incoming control connection") defer s.removeAndCloseConn(c) + logf("incoming control connection") + for ctx.Err() == nil { msg, err := ipn.ReadMsg(br) if err != nil { @@ -125,19 +208,48 @@ func (s *server) serveConn(ctx context.Context, c net.Conn, logf logger.Logf) { } } -func (s *server) addConn(c net.Conn) { +func (s *server) addConn(c net.Conn, isHTTP bool) (ci connIdentity, err error) { + ci, err = getConnIdentity(c) + if err != nil { + return + } + s.mu.Lock() defer s.mu.Unlock() + if s.clients == nil { s.clients = map[net.Conn]bool{} } - s.clients[c] = true + if s.allClients == nil { + s.allClients = map[net.Conn]connIdentity{} + } + + // If clients are already connected, verify they're the same user. + // This mostly matters on Windows at the moment. + if len(s.allClients) > 0 { + var active connIdentity + for _, active = range s.allClients { + break + } + if ci.UserID != active.UserID { + //lint:ignore ST1005 we want to capitalize Tailscale here + return ci, fmt.Errorf("Tailscale already in use by %s, pid %d", active.User.Username, active.Pid) + } + } + + if !isHTTP { + s.clients[c] = true + } + s.allClients[c] = ci + + return ci, nil } func (s *server) removeAndCloseConn(c net.Conn) { s.mu.Lock() delete(s.clients, c) - remain := len(s.clients) + delete(s.allClients, c) + remain := len(s.allClients) s.mu.Unlock() if remain == 0 && s.resetOnZero { @@ -250,13 +362,11 @@ func Run(ctx context.Context, logf logger.Logf, logid string, getEngine func() ( if opts.DebugMux != nil { opts.DebugMux.HandleFunc("/debug/ipn", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - st := b.Status() - // TODO(bradfitz): add LogID and opts to st? - st.WriteHTML(w) + serveHTMLStatus(w, b) }) } + server.b = b server.bs = ipn.NewBackendServer(logf, b, server.writeToClients) if opts.AutostartStateKey != "" { @@ -436,54 +546,40 @@ func (l *oneConnListener) Addr() net.Addr { return dummyAddr("unused-address") } func (a dummyAddr) Network() string { return string(a) } func (a dummyAddr) String() string { return string(a) } -type altReaderNetConn struct { - r io.Reader +// protoSwitchConn is a net.Conn that's we want to speak HTTP to but +// it's already had a few bytes read from it to determine that it's +// HTTP. So we Read from its bufio.Reader. On Close, we we tell the +// server it's closed, so the server can account the who's connected. +type protoSwitchConn struct { + s *server net.Conn + br *bufio.Reader + closeOnce sync.Once } -func (a altReaderNetConn) Read(p []byte) (int, error) { return a.r.Read(p) } +func (psc *protoSwitchConn) Read(p []byte) (int, error) { return psc.br.Read(p) } +func (psc *protoSwitchConn) Close() error { + psc.closeOnce.Do(func() { psc.s.removeAndCloseConn(psc.Conn) }) + return nil +} -func localhostHandler(c net.Conn) http.Handler { - la, lerr := netaddr.ParseIPPort(c.LocalAddr().String()) - ra, rerr := netaddr.ParseIPPort(c.RemoteAddr().String()) +func (s *server) localhostHandler(ci connIdentity) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "