mirror of https://github.com/tailscale/tailscale/
cmd/tsidp: add start of OIDC Tailscale IdP
Updates #10263 Co-authored-by: Brad Fitzpatrick <bradfitz@tailscale.com> Change-Id: I240bc9b5ecf2df6f92c45929d105fde66c06a860 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>pull/10272/head
parent
2dbd546766
commit
7d4221c295
@ -0,0 +1,703 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
// The tsidp command is an OpenID Connect Identity Provider server.
|
||||||
|
//
|
||||||
|
// See https://github.com/tailscale/tailscale/issues/10263 for background.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
crand "crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
"tailscale.com/client/tailscale"
|
||||||
|
"tailscale.com/client/tailscale/apitype"
|
||||||
|
"tailscale.com/envknob"
|
||||||
|
"tailscale.com/ipn/ipnstate"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/tsnet"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
"tailscale.com/types/lazy"
|
||||||
|
"tailscale.com/types/logger"
|
||||||
|
"tailscale.com/types/views"
|
||||||
|
"tailscale.com/util/mak"
|
||||||
|
"tailscale.com/util/must"
|
||||||
|
"tailscale.com/util/rands"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
flagVerbose = flag.Bool("verbose", false, "be verbose")
|
||||||
|
flagPort = flag.Int("port", 443, "port to listen on")
|
||||||
|
flagLocalPort = flag.Int("local-port", -1, "allow requests from localhost")
|
||||||
|
flagUseLocalTailscaled = flag.Bool("use-local-tailscaled", false, "use local tailscaled instead of tsnet")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
ctx := context.Background()
|
||||||
|
if !envknob.UseWIPCode() {
|
||||||
|
log.Fatal("cmd/tsidp is a work in progress and has not been security reviewed;\nits use requires TAILSCALE_USE_WIP_CODE=1 be set in the environment for now.")
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
lc *tailscale.LocalClient
|
||||||
|
st *ipnstate.Status
|
||||||
|
err error
|
||||||
|
|
||||||
|
lns []net.Listener
|
||||||
|
)
|
||||||
|
if *flagUseLocalTailscaled {
|
||||||
|
lc = &tailscale.LocalClient{}
|
||||||
|
st, err = lc.StatusWithoutPeers(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("getting status: %v", err)
|
||||||
|
}
|
||||||
|
portStr := fmt.Sprint(*flagPort)
|
||||||
|
anySuccess := false
|
||||||
|
for _, ip := range st.TailscaleIPs {
|
||||||
|
ln, err := net.Listen("tcp", net.JoinHostPort(ip.String(), portStr))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to listen on %v: %v", ip, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
anySuccess = true
|
||||||
|
ln = tls.NewListener(ln, &tls.Config{
|
||||||
|
GetCertificate: lc.GetCertificate,
|
||||||
|
})
|
||||||
|
lns = append(lns, ln)
|
||||||
|
}
|
||||||
|
if !anySuccess {
|
||||||
|
log.Fatalf("failed to listen on any of %v", st.TailscaleIPs)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ts := &tsnet.Server{
|
||||||
|
Hostname: "idp",
|
||||||
|
}
|
||||||
|
if !*flagVerbose {
|
||||||
|
ts.Logf = logger.Discard
|
||||||
|
}
|
||||||
|
st, err = ts.Up(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
lc, err = ts.LocalClient()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("getting local client: %v", err)
|
||||||
|
}
|
||||||
|
ln, err := ts.ListenTLS("tcp", fmt.Sprintf(":%d", *flagPort))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
lns = append(lns, ln)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := &idpServer{
|
||||||
|
lc: lc,
|
||||||
|
}
|
||||||
|
if *flagPort != 443 {
|
||||||
|
srv.serverURL = fmt.Sprintf("https://%s:%d", strings.TrimSuffix(st.Self.DNSName, "."), *flagPort)
|
||||||
|
} else {
|
||||||
|
srv.serverURL = fmt.Sprintf("https://%s", strings.TrimSuffix(st.Self.DNSName, "."))
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Running tsidp at %s ...", srv.serverURL)
|
||||||
|
|
||||||
|
if *flagLocalPort != -1 {
|
||||||
|
log.Printf("Also running tsidp at %s ...", srv.loopbackURL)
|
||||||
|
srv.loopbackURL = fmt.Sprintf("http://localhost:%d", *flagLocalPort)
|
||||||
|
ln, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *flagLocalPort))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
lns = append(lns, ln)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ln := range lns {
|
||||||
|
go http.Serve(ln, srv)
|
||||||
|
}
|
||||||
|
select {}
|
||||||
|
}
|
||||||
|
|
||||||
|
type idpServer struct {
|
||||||
|
lc *tailscale.LocalClient
|
||||||
|
loopbackURL string
|
||||||
|
serverURL string // "https://foo.bar.ts.net"
|
||||||
|
|
||||||
|
lazyMux lazy.SyncValue[*http.ServeMux]
|
||||||
|
lazySigningKey lazy.SyncValue[*signingKey]
|
||||||
|
lazySigner lazy.SyncValue[jose.Signer]
|
||||||
|
|
||||||
|
mu sync.Mutex // guards the fields below
|
||||||
|
code map[string]*authRequest // keyed by random hex
|
||||||
|
accessToken map[string]*authRequest // keyed by random hex
|
||||||
|
}
|
||||||
|
|
||||||
|
type authRequest struct {
|
||||||
|
// localRP is true if the request is from a relying party running on the
|
||||||
|
// same machine as the idp server. It is mutually exclusive with rpNodeID.
|
||||||
|
localRP bool
|
||||||
|
|
||||||
|
// rpNodeID is the NodeID of the relying party (who requested the auth, such
|
||||||
|
// as Proxmox or Synology), not the user node who is being authenticated. It
|
||||||
|
// is mutually exclusive with localRP.
|
||||||
|
rpNodeID tailcfg.NodeID
|
||||||
|
|
||||||
|
// clientID is the "client_id" sent in the authorized request.
|
||||||
|
clientID string
|
||||||
|
|
||||||
|
// nonce presented in the request.
|
||||||
|
nonce string
|
||||||
|
|
||||||
|
// redirectURI is the redirect_uri presented in the request.
|
||||||
|
redirectURI string
|
||||||
|
|
||||||
|
// remoteUser is the user who is being authenticated.
|
||||||
|
remoteUser *apitype.WhoIsResponse
|
||||||
|
|
||||||
|
// validTill is the time until which the token is valid.
|
||||||
|
// As of 2023-11-14, it is 5 minutes.
|
||||||
|
// TODO: add routine to delete expired tokens.
|
||||||
|
validTill time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ar *authRequest) allowRelyingParty(ctx context.Context, remoteAddr string, lc *tailscale.LocalClient) error {
|
||||||
|
if ar.localRP {
|
||||||
|
ra, err := netip.ParseAddrPort(remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ra.Addr().IsLoopback() {
|
||||||
|
return fmt.Errorf("tsidp: request from non-loopback address")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
who, err := lc.WhoIs(ctx, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("tsidp: error getting WhoIs: %w", err)
|
||||||
|
}
|
||||||
|
if ar.rpNodeID != who.Node.ID {
|
||||||
|
return fmt.Errorf("tsidp: token for different node")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) {
|
||||||
|
who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error getting WhoIs: %v", err)
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
uq := r.URL.Query()
|
||||||
|
|
||||||
|
code := rands.HexString(32)
|
||||||
|
ar := &authRequest{
|
||||||
|
nonce: uq.Get("nonce"),
|
||||||
|
remoteUser: who,
|
||||||
|
redirectURI: uq.Get("redirect_uri"),
|
||||||
|
clientID: uq.Get("client_id"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.URL.Path == "/authorize/localhost" {
|
||||||
|
ar.localRP = true
|
||||||
|
} else {
|
||||||
|
var ok bool
|
||||||
|
ar.rpNodeID, ok = parseID[tailcfg.NodeID](strings.TrimPrefix(r.URL.Path, "/authorize/"))
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "tsidp: invalid node ID suffix after /authorize/", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
mak.Set(&s.code, code, ar)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
q := make(url.Values)
|
||||||
|
q.Set("code", code)
|
||||||
|
q.Set("state", uq.Get("state"))
|
||||||
|
u := uq.Get("redirect_uri") + "?" + q.Encode()
|
||||||
|
log.Printf("Redirecting to %q", u)
|
||||||
|
|
||||||
|
http.Redirect(w, r, u, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *idpServer) newMux() *http.ServeMux {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc(oidcJWKSPath, s.serveJWKS)
|
||||||
|
mux.HandleFunc(oidcConfigPath, s.serveOpenIDConfig)
|
||||||
|
mux.HandleFunc("/authorize/", s.authorize)
|
||||||
|
mux.HandleFunc("/userinfo", s.serveUserInfo)
|
||||||
|
mux.HandleFunc("/token", s.serveToken)
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/" {
|
||||||
|
io.WriteString(w, "<html><body><h1>Tailscale OIDC IdP</h1>")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Error(w, "tsidp: not found", http.StatusNotFound)
|
||||||
|
})
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log.Printf("%v %v", r.Method, r.URL)
|
||||||
|
s.lazyMux.Get(s.newMux).ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != "GET" {
|
||||||
|
http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tk, ok := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "tsidp: invalid Authorization header", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
ar, ok := s.accessToken[tk]
|
||||||
|
s.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "tsidp: invalid token", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := ar.allowRelyingParty(r.Context(), r.RemoteAddr, s.lc); err != nil {
|
||||||
|
log.Printf("Error allowing relying party: %v", err)
|
||||||
|
http.Error(w, err.Error(), http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ar.validTill.Before(time.Now()) {
|
||||||
|
http.Error(w, "tsidp: token expired", http.StatusBadRequest)
|
||||||
|
s.mu.Lock()
|
||||||
|
delete(s.accessToken, tk)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
ui := userInfo{}
|
||||||
|
if ar.remoteUser.Node.IsTagged() {
|
||||||
|
http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ui.Sub = ar.remoteUser.Node.User.String()
|
||||||
|
ui.Name = ar.remoteUser.UserProfile.DisplayName
|
||||||
|
ui.Email = ar.remoteUser.UserProfile.LoginName
|
||||||
|
ui.Picture = ar.remoteUser.UserProfile.ProfilePicURL
|
||||||
|
|
||||||
|
// TODO(maisem): not sure if this is the right thing to do
|
||||||
|
ui.UserName, _, _ = strings.Cut(ar.remoteUser.UserProfile.LoginName, "@")
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(ui); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type userInfo struct {
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Picture string `json:"picture"`
|
||||||
|
UserName string `json:"username"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != "POST" {
|
||||||
|
http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.FormValue("grant_type") != "authorization_code" {
|
||||||
|
http.Error(w, "tsidp: grant_type not supported", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := r.FormValue("code")
|
||||||
|
if code == "" {
|
||||||
|
http.Error(w, "tsidp: code is required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
ar, ok := s.code[code]
|
||||||
|
if ok {
|
||||||
|
delete(s.code, code)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "tsidp: code not found", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := ar.allowRelyingParty(r.Context(), r.RemoteAddr, s.lc); err != nil {
|
||||||
|
log.Printf("Error allowing relying party: %v", err)
|
||||||
|
http.Error(w, err.Error(), http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ar.redirectURI != r.FormValue("redirect_uri") {
|
||||||
|
http.Error(w, "tsidp: redirect_uri mismatch", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
signer, err := s.oidcSigner()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error getting signer: %v", err)
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jti := rands.HexString(32)
|
||||||
|
who := ar.remoteUser
|
||||||
|
|
||||||
|
// TODO(maisem): not sure if this is the right thing to do
|
||||||
|
userName, _, _ := strings.Cut(ar.remoteUser.UserProfile.LoginName, "@")
|
||||||
|
n := who.Node.View()
|
||||||
|
if n.IsTagged() {
|
||||||
|
http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
_, tcd, _ := strings.Cut(n.Name(), ".")
|
||||||
|
tsClaims := tailscaleClaims{
|
||||||
|
Claims: jwt.Claims{
|
||||||
|
Audience: jwt.Audience{ar.clientID},
|
||||||
|
Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
ID: jti,
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
Issuer: s.serverURL,
|
||||||
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
|
Subject: n.User().String(),
|
||||||
|
},
|
||||||
|
Nonce: ar.nonce,
|
||||||
|
Key: n.Key(),
|
||||||
|
Addresses: n.Addresses(),
|
||||||
|
NodeID: n.ID(),
|
||||||
|
NodeName: n.Name(),
|
||||||
|
Tailnet: tcd,
|
||||||
|
UserID: n.User(),
|
||||||
|
Email: who.UserProfile.LoginName,
|
||||||
|
UserName: userName,
|
||||||
|
}
|
||||||
|
if ar.localRP {
|
||||||
|
tsClaims.Issuer = s.loopbackURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an OIDC token using this issuer's signer.
|
||||||
|
token, err := jwt.Signed(signer).Claims(tsClaims).CompactSerialize()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error getting token: %v", err)
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
at := rands.HexString(32)
|
||||||
|
s.mu.Lock()
|
||||||
|
ar.validTill = now.Add(5 * time.Minute)
|
||||||
|
mak.Set(&s.accessToken, at, ar)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(oidcTokenResponse{
|
||||||
|
AccessToken: at,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 5 * 60,
|
||||||
|
IDToken: token,
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type oidcTokenResponse struct {
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
oidcJWKSPath = "/.well-known/jwks.json"
|
||||||
|
oidcConfigPath = "/.well-known/openid-configuration"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *idpServer) oidcSigner() (jose.Signer, error) {
|
||||||
|
return s.lazySigner.GetErr(func() (jose.Signer, error) {
|
||||||
|
sk, err := s.oidcPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return jose.NewSigner(jose.SigningKey{
|
||||||
|
Algorithm: jose.RS256,
|
||||||
|
Key: sk.k,
|
||||||
|
}, &jose.SignerOptions{EmbedJWK: false, ExtraHeaders: map[jose.HeaderKey]any{
|
||||||
|
jose.HeaderType: "JWT",
|
||||||
|
"kid": fmt.Sprint(sk.kid),
|
||||||
|
}})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *idpServer) oidcPrivateKey() (*signingKey, error) {
|
||||||
|
return s.lazySigningKey.GetErr(func() (*signingKey, error) {
|
||||||
|
var sk signingKey
|
||||||
|
b, err := os.ReadFile("oidc-key.json")
|
||||||
|
if err == nil {
|
||||||
|
if err := sk.UnmarshalJSON(b); err == nil {
|
||||||
|
return &sk, nil
|
||||||
|
} else {
|
||||||
|
log.Printf("Error unmarshaling key: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
id, k := mustGenRSAKey(2048)
|
||||||
|
sk.k = k
|
||||||
|
sk.kid = id
|
||||||
|
b, err = sk.MarshalJSON()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error marshaling key: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile("oidc-key.json", b, 0600); err != nil {
|
||||||
|
log.Fatalf("Error writing key: %v", err)
|
||||||
|
}
|
||||||
|
return &sk, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != oidcJWKSPath {
|
||||||
|
http.Error(w, "tsidp: not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
sk, err := s.oidcPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// TODO(maisem): maybe only marshal this once and reuse?
|
||||||
|
// TODO(maisem): implement key rotation.
|
||||||
|
je := json.NewEncoder(w)
|
||||||
|
je.SetIndent("", " ")
|
||||||
|
if err := je.Encode(jose.JSONWebKeySet{
|
||||||
|
Keys: []jose.JSONWebKey{
|
||||||
|
{
|
||||||
|
Key: sk.k.Public(),
|
||||||
|
Algorithm: string(jose.RS256),
|
||||||
|
Use: "sig",
|
||||||
|
KeyID: fmt.Sprint(sk.kid),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// openIDProviderMetadata is a partial representation of
|
||||||
|
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata.
|
||||||
|
type openIDProviderMetadata struct {
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"`
|
||||||
|
TokenEndpoint string `json:"token_endpoint,omitempty"`
|
||||||
|
UserInfoEndpoint string `json:"userinfo_endpoint,omitempty"`
|
||||||
|
JWKS_URI string `json:"jwks_uri"`
|
||||||
|
ScopesSupported views.Slice[string] `json:"scopes_supported"`
|
||||||
|
ResponseTypesSupported views.Slice[string] `json:"response_types_supported"`
|
||||||
|
SubjectTypesSupported views.Slice[string] `json:"subject_types_supported"`
|
||||||
|
ClaimsSupported views.Slice[string] `json:"claims_supported"`
|
||||||
|
IDTokenSigningAlgValuesSupported views.Slice[string] `json:"id_token_signing_alg_values_supported"`
|
||||||
|
// TODO(maisem): maybe add other fields?
|
||||||
|
// Currently we fill out the REQUIRED fields, scopes_supported and claims_supported.
|
||||||
|
}
|
||||||
|
|
||||||
|
type tailscaleClaims struct {
|
||||||
|
jwt.Claims `json:",inline"`
|
||||||
|
Nonce string `json:"nonce,omitempty"` // the nonce from the request
|
||||||
|
Key key.NodePublic `json:"key"` // the node public key
|
||||||
|
Addresses views.Slice[netip.Prefix] `json:"addresses"` // the Tailscale IPs of the node
|
||||||
|
NodeID tailcfg.NodeID `json:"nid"` // the stable node ID
|
||||||
|
NodeName string `json:"node"` // name of the node
|
||||||
|
Tailnet string `json:"tailnet"` // tailnet (like tail-scale.ts.net)
|
||||||
|
|
||||||
|
// Email is the "emailish" value with an '@' sign. It might not be a valid email.
|
||||||
|
Email string `json:"email,omitempty"` // user emailish (like "alice@github" or "bob@example.com")
|
||||||
|
UserID tailcfg.UserID `json:"uid,omitempty"`
|
||||||
|
|
||||||
|
// UserName is the local part of Email (without '@' and domain).
|
||||||
|
// It is a temporary (2023-11-15) hack during development.
|
||||||
|
// We should probably let this be configured via grants.
|
||||||
|
UserName string `json:"username,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
openIDSupportedClaims = views.SliceOf([]string{
|
||||||
|
// Standard claims, these correspond to fields in jwt.Claims.
|
||||||
|
"sub", "aud", "exp", "iat", "iss", "jti", "nbf", "username", "email",
|
||||||
|
|
||||||
|
// Tailscale claims, these correspond to fields in tailscaleClaims.
|
||||||
|
"key", "addresses", "nid", "node", "tailnet", "tags", "user", "uid",
|
||||||
|
})
|
||||||
|
|
||||||
|
// As defined in the OpenID spec this should be "openid".
|
||||||
|
openIDSupportedScopes = views.SliceOf([]string{"openid", "email", "profile"})
|
||||||
|
|
||||||
|
// We only support getting the id_token.
|
||||||
|
openIDSupportedReponseTypes = views.SliceOf([]string{"id_token", "code"})
|
||||||
|
|
||||||
|
// The type of the "sub" field in the JWT, which means it is globally unique identifier.
|
||||||
|
// The other option is "pairwise", which means the identifier is different per receiving 3p.
|
||||||
|
openIDSupportedSubjectTypes = views.SliceOf([]string{"public"})
|
||||||
|
|
||||||
|
// The algo used for signing. The OpenID spec says "The algorithm RS256 MUST be included."
|
||||||
|
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
|
||||||
|
openIDSupportedSigningAlgos = views.SliceOf([]string{string(jose.RS256)})
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != oidcConfigPath {
|
||||||
|
http.Error(w, "tsidp: not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ap, err := netip.ParseAddrPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error parsing remote addr: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var authorizeEndpoint string
|
||||||
|
rpEndpoint := s.serverURL
|
||||||
|
if who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr); err == nil {
|
||||||
|
authorizeEndpoint = fmt.Sprintf("%s/authorize/%d", s.serverURL, who.Node.ID)
|
||||||
|
} else if ap.Addr().IsLoopback() {
|
||||||
|
rpEndpoint = s.loopbackURL
|
||||||
|
authorizeEndpoint = fmt.Sprintf("%s/authorize/localhost", s.serverURL)
|
||||||
|
} else {
|
||||||
|
log.Printf("Error getting WhoIs: %v", err)
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
je := json.NewEncoder(w)
|
||||||
|
je.SetIndent("", " ")
|
||||||
|
if err := je.Encode(openIDProviderMetadata{
|
||||||
|
AuthorizationEndpoint: authorizeEndpoint,
|
||||||
|
Issuer: rpEndpoint,
|
||||||
|
JWKS_URI: rpEndpoint + oidcJWKSPath,
|
||||||
|
UserInfoEndpoint: rpEndpoint + "/userinfo",
|
||||||
|
TokenEndpoint: rpEndpoint + "/token",
|
||||||
|
ScopesSupported: openIDSupportedScopes,
|
||||||
|
ResponseTypesSupported: openIDSupportedReponseTypes,
|
||||||
|
SubjectTypesSupported: openIDSupportedSubjectTypes,
|
||||||
|
ClaimsSupported: openIDSupportedClaims,
|
||||||
|
IDTokenSigningAlgValuesSupported: openIDSupportedSigningAlgos,
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
minimumRSAKeySize = 2048
|
||||||
|
)
|
||||||
|
|
||||||
|
// mustGenRSAKey generates a new RSA key with the provided number of bits. It
|
||||||
|
// panics on failure. bits must be at least minimumRSAKeySizeBytes * 8.
|
||||||
|
func mustGenRSAKey(bits int) (kid uint64, k *rsa.PrivateKey) {
|
||||||
|
if bits < minimumRSAKeySize {
|
||||||
|
panic("request to generate a too-small RSA key")
|
||||||
|
}
|
||||||
|
kid = must.Get(readUint64(crand.Reader))
|
||||||
|
k = must.Get(rsa.GenerateKey(crand.Reader, bits))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// readUint64 reads from r until 8 bytes represent a non-zero uint64.
|
||||||
|
func readUint64(r io.Reader) (uint64, error) {
|
||||||
|
for {
|
||||||
|
var b [8]byte
|
||||||
|
if _, err := io.ReadFull(r, b[:]); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if v := binary.BigEndian.Uint64(b[:]); v != 0 {
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// rsaPrivateKeyJSONWrapper is the the JSON serialization
|
||||||
|
// format used by RSAPrivateKey.
|
||||||
|
type rsaPrivateKeyJSONWrapper struct {
|
||||||
|
Key string
|
||||||
|
ID uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type signingKey struct {
|
||||||
|
k *rsa.PrivateKey
|
||||||
|
kid uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sk *signingKey) MarshalJSON() ([]byte, error) {
|
||||||
|
b := pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(sk.k),
|
||||||
|
}
|
||||||
|
bts := pem.EncodeToMemory(&b)
|
||||||
|
return json.Marshal(rsaPrivateKeyJSONWrapper{
|
||||||
|
Key: base64.URLEncoding.EncodeToString(bts),
|
||||||
|
ID: sk.kid,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sk *signingKey) UnmarshalJSON(b []byte) error {
|
||||||
|
var wrapper rsaPrivateKeyJSONWrapper
|
||||||
|
if err := json.Unmarshal(b, &wrapper); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(wrapper.Key) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
b64dec, err := base64.URLEncoding.DecodeString(wrapper.Key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
blk, _ := pem.Decode(b64dec)
|
||||||
|
k, err := x509.ParsePKCS1PrivateKey(blk.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sk.k = k
|
||||||
|
sk.kid = wrapper.ID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseID takes a string input and returns a typed IntID T and true, or a zero
|
||||||
|
// value and false if the input is unhandled syntax or out of a valid range.
|
||||||
|
func parseID[T ~int64](input string) (_ T, ok bool) {
|
||||||
|
if input == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
i, err := strconv.ParseInt(input, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if i < 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return T(i), true
|
||||||
|
}
|
Loading…
Reference in New Issue