client/web: move auth session creation out of /api/auth

Splits auth session creation into two new endpoints:

/api/auth/session/new - to request a new auth session

/api/auth/session/wait - to block until user has completed auth url

Updates tailscale/corp#14335

Signed-off-by: Sonia Appasamy <sonia@tailscale.com>
pull/10037/head
Sonia Appasamy 1 year ago committed by Sonia Appasamy
parent 658971d7c0
commit e5dcf7bdde

@ -3,19 +3,19 @@ import React from "react"
import LegacyClientView from "src/components/views/legacy-client-view" import LegacyClientView from "src/components/views/legacy-client-view"
import LoginClientView from "src/components/views/login-client-view" import LoginClientView from "src/components/views/login-client-view"
import ReadonlyClientView from "src/components/views/readonly-client-view" import ReadonlyClientView from "src/components/views/readonly-client-view"
import useAuth, { AuthResponse } from "src/hooks/auth" import useAuth, { AuthResponse, SessionsCallbacks } from "src/hooks/auth"
import useNodeData from "src/hooks/node-data" import useNodeData from "src/hooks/node-data"
import ManagementClientView from "./views/management-client-view" import ManagementClientView from "./views/management-client-view"
export default function App() { export default function App() {
const { data: auth, loading: loadingAuth, waitOnAuth } = useAuth() const { data: auth, loading: loadingAuth, sessions } = useAuth()
return ( return (
<div className="flex flex-col items-center min-w-sm max-w-lg mx-auto py-14"> <div className="flex flex-col items-center min-w-sm max-w-lg mx-auto py-14">
{loadingAuth ? ( {loadingAuth ? (
<div className="text-center py-14">Loading...</div> // TODO(sonia): add a loading view <div className="text-center py-14">Loading...</div> // TODO(sonia): add a loading view
) : ( ) : (
<WebClient auth={auth} waitOnAuth={waitOnAuth} /> <WebClient auth={auth} sessions={sessions} />
)} )}
</div> </div>
) )
@ -23,10 +23,10 @@ export default function App() {
function WebClient({ function WebClient({
auth, auth,
waitOnAuth, sessions,
}: { }: {
auth?: AuthResponse auth?: AuthResponse
waitOnAuth: () => Promise<void> sessions: SessionsCallbacks
}) { }) {
const { data, refreshData, updateNode } = useNodeData() const { data, refreshData, updateNode } = useNodeData()
@ -45,7 +45,7 @@ function WebClient({
<ManagementClientView {...data} /> <ManagementClientView {...data} />
) : data.DebugMode === "login" || data.DebugMode === "full" ? ( ) : data.DebugMode === "login" || data.DebugMode === "full" ? (
// Render new client interface in readonly mode. // Render new client interface in readonly mode.
<ReadonlyClientView data={data} auth={auth} waitOnAuth={waitOnAuth} /> <ReadonlyClientView data={data} auth={auth} sessions={sessions} />
) : ( ) : (
// Render legacy client interface. // Render legacy client interface.
<LegacyClientView <LegacyClientView

@ -1,5 +1,5 @@
import React from "react" import React from "react"
import { AuthResponse } from "src/hooks/auth" import { AuthResponse, AuthType, SessionsCallbacks } from "src/hooks/auth"
import { NodeData } from "src/hooks/node-data" import { NodeData } from "src/hooks/node-data"
import { ReactComponent as ConnectedDeviceIcon } from "src/icons/connected-device.svg" import { ReactComponent as ConnectedDeviceIcon } from "src/icons/connected-device.svg"
import { ReactComponent as TailscaleLogo } from "src/icons/tailscale-logo.svg" import { ReactComponent as TailscaleLogo } from "src/icons/tailscale-logo.svg"
@ -17,11 +17,11 @@ import ProfilePic from "src/ui/profile-pic"
export default function ReadonlyClientView({ export default function ReadonlyClientView({
data, data,
auth, auth,
waitOnAuth, sessions,
}: { }: {
data: NodeData data: NodeData
auth?: AuthResponse auth?: AuthResponse
waitOnAuth: () => Promise<void> sessions: SessionsCallbacks
}) { }) {
return ( return (
<> <>
@ -51,12 +51,14 @@ export default function ReadonlyClientView({
<div className="text-sm leading-tight">{data.IP}</div> <div className="text-sm leading-tight">{data.IP}</div>
</div> </div>
</div> </div>
{data.DebugMode === "full" && ( {auth?.authNeeded == AuthType.tailscale && (
<button <button
className="button button-blue ml-6" className="button button-blue ml-6"
onClick={() => { onClick={() => {
window.open(auth?.authUrl, "_blank") sessions
waitOnAuth() .new()
.then((url) => window.open(url, "_blank"))
.then(() => sessions.wait())
}} }}
> >
Access Access

@ -8,20 +8,23 @@ export enum AuthType {
export type AuthResponse = { export type AuthResponse = {
ok: boolean ok: boolean
authUrl?: string
authNeeded?: AuthType authNeeded?: AuthType
} }
export type SessionsCallbacks = {
new: () => Promise<string> // creates new auth session and returns authURL
wait: () => Promise<void> // blocks until auth is completed
}
// useAuth reports and refreshes Tailscale auth status // useAuth reports and refreshes Tailscale auth status
// for the web client. // for the web client.
export default function useAuth() { export default function useAuth() {
const [data, setData] = useState<AuthResponse>() const [data, setData] = useState<AuthResponse>()
const [loading, setLoading] = useState<boolean>(true) const [loading, setLoading] = useState<boolean>(true)
const loadAuth = useCallback((wait?: boolean) => { const loadAuth = useCallback(() => {
const url = wait ? "/auth?wait=true" : "/auth"
setLoading(true) setLoading(true)
return apiFetch(url, "GET") return apiFetch("/auth", "GET")
.then((r) => r.json()) .then((r) => r.json())
.then((d) => { .then((d) => {
setData(d) setData(d)
@ -44,11 +47,33 @@ export default function useAuth() {
}) })
}, []) }, [])
const newSession = useCallback(() => {
return apiFetch("/auth/session/new", "GET")
.then((r) => r.json())
.then((d) => d.authUrl)
.catch((error) => {
console.error(error)
})
}, [])
const waitForSessionCompletion = useCallback(() => {
return apiFetch("/auth/session/wait", "GET")
.then(() => loadAuth()) // refresh auth data
.catch((error) => {
console.error(error)
})
}, [])
useEffect(() => { useEffect(() => {
loadAuth() loadAuth()
}, []) }, [])
const waitOnAuth = useCallback(() => loadAuth(true), []) return {
data,
return { data, loading, waitOnAuth } loading,
sessions: {
new: newSession,
wait: waitForSessionCompletion,
},
}
} }

@ -203,8 +203,15 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (s *Server) serve(w http.ResponseWriter, r *http.Request) { func (s *Server) serve(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/api/") { if strings.HasPrefix(r.URL.Path, "/api/") {
if r.Method == httpm.GET && r.URL.Path == "/api/auth" { switch {
s.serveAPIAuth(w, r) case r.URL.Path == "/api/auth" && r.Method == httpm.GET:
s.serveAPIAuth(w, r) // serve auth status
return
case r.URL.Path == "/api/auth/session/new" && r.Method == httpm.GET:
s.serveAPIAuthSessionNew(w, r) // create new session
return
case r.URL.Path == "/api/auth/session/wait" && r.Method == httpm.GET:
s.serveAPIAuthSessionWait(w, r) // wait for session to be authorized
return return
} }
if ok := s.authorizeRequest(w, r); !ok { if ok := s.authorizeRequest(w, r); !ok {
@ -295,20 +302,15 @@ var (
type authResponse struct { type authResponse struct {
OK bool `json:"ok"` // true when user has valid auth session OK bool `json:"ok"` // true when user has valid auth session
AuthURL string `json:"authUrl,omitempty"` // filled when user has control auth action to take
AuthNeeded authType `json:"authNeeded,omitempty"` // filled when user needs to complete a specific type of auth AuthNeeded authType `json:"authNeeded,omitempty"` // filled when user needs to complete a specific type of auth
} }
// serverAPIAuth handles requests to the /api/auth endpoint // serverAPIAuth handles requests to the /api/auth endpoint
// and returns an authResponse indicating the current auth state and any steps the user needs to take. // and returns an authResponse indicating the current auth state and any steps the user needs to take.
func (s *Server) serveAPIAuth(w http.ResponseWriter, r *http.Request) { func (s *Server) serveAPIAuth(w http.ResponseWriter, r *http.Request) {
if r.Method != httpm.GET {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var resp authResponse var resp authResponse
session, whois, err := s.getSession(r) session, _, err := s.getSession(r)
switch { switch {
case err != nil && errors.Is(err, errNotUsingTailscale): case err != nil && errors.Is(err, errNotUsingTailscale):
// not using tailscale, so perform platform auth // not using tailscale, so perform platform auth
@ -328,12 +330,43 @@ func (s *Server) serveAPIAuth(w http.ResponseWriter, r *http.Request) {
default: default:
resp.OK = true // no additional auth for this distro resp.OK = true // no additional auth for this distro
} }
case err != nil && (errors.Is(err, errNotOwner) ||
errors.Is(err, errNotUsingTailscale) ||
errors.Is(err, errTaggedLocalSource) ||
errors.Is(err, errTaggedRemoteSource)):
// These cases are all restricted to the readonly view.
// No auth action to take.
resp = authResponse{OK: false}
case err != nil && !errors.Is(err, errNoSession): case err != nil && !errors.Is(err, errNoSession):
// Any other error.
http.Error(w, err.Error(), http.StatusInternalServerError)
return
case session.isAuthorized(s.timeNow()):
resp = authResponse{OK: true}
default:
resp = authResponse{OK: false, AuthNeeded: tailscaleAuth}
}
writeJSON(w, resp)
}
type newSessionAuthResponse struct {
AuthURL string `json:"authUrl,omitempty"`
}
// serveAPIAuthSessionNew handles requests to the /api/auth/session/new endpoint.
func (s *Server) serveAPIAuthSessionNew(w http.ResponseWriter, r *http.Request) {
session, whois, err := s.getSession(r)
if err != nil && !errors.Is(err, errNoSession) {
// Source associated with request not allowed to create
// a session for this web client.
http.Error(w, err.Error(), http.StatusUnauthorized) http.Error(w, err.Error(), http.StatusUnauthorized)
return return
case session == nil: }
if session == nil {
// Create a new session. // Create a new session.
session, err := s.newSession(r.Context(), whois) // If one already existed, we return that authURL rather than creating a new one.
session, err = s.newSession(r.Context(), whois)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@ -346,29 +379,25 @@ func (s *Server) serveAPIAuth(w http.ResponseWriter, r *http.Request) {
Path: "/", Path: "/",
Expires: session.expires(), Expires: session.expires(),
}) })
resp = authResponse{OK: false, AuthURL: session.AuthURL}
case !session.isAuthorized(s.timeNow()):
if r.URL.Query().Get("wait") == "true" {
// Client requested we block until user completes auth.
if err := s.awaitUserAuth(r.Context(), session); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
}
if session.isAuthorized(s.timeNow()) {
resp = authResponse{OK: true}
} else {
resp = authResponse{OK: false, AuthURL: session.AuthURL}
}
default:
resp = authResponse{OK: true}
} }
if err := json.NewEncoder(w).Encode(resp); err != nil { writeJSON(w, newSessionAuthResponse{AuthURL: session.AuthURL})
http.Error(w, err.Error(), http.StatusInternalServerError) }
// serveAPIAuthSessionWait handles requests to the /api/auth/session/wait endpoint.
func (s *Server) serveAPIAuthSessionWait(w http.ResponseWriter, r *http.Request) {
session, _, err := s.getSession(r)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
if session.isAuthorized(s.timeNow()) {
return // already authorized
}
if err := s.awaitUserAuth(r.Context(), session); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return return
} }
w.Header().Set("Content-Type", "application/json")
} }
// serveAPI serves requests for the web client api. // serveAPI serves requests for the web client api.
@ -458,11 +487,7 @@ func (s *Server) serveGetNodeData(w http.ResponseWriter, r *http.Request) {
if len(st.TailscaleIPs) != 0 { if len(st.TailscaleIPs) != 0 {
data.IP = st.TailscaleIPs[0].String() data.IP = st.TailscaleIPs[0].String()
} }
if err := json.NewEncoder(w).Encode(*data); err != nil { writeJSON(w, *data)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
} }
type nodeUpdate struct { type nodeUpdate struct {
@ -720,3 +745,12 @@ func enforcePrefix(prefix string, h http.HandlerFunc) http.HandlerFunc {
http.StripPrefix(prefix, h).ServeHTTP(w, r) http.StripPrefix(prefix, h).ServeHTTP(w, r)
} }
} }
func writeJSON(w http.ResponseWriter, data any) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(data); err != nil {
w.Header().Set("Content-Type", "text/plain")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}

@ -408,7 +408,7 @@ func TestAuthorizeRequest(t *testing.T) {
} }
} }
func TestServeTailscaleAuth(t *testing.T) { func TestServeAuth(t *testing.T) {
user := &tailcfg.UserProfile{ID: tailcfg.UserID(1)} user := &tailcfg.UserProfile{ID: tailcfg.UserID(1)}
self := &ipnstate.PeerStatus{ID: "self", UserID: user.ID} self := &ipnstate.PeerStatus{ID: "self", UserID: user.ID}
remoteNode := &apitype.WhoIsResponse{Node: &tailcfg.Node{ID: 1}, UserProfile: user} remoteNode := &apitype.WhoIsResponse{Node: &tailcfg.Node{ID: 1}, UserProfile: user}
@ -462,18 +462,29 @@ func TestServeTailscaleAuth(t *testing.T) {
}) })
tests := []struct { tests := []struct {
name string name string
cookie string
query string cookie string // cookie attached to request
wantStatus int wantNewCookie bool // want new cookie generated during request
wantResp *authResponse wantSession *browserSession // session associated w/ cookie after request
wantNewCookie bool // new cookie generated
wantSession *browserSession // session associated w/ cookie at end of request path string
wantStatus int
wantResp any
}{ }{
{ {
name: "new-session-created", name: "no-session",
path: "/api/auth",
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantResp: &authResponse{OK: false, AuthURL: testControlURL + testAuthPath}, wantResp: &authResponse{OK: false, AuthNeeded: tailscaleAuth},
wantNewCookie: false,
wantSession: nil,
},
{
name: "new-session",
path: "/api/auth/session/new",
wantStatus: http.StatusOK,
wantResp: &newSessionAuthResponse{AuthURL: testControlURL + testAuthPath},
wantNewCookie: true, wantNewCookie: true,
wantSession: &browserSession{ wantSession: &browserSession{
ID: "GENERATED_ID", // gets swapped for newly created ID by test ID: "GENERATED_ID", // gets swapped for newly created ID by test
@ -487,9 +498,10 @@ func TestServeTailscaleAuth(t *testing.T) {
}, },
{ {
name: "query-existing-incomplete-session", name: "query-existing-incomplete-session",
path: "/api/auth",
cookie: successCookie, cookie: successCookie,
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantResp: &authResponse{OK: false, AuthURL: testControlURL + testAuthPathSuccess}, wantResp: &authResponse{OK: false, AuthNeeded: tailscaleAuth},
wantSession: &browserSession{ wantSession: &browserSession{
ID: successCookie, ID: successCookie,
SrcNode: remoteNode.Node.ID, SrcNode: remoteNode.Node.ID,
@ -501,13 +513,27 @@ func TestServeTailscaleAuth(t *testing.T) {
}, },
}, },
{ {
name: "transition-to-successful-session", name: "existing-session-used",
cookie: successCookie, path: "/api/auth/session/new", // should not create new session
// query "wait" indicates the FE wants to make cookie: successCookie,
// local api call to wait until session completed.
query: "wait=true",
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantResp: &authResponse{OK: true}, wantResp: &newSessionAuthResponse{AuthURL: testControlURL + testAuthPathSuccess},
wantSession: &browserSession{
ID: successCookie,
SrcNode: remoteNode.Node.ID,
SrcUser: user.ID,
Created: oneHourAgo,
AuthID: testAuthPathSuccess,
AuthURL: testControlURL + testAuthPathSuccess,
Authenticated: false,
},
},
{
name: "transition-to-successful-session",
path: "/api/auth/session/wait",
cookie: successCookie,
wantStatus: http.StatusOK,
wantResp: nil,
wantSession: &browserSession{ wantSession: &browserSession{
ID: successCookie, ID: successCookie,
SrcNode: remoteNode.Node.ID, SrcNode: remoteNode.Node.ID,
@ -520,6 +546,7 @@ func TestServeTailscaleAuth(t *testing.T) {
}, },
{ {
name: "query-existing-complete-session", name: "query-existing-complete-session",
path: "/api/auth",
cookie: successCookie, cookie: successCookie,
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantResp: &authResponse{OK: true}, wantResp: &authResponse{OK: true},
@ -535,17 +562,18 @@ func TestServeTailscaleAuth(t *testing.T) {
}, },
{ {
name: "transition-to-failed-session", name: "transition-to-failed-session",
path: "/api/auth/session/wait",
cookie: failureCookie, cookie: failureCookie,
query: "wait=true",
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
wantResp: nil, wantResp: nil,
wantSession: nil, // session deleted wantSession: nil, // session deleted
}, },
{ {
name: "failed-session-cleaned-up", name: "failed-session-cleaned-up",
path: "/api/auth/session/new",
cookie: failureCookie, cookie: failureCookie,
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantResp: &authResponse{OK: false, AuthURL: testControlURL + testAuthPath}, wantResp: &newSessionAuthResponse{AuthURL: testControlURL + testAuthPath},
wantNewCookie: true, wantNewCookie: true,
wantSession: &browserSession{ wantSession: &browserSession{
ID: "GENERATED_ID", ID: "GENERATED_ID",
@ -559,9 +587,10 @@ func TestServeTailscaleAuth(t *testing.T) {
}, },
{ {
name: "expired-cookie-gets-new-session", name: "expired-cookie-gets-new-session",
path: "/api/auth/session/new",
cookie: expiredCookie, cookie: expiredCookie,
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantResp: &authResponse{OK: false, AuthURL: testControlURL + testAuthPath}, wantResp: &newSessionAuthResponse{AuthURL: testControlURL + testAuthPath},
wantNewCookie: true, wantNewCookie: true,
wantSession: &browserSession{ wantSession: &browserSession{
ID: "GENERATED_ID", ID: "GENERATED_ID",
@ -576,12 +605,11 @@ func TestServeTailscaleAuth(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", "/api/auth", nil) r := httptest.NewRequest("GET", tt.path, nil)
r.URL.RawQuery = tt.query
r.RemoteAddr = remoteIP r.RemoteAddr = remoteIP
r.AddCookie(&http.Cookie{Name: sessionCookieName, Value: tt.cookie}) r.AddCookie(&http.Cookie{Name: sessionCookieName, Value: tt.cookie})
w := httptest.NewRecorder() w := httptest.NewRecorder()
s.serveAPIAuth(w, r) s.serve(w, r)
res := w.Result() res := w.Result()
defer res.Body.Close() defer res.Body.Close()
@ -589,17 +617,20 @@ func TestServeTailscaleAuth(t *testing.T) {
if gotStatus := res.StatusCode; tt.wantStatus != gotStatus { if gotStatus := res.StatusCode; tt.wantStatus != gotStatus {
t.Errorf("wrong status; want=%v, got=%v", tt.wantStatus, gotStatus) t.Errorf("wrong status; want=%v, got=%v", tt.wantStatus, gotStatus)
} }
var gotResp *authResponse var gotResp string
if res.StatusCode == http.StatusOK { if res.StatusCode == http.StatusOK {
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := json.Unmarshal(body, &gotResp); err != nil { gotResp = strings.Trim(string(body), "\n")
t.Fatal(err) }
} var wantResp string
if tt.wantResp != nil {
b, _ := json.Marshal(tt.wantResp)
wantResp = string(b)
} }
if diff := cmp.Diff(gotResp, tt.wantResp); diff != "" { if diff := cmp.Diff(gotResp, string(wantResp)); diff != "" {
t.Errorf("wrong response; (-got+want):%v", diff) t.Errorf("wrong response; (-got+want):%v", diff)
} }
// Validate cookie creation. // Validate cookie creation.
@ -654,22 +685,13 @@ func mockLocalAPI(t *testing.T, whoIs map[string]*apitype.WhoIsResponse, self fu
t.Fatalf("/whois call missing \"addr\" query") t.Fatalf("/whois call missing \"addr\" query")
} }
if node := whoIs[addr]; node != nil { if node := whoIs[addr]; node != nil {
if err := json.NewEncoder(w).Encode(&node); err != nil { writeJSON(w, &node)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
return return
} }
http.Error(w, "not a node", http.StatusUnauthorized) http.Error(w, "not a node", http.StatusUnauthorized)
return return
case "/localapi/v0/status": case "/localapi/v0/status":
status := ipnstate.Status{Self: self()} writeJSON(w, ipnstate.Status{Self: self()})
if err := json.NewEncoder(w).Encode(status); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
return return
case "/localapi/v0/debug-web-client": // used by TestServeTailscaleAuth case "/localapi/v0/debug-web-client": // used by TestServeTailscaleAuth
type reqData struct { type reqData struct {
@ -694,11 +716,7 @@ func mockLocalAPI(t *testing.T, whoIs map[string]*apitype.WhoIsResponse, self fu
http.Error(w, "authenticated as wrong user", http.StatusUnauthorized) http.Error(w, "authenticated as wrong user", http.StatusUnauthorized)
return return
} }
if err := json.NewEncoder(w).Encode(resp); err != nil { writeJSON(w, resp)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
return return
default: default:
t.Fatalf("unhandled localapi test endpoint %q, add to localapi handler func in test", r.URL.Path) t.Fatalf("unhandled localapi test endpoint %q, add to localapi handler func in test", r.URL.Path)

Loading…
Cancel
Save