// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Package tka (WIP) implements the Tailnet Key Authority. package tka import ( "bytes" "crypto/ed25519" "errors" "fmt" "os" "sort" ) // Authority is a Tailnet Key Authority. This type is the main coupling // point to the rest of the tailscale client. // // Authority objects can either be created from an existing, non-empty // tailchonk (via tka.Open()), or created from scratch using tka.Bootstrap() // or tka.Create(). type Authority struct { head AUM oldestAncestor AUM state State storage Chonk } // A chain describes a linear sequence of updates from Oldest to Head, // resulting in some State at Head. type chain struct { Oldest AUM Head AUM state State // Set to true if the AUM chain intersects with the active // chain from a previous run. chainsThroughActive bool } // computeChainCandidates returns all possible chains based on AUMs stored // in the given tailchonk. A chain is defined as a unique (oldest, newest) // AUM tuple. chain.state is not yet populated in returned chains. // // If lastKnownOldest is provided, any chain that includes the given AUM // has the chainsThroughActive field set to true. This bit is leveraged // in computeActiveAncestor() to filter out irrelevant chains when determining // the active ancestor from a list of distinct chains. func computeChainCandidates(storage Chonk, lastKnownOldest *AUMHash, maxIter int) ([]chain, error) { heads, err := storage.Heads() if err != nil { return nil, fmt.Errorf("reading heads: %v", err) } candidates := make([]chain, len(heads)) for i := range heads { // Oldest is iteratively computed below. candidates[i] = chain{Oldest: heads[i], Head: heads[i]} } // Not strictly necessary, but simplifies checks in tests. sort.Slice(candidates, func(i, j int) bool { ih, jh := candidates[i].Oldest.Hash(), candidates[j].Oldest.Hash() return bytes.Compare(ih[:], jh[:]) < 0 }) // candidates.Oldest needs to be computed by working backwards from // head as far as we can. iterAgain := true // if theres still work to be done. for i := 0; iterAgain; i++ { if i >= maxIter { return nil, fmt.Errorf("iteration limit exceeded (%d)", maxIter) } iterAgain = false for j := range candidates { parent, hasParent := candidates[j].Oldest.Parent() if hasParent { parent, err := storage.AUM(parent) if err != nil { if err == os.ErrNotExist { continue } return nil, fmt.Errorf("reading parent: %v", err) } candidates[j].Oldest = parent if lastKnownOldest != nil && *lastKnownOldest == parent.Hash() { candidates[j].chainsThroughActive = true } iterAgain = true } } } return candidates, nil } // pickNextAUM returns the AUM which should be used as the next // AUM in the chain, possibly applying fork resolution logic. // // In other words: given an AUM with 3 children like this: // / - 1 // P - 2 // \ - 3 // // pickNextAUM will determine and return the correct branch. // // This method takes ownership of the provided slice. func pickNextAUM(state State, candidates []AUM) AUM { switch len(candidates) { case 0: panic("pickNextAUM called with empty candidate set") case 1: return candidates[0] } // Oooof, we have some forks in the chain. We need to pick which // one to use by applying the Fork Resolution Algorithm ✨ // // The rules are this: // 1. The child with the highest signature weight is chosen. // 2. If equal, the child which is a RemoveKey AUM is chosen. // 3. If equal, the child with the lowest AUM hash is chosen. sort.Slice(candidates, func(j, i int) bool { // Rule 1. iSigWeight, jSigWeight := candidates[i].Weight(state), candidates[j].Weight(state) if iSigWeight != jSigWeight { return iSigWeight < jSigWeight } // Rule 2. if iKind, jKind := candidates[i].MessageKind, candidates[j].MessageKind; iKind != jKind && (iKind == AUMRemoveKey || jKind == AUMRemoveKey) { return jKind == AUMRemoveKey } // Rule 3. iHash, jHash := candidates[i].Hash(), candidates[j].Hash() return bytes.Compare(iHash[:], jHash[:]) > 0 }) return candidates[0] } // advanceChain computes the next AUM to advance with based on all child // AUMs, returning the chosen AUM & the state obtained by applying that // AUM. // // The return value for next is nil if there are no children AUMs, hence // the provided state is at head (up to date). func advanceChain(state State, candidates []AUM) (next *AUM, out State, err error) { if len(candidates) == 0 { return nil, state, nil } aum := pickNextAUM(state, candidates) if state, err = state.applyVerifiedAUM(aum); err != nil { return nil, State{}, fmt.Errorf("advancing state: %v", err) } return &aum, state, nil } // fastForward iteratively advances the current state based on known AUMs until // the given termination function returns true or there is no more progress possible. // // The last-processed AUM, and the state computed after applying the last AUM, // are returned. func fastForward(storage Chonk, maxIter int, startState State, done func(curAUM AUM, curState State) bool) (AUM, State, error) { if startState.LastAUMHash == nil { return AUM{}, State{}, errors.New("invalid initial state") } nextAUM, err := storage.AUM(*startState.LastAUMHash) if err != nil { return AUM{}, State{}, fmt.Errorf("reading next: %v", err) } curs := nextAUM state := startState for i := 0; i < maxIter; i++ { if done != nil && done(curs, state) { return curs, state, nil } children, err := storage.ChildAUMs(curs.Hash()) if err != nil { return AUM{}, State{}, fmt.Errorf("getting children of %X: %v", curs.Hash(), err) } next, nextState, err := advanceChain(state, children) if err != nil { return AUM{}, State{}, fmt.Errorf("advance %X: %v", curs.Hash(), err) } if next == nil { // There were no more children, we are at 'head'. return curs, state, nil } curs = *next state = nextState } return AUM{}, State{}, fmt.Errorf("iteration limit exceeded (%d)", maxIter) } // computeStateAt returns the State at wantHash. func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) { // TODO(tom): This is going to get expensive for really long // chains. We should make nodes emit a checkpoint every // X updates or something. topAUM, err := storage.AUM(wantHash) if err != nil { return State{}, err } // Iterate backwards till we find a starting point to compute // the state from. // // Valid starting points are either a checkpoint AUM, or a // genesis AUM. curs := topAUM var state State for i := 0; true; i++ { if i > maxIter { return State{}, fmt.Errorf("iteration limit exceeded (%d)", maxIter) } // Checkpoints encapsulate the state at that point, dope. if curs.MessageKind == AUMCheckpoint { state = curs.State.cloneForUpdate(&curs) break } parent, hasParent := curs.Parent() if !hasParent { // This is a 'genesis' update: there are none before it, so // this AUM can be applied to the empty state to determine // the state at this AUM. // // It is only valid for NoOp, AddKey, and Checkpoint AUMs // to be a genesis update. Checkpoint was handled earlier. if mk := curs.MessageKind; mk == AUMNoOp || mk == AUMAddKey { var err error if state, err = (State{}).applyVerifiedAUM(curs); err != nil { return State{}, fmt.Errorf("applying genesis (%+v): %v", curs, err) } break } return State{}, fmt.Errorf("invalid genesis update: %+v", curs) } // If we got here, the current state is dependent on the previous. // Keep iterating backwards till thats not the case. if curs, err = storage.AUM(parent); err != nil { return State{}, fmt.Errorf("reading parent: %v", err) } } // We now know some starting point state. Iterate forward till we // are at the AUM we want state for. _, state, err = fastForward(storage, maxIter, state, func(curs AUM, _ State) bool { return curs.Hash() == wantHash }) // fastForward only terminates before the done condition if it // doesnt have any later AUMs to process. This cant be the case // as we've already iterated through them above so they must exist, // but we check anyway to be super duper sure. if err == nil && *state.LastAUMHash != wantHash { panic("unexpected fastForward outcome") } return state, err } // computeActiveAncestor determines which ancestor AUM to use as the // ancestor of the valid chain. // // If all the chains end up having the same ancestor, then thats the // only possible ancestor, ezpz. However if there are multiple distinct // ancestors, that means there are distinct chains, and we need some // hint to choose what to use. For that, we rely on the chainsThroughActive // bit, which signals to us that that ancestor was part of the // chain in a previous run. func computeActiveAncestor(storage Chonk, chains []chain) (AUMHash, error) { // Dedupe possible ancestors, tracking if they were part of // the active chain on a previous run. ancestors := make(map[AUMHash]bool, len(chains)) for _, c := range chains { ancestors[c.Oldest.Hash()] = c.chainsThroughActive } if len(ancestors) == 1 { // There's only one. DOPE. for k, _ := range ancestors { return k, nil } } // Theres more than one, so we need to use the ancestor that was // part of the active chain in a previous iteration. // Note that there can only be one distinct ancestor that was // formerly part of the active chain, because AUMs can only have // one parent and would have converged to a common ancestor. for k, chainsThroughActive := range ancestors { if chainsThroughActive { return k, nil } } return AUMHash{}, errors.New("multiple distinct chains") } // computeActiveChain bootstraps the runtime state of the Authority when // starting entirely off stored state. // // TODO(tom): Don't look at head states, just iterate forward from // the ancestor. // // The algorithm is as follows: // 1. Determine all possible 'head' (like in git) states. // 2. Filter these possible chains based on whether the ancestor was // formerly (in a previous run) part of the chain. // 3. Compute the state of the state machine at this ancestor. This is // needed for fast-forward, as each update operates on the state of // the update preceeding it. // 4. Iteratively apply updates till we reach head ('fast forward'). func computeActiveChain(storage Chonk, lastKnownOldest *AUMHash, maxIter int) (chain, error) { chains, err := computeChainCandidates(storage, lastKnownOldest, maxIter) if err != nil { return chain{}, fmt.Errorf("computing candidates: %v", err) } // Find the right ancestor. oldestHash, err := computeActiveAncestor(storage, chains) if err != nil { return chain{}, fmt.Errorf("computing ancestor: %v", err) } ancestor, err := storage.AUM(oldestHash) if err != nil { return chain{}, err } // At this stage we know the ancestor AUM, so we have excluded distinct // chains but we might still have forks (so we don't know the head AUM). // // We iterate forward from the ancestor AUM, handling any forks as we go // till we arrive at a head. out := chain{Oldest: ancestor, Head: ancestor} if out.state, err = computeStateAt(storage, maxIter, oldestHash); err != nil { return chain{}, fmt.Errorf("bootstrapping state: %v", err) } out.Head, out.state, err = fastForward(storage, maxIter, out.state, nil) if err != nil { return chain{}, fmt.Errorf("fast forward: %v", err) } return out, nil } // aumVerify verifies if an AUM is well-formed, correctly signed, and // can be accepted for storage. func aumVerify(aum AUM, state State, isGenesisAUM bool) error { if err := aum.StaticValidate(); err != nil { return fmt.Errorf("invalid: %v", err) } if !isGenesisAUM { if err := checkParent(aum, state); err != nil { return err } } if len(aum.Signatures) == 0 { return errors.New("unsigned AUM") } sigHash := aum.SigHash() for i, sig := range aum.Signatures { key, err := state.GetKey(sig.KeyID) if err != nil { return fmt.Errorf("bad keyID on signature %d: %v", i, err) } if err := sig.Verify(sigHash, key); err != nil { return fmt.Errorf("signature %d: %v", i, err) } } return nil } func checkParent(aum AUM, state State) error { parent, hasParent := aum.Parent() if !hasParent { return errors.New("aum has no parent") } if state.LastAUMHash == nil { return errors.New("cannot check update parent hash against a state with no previous AUM") } if *state.LastAUMHash != parent { return fmt.Errorf("aum with parent %x cannot be applied to a state with parent %x", state.LastAUMHash, parent) } return nil } // Head returns the AUM digest of the latest update applied to the state // machine. func (a *Authority) Head() AUMHash { return *a.state.LastAUMHash } // Open initializes an existing TKA from the given tailchonk. // // Only use this if the current node has initialized an Authority before. // If a TKA exists on other nodes but theres nothing locally, use Bootstrap(). // If no TKA exists anywhere and you are creating it for the first // time, use New(). func Open(storage Chonk) (*Authority, error) { a, err := storage.LastActiveAncestor() if err != nil { return nil, fmt.Errorf("reading last ancestor: %v", err) } c, err := computeActiveChain(storage, a, 2000) if err != nil { return nil, fmt.Errorf("active chain: %v", err) } return &Authority{ head: c.Head, oldestAncestor: c.Oldest, storage: storage, state: c.state, }, nil } // Create initializes a brand-new TKA, generating a genesis update // and committing it to the given storage. // // The given signer must also be present in state as a trusted key. // // Do not use this to initialize a TKA that already exists, use Open() // or Bootstrap() instead. func Create(storage Chonk, state State, signer ed25519.PrivateKey) (*Authority, AUM, error) { // Generate & sign a checkpoint, our genesis update. genesis := AUM{ MessageKind: AUMCheckpoint, State: &state, } if err := genesis.StaticValidate(); err != nil { // This serves as an easy way to validate the given state. return nil, AUM{}, fmt.Errorf("invalid state: %v", err) } genesis.sign25519(signer) a, err := Bootstrap(storage, genesis) return a, genesis, err } // Bootstrap initializes a TKA based on the given checkpoint. // // Call this when setting up a new nodes' TKA, but other nodes // with initialized TKA's exist. // // Pass the returned genesis AUM from Create(), or a later checkpoint AUM. // // TODO(tom): We should test an authority bootstrapped from a later checkpoint // works fine with sync and everything. func Bootstrap(storage Chonk, bootstrap AUM) (*Authority, error) { heads, err := storage.Heads() if err != nil { return nil, fmt.Errorf("reading heads: %v", err) } if len(heads) != 0 { return nil, errors.New("tailchonk is not empty") } // Check the AUM is well-formed. if bootstrap.MessageKind != AUMCheckpoint { return nil, fmt.Errorf("bootstrap AUMs must be checkpoint messages, got %v", bootstrap.MessageKind) } if bootstrap.State == nil { return nil, errors.New("bootstrap AUM is missing state") } if err := aumVerify(bootstrap, *bootstrap.State, true); err != nil { return nil, fmt.Errorf("invalid bootstrap: %v", err) } // Everything looks good, write it to storage. if err := storage.CommitVerifiedAUMs([]AUM{bootstrap}); err != nil { return nil, fmt.Errorf("commit: %v", err) } if err := storage.SetLastActiveAncestor(bootstrap.Hash()); err != nil { return nil, fmt.Errorf("set ancestor: %v", err) } return Open(storage) } // Inform is called to tell the authority about new updates. Updates // should be ordered oldest to newest. An error is returned if any // of the updates could not be processed. func (a *Authority) Inform(updates []AUM) error { stateAt := make(map[AUMHash]State, len(updates)+1) toCommit := make([]AUM, 0, len(updates)) for i, update := range updates { hash := update.Hash() if _, err := a.storage.AUM(hash); err == nil { // Already have this AUM. continue } parent, hasParent := update.Parent() if !hasParent { return fmt.Errorf("update %d: missing parent", i) } state, hasState := stateAt[parent] var err error if !hasState { if state, err = computeStateAt(a.storage, 2000, parent); err != nil { return fmt.Errorf("update %d computing state: %v", i, err) } stateAt[parent] = state } if err := aumVerify(update, state, false); err != nil { return fmt.Errorf("update %d invalid: %v", i, err) } if stateAt[hash], err = state.applyVerifiedAUM(update); err != nil { return fmt.Errorf("update %d cannot be applied: %v", i, err) } toCommit = append(toCommit, update) } if err := a.storage.CommitVerifiedAUMs(toCommit); err != nil { return fmt.Errorf("commit: %v", err) } // TODO(tom): Theres no need to recompute the state from scratch // in every case. We should detect when updates were // a linear, non-forking series applied to head, and // just use the last State we computed. oldestAncestor := a.oldestAncestor.Hash() c, err := computeActiveChain(a.storage, &oldestAncestor, 2000) if err != nil { return fmt.Errorf("recomputing active chain: %v", err) } a.head = c.Head a.oldestAncestor = c.Oldest a.state = c.state return nil }