diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index 43a3ae94b..3e077f9fd 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -7,6 +7,7 @@ package main import ( + "bytes" "context" crand "crypto/rand" "crypto/rsa" @@ -16,6 +17,7 @@ import ( "encoding/binary" "encoding/json" "encoding/pem" + "errors" "flag" "fmt" "io" @@ -25,6 +27,7 @@ import ( "net/netip" "net/url" "os" + "os/signal" "strconv" "strings" "sync" @@ -35,6 +38,7 @@ import ( "tailscale.com/client/tailscale" "tailscale.com/client/tailscale/apitype" "tailscale.com/envknob" + "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" "tailscale.com/tsnet" @@ -44,13 +48,22 @@ import ( "tailscale.com/util/mak" "tailscale.com/util/must" "tailscale.com/util/rands" + "tailscale.com/version" ) +// ctxConn is a key to look up a net.Conn stored in an HTTP request's context. +type ctxConn struct{} + +// funnelClientsFile is the file where client IDs and secrets for OIDC clients +// accessing the IDP over Funnel are persisted. +const funnelClientsFile = "oidc-funnel-clients.json" + var ( flagVerbose = flag.Bool("verbose", false, "be verbose") flagPort = flag.Int("port", 443, "port to listen on") flagLocalPort = flag.Int("local-port", -1, "allow requests from localhost") flagUseLocalTailscaled = flag.Bool("use-local-tailscaled", false, "use local tailscaled instead of tsnet") + flagFunnel = flag.Bool("funnel", false, "use Tailscale Funnel to make tsidp available on the public internet") ) func main() { @@ -61,9 +74,11 @@ func main() { } var ( - lc *tailscale.LocalClient - st *ipnstate.Status - err error + lc *tailscale.LocalClient + st *ipnstate.Status + err error + watcherChan chan error + cleanup func() lns []net.Listener ) @@ -90,6 +105,18 @@ func main() { if !anySuccess { log.Fatalf("failed to listen on any of %v", st.TailscaleIPs) } + + // tailscaled needs to be setting an HTTP header for funneled requests + // that older versions don't provide. + // TODO(naman): is this the correct check? + if *flagFunnel && !version.AtLeast(st.Version, "1.71.0") { + log.Fatalf("Local tailscaled not new enough to support -funnel. Update Tailscale or use tsnet mode.") + } + cleanup, watcherChan, err = serveOnLocalTailscaled(ctx, lc, st, uint16(*flagPort), *flagFunnel) + if err != nil { + log.Fatalf("could not serve on local tailscaled: %v", err) + } + defer cleanup() } else { ts := &tsnet.Server{ Hostname: "idp", @@ -105,7 +132,15 @@ func main() { if err != nil { log.Fatalf("getting local client: %v", err) } - ln, err := ts.ListenTLS("tcp", fmt.Sprintf(":%d", *flagPort)) + var ln net.Listener + if *flagFunnel { + if err := ipn.CheckFunnelAccess(uint16(*flagPort), st.Self); err != nil { + log.Fatalf("%v", err) + } + ln, err = ts.ListenFunnel("tcp", fmt.Sprintf(":%d", *flagPort)) + } else { + ln, err = ts.ListenTLS("tcp", fmt.Sprintf(":%d", *flagPort)) + } if err != nil { log.Fatal(err) } @@ -113,13 +148,26 @@ func main() { } srv := &idpServer{ - lc: lc, + lc: lc, + funnel: *flagFunnel, + localTSMode: *flagUseLocalTailscaled, } if *flagPort != 443 { srv.serverURL = fmt.Sprintf("https://%s:%d", strings.TrimSuffix(st.Self.DNSName, "."), *flagPort) } else { srv.serverURL = fmt.Sprintf("https://%s", strings.TrimSuffix(st.Self.DNSName, ".")) } + if *flagFunnel { + f, err := os.Open(funnelClientsFile) + if err == nil { + srv.funnelClients = make(map[string]*funnelClient) + if err := json.NewDecoder(f).Decode(&srv.funnelClients); err != nil { + log.Fatalf("could not parse %s: %v", funnelClientsFile, err) + } + } else if !errors.Is(err, os.ErrNotExist) { + log.Fatalf("could not open %s: %v", funnelClientsFile, err) + } + } log.Printf("Running tsidp at %s ...", srv.serverURL) @@ -134,35 +182,129 @@ func main() { } for _, ln := range lns { - go http.Serve(ln, srv) + server := http.Server{ + Handler: srv, + ConnContext: func(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, ctxConn{}, c) + }, + } + go server.Serve(ln) + } + // need to catch os.Interrupt, otherwise deferred cleanup code doesn't run + exitChan := make(chan os.Signal, 1) + signal.Notify(exitChan, os.Interrupt) + select { + case <-exitChan: + log.Printf("interrupt, exiting") + return + case <-watcherChan: + if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { + log.Printf("watcher closed, exiting") + return + } + log.Fatalf("watcher error: %v", err) + return + } +} + +// serveOnLocalTailscaled starts a serve session using an already-running +// tailscaled instead of starting a fresh tsnet server, making something +// listening on clientDNSName:dstPort accessible over serve/funnel. +func serveOnLocalTailscaled(ctx context.Context, lc *tailscale.LocalClient, st *ipnstate.Status, dstPort uint16, shouldFunnel bool) (cleanup func(), watcherChan chan error, err error) { + // In order to support funneling out in local tailscaled mode, we need + // to add a serve config to forward the listeners we bound above and + // allow those forwarders to be funneled out. + sc, err := lc.GetServeConfig(ctx) + if err != nil { + return nil, nil, fmt.Errorf("could not get serve config: %v", err) + } + if sc == nil { + sc = new(ipn.ServeConfig) + } + + // We watch the IPN bus just to get a session ID. The session expires + // when we stop watching the bus, and that auto-deletes the foreground + // serve/funnel configs we are creating below. + watcher, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialState|ipn.NotifyNoPrivateKeys) + if err != nil { + return nil, nil, fmt.Errorf("could not set up ipn bus watcher: %v", err) + } + defer func() { + if err != nil { + watcher.Close() + } + }() + n, err := watcher.Next() + if err != nil { + return nil, nil, fmt.Errorf("could not get initial state from ipn bus watcher: %v", err) + } + if n.SessionID == "" { + err = fmt.Errorf("missing sessionID in ipn.Notify") + return nil, nil, err + } + watcherChan = make(chan error) + go func() { + for { + _, err = watcher.Next() + if err != nil { + watcherChan <- err + return + } + } + }() + + // Create a foreground serve config that gets cleaned up when tsidp + // exits and the session ID associated with this config is invalidated. + foregroundSc := new(ipn.ServeConfig) + mak.Set(&sc.Foreground, n.SessionID, foregroundSc) + serverURL := strings.TrimSuffix(st.Self.DNSName, ".") + fmt.Printf("setting funnel for %s:%v\n", serverURL, dstPort) + + foregroundSc.SetFunnel(serverURL, dstPort, shouldFunnel) + foregroundSc.SetWebHandler(&ipn.HTTPHandler{ + Proxy: fmt.Sprintf("https://%s", net.JoinHostPort(serverURL, strconv.Itoa(int(dstPort)))), + }, serverURL, uint16(*flagPort), "/", true) + err = lc.SetServeConfig(ctx, sc) + if err != nil { + return nil, watcherChan, fmt.Errorf("could not set serve config: %v", err) } - select {} + + return func() { watcher.Close() }, watcherChan, nil } type idpServer struct { lc *tailscale.LocalClient loopbackURL string serverURL string // "https://foo.bar.ts.net" + funnel bool + localTSMode bool lazyMux lazy.SyncValue[*http.ServeMux] lazySigningKey lazy.SyncValue[*signingKey] lazySigner lazy.SyncValue[jose.Signer] - mu sync.Mutex // guards the fields below - code map[string]*authRequest // keyed by random hex - accessToken map[string]*authRequest // keyed by random hex + mu sync.Mutex // guards the fields below + code map[string]*authRequest // keyed by random hex + accessToken map[string]*authRequest // keyed by random hex + funnelClients map[string]*funnelClient // keyed by client ID } type authRequest struct { // localRP is true if the request is from a relying party running on the - // same machine as the idp server. It is mutually exclusive with rpNodeID. + // same machine as the idp server. It is mutually exclusive with rpNodeID + // and funnelRP. localRP bool // rpNodeID is the NodeID of the relying party (who requested the auth, such // as Proxmox or Synology), not the user node who is being authenticated. It - // is mutually exclusive with localRP. + // is mutually exclusive with localRP and funnelRP. rpNodeID tailcfg.NodeID + // funnelRP is non-nil if the request is from a relying party outside the + // tailnet, via Tailscale Funnel. It is mutually exclusive with rpNodeID + // and localRP. + funnelRP *funnelClient + // clientID is the "client_id" sent in the authorized request. clientID string @@ -181,9 +323,12 @@ type authRequest struct { validTill time.Time } -func (ar *authRequest) allowRelyingParty(ctx context.Context, remoteAddr string, lc *tailscale.LocalClient) error { +// allowRelyingParty validates that a relying party identified either by a +// known remoteAddr or a valid client ID/secret pair is allowed to proceed +// with the authorization flow associated with this authRequest. +func (ar *authRequest) allowRelyingParty(r *http.Request, lc *tailscale.LocalClient) error { if ar.localRP { - ra, err := netip.ParseAddrPort(remoteAddr) + ra, err := netip.ParseAddrPort(r.RemoteAddr) if err != nil { return err } @@ -192,7 +337,18 @@ func (ar *authRequest) allowRelyingParty(ctx context.Context, remoteAddr string, } return nil } - who, err := lc.WhoIs(ctx, remoteAddr) + if ar.funnelRP != nil { + clientID, clientSecret, ok := r.BasicAuth() + if !ok { + clientID = r.FormValue("client_id") + clientSecret = r.FormValue("client_secret") + } + if ar.funnelRP.ID != clientID || ar.funnelRP.Secret != clientSecret { + return fmt.Errorf("tsidp: invalid client credentials") + } + return nil + } + who, err := lc.WhoIs(r.Context(), r.RemoteAddr) if err != nil { return fmt.Errorf("tsidp: error getting WhoIs: %w", err) } @@ -203,24 +359,60 @@ func (ar *authRequest) allowRelyingParty(ctx context.Context, remoteAddr string, } func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { - who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) + // This URL is visited by the user who is being authenticated. If they are + // visiting the URL over Funnel, that means they are not part of the + // tailnet that they are trying to be authenticated for. + if isFunnelRequest(r) { + http.Error(w, "tsidp: unauthorized", http.StatusUnauthorized) + return + } + + uq := r.URL.Query() + + redirectURI := uq.Get("redirect_uri") + if redirectURI == "" { + http.Error(w, "tsidp: must specify redirect_uri", http.StatusBadRequest) + return + } + + var remoteAddr string + if s.localTSMode { + // in local tailscaled mode, the local tailscaled is forwarding us + // HTTP requests, so reading r.RemoteAddr will just get us our own + // address. + remoteAddr = r.Header.Get("X-Forwarded-For") + } else { + remoteAddr = r.RemoteAddr + } + who, err := s.lc.WhoIs(r.Context(), remoteAddr) if err != nil { log.Printf("Error getting WhoIs: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } - uq := r.URL.Query() - code := rands.HexString(32) ar := &authRequest{ nonce: uq.Get("nonce"), remoteUser: who, - redirectURI: uq.Get("redirect_uri"), + redirectURI: redirectURI, clientID: uq.Get("client_id"), } - if r.URL.Path == "/authorize/localhost" { + if r.URL.Path == "/authorize/funnel" { + s.mu.Lock() + c, ok := s.funnelClients[ar.clientID] + s.mu.Unlock() + if !ok { + http.Error(w, "tsidp: invalid client ID", http.StatusBadRequest) + return + } + if ar.redirectURI != c.RedirectURI { + http.Error(w, "tsidp: redirect_uri mismatch", http.StatusBadRequest) + return + } + ar.funnelRP = c + } else if r.URL.Path == "/authorize/localhost" { ar.localRP = true } else { var ok bool @@ -237,8 +429,10 @@ func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { q := make(url.Values) q.Set("code", code) - q.Set("state", uq.Get("state")) - u := uq.Get("redirect_uri") + "?" + q.Encode() + if state := uq.Get("state"); state != "" { + q.Set("state", state) + } + u := redirectURI + "?" + q.Encode() log.Printf("Redirecting to %q", u) http.Redirect(w, r, u, http.StatusFound) @@ -251,6 +445,7 @@ func (s *idpServer) newMux() *http.ServeMux { mux.HandleFunc("/authorize/", s.authorize) mux.HandleFunc("/userinfo", s.serveUserInfo) mux.HandleFunc("/token", s.serveToken) + mux.HandleFunc("/clients/", s.serveClients) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/" { io.WriteString(w, "