From 4f1374ec9eec11cf474a6c33ff22ee268d9ec136 Mon Sep 17 00:00:00 2001 From: Tom DNetto Date: Mon, 11 Jul 2022 11:28:18 -0700 Subject: [PATCH] tka: implement consensus & state computation internals Signed-off-by: Tom DNetto --- tka/chaintest_test.go | 365 ++++++++++++++++++++++++++++++++++++++++++ tka/tka.go | 332 ++++++++++++++++++++++++++++++++++++++ tka/tka_test.go | 187 ++++++++++++++++++++++ 3 files changed, 884 insertions(+) create mode 100644 tka/chaintest_test.go create mode 100644 tka/tka_test.go diff --git a/tka/chaintest_test.go b/tka/chaintest_test.go new file mode 100644 index 000000000..3b3ce4c4c --- /dev/null +++ b/tka/chaintest_test.go @@ -0,0 +1,365 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tka + +import ( + "bytes" + "crypto/ed25519" + "fmt" + "strconv" + "strings" + "testing" + "text/scanner" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +// chaintest_test.go implements test helpers for concisely describing +// chains of possibly signed AUMs, to assist in making tests shorter and +// easier to read. + +// parsed representation of a named AUM in a test chain. +type testchainNode struct { + Name string + Parent string + Uses []scanner.Position + + HashSeed int + Template string + SignedWith string +} + +// testChain represents a constructed web of AUMs for testing purposes. +type testChain struct { + Nodes map[string]*testchainNode + AUMs map[string]AUM + AUMHashes map[string]AUMHash + + // Configured by options to NewTestchain() + Template map[string]AUM + Key map[string]*Key + KeyPrivs map[string]ed25519.PrivateKey + SignAllKeys []string +} + +// newTestchain constructs a web of AUMs based on the provided input and +// options. +// +// Input is expected to be a graph & tweaks, looking like this: +// +// G1 -> A -> B +// | -> C +// +// which defines AUMs G1, A, B, and C; with G1 having no parent, A having +// G1 as a parent, and both B & C having A as a parent. +// +// Tweaks are specified like this: +// +// . = +// +// for example: G1.hashSeed = 2 +// +// There are 3 available tweaks: +// - hashSeed: Set to an integer to tweak the AUM hash of that AUM. +// - template: Set to the name of a template provided via optTemplate(). +// The template is copied and use as the content for that AUM. +// - signedWith: Set to the name of a key provided via optKey(). This +// key is used to sign that AUM. +func newTestchain(t *testing.T, input string, options ...testchainOpt) *testChain { + t.Helper() + + var ( + s scanner.Scanner + out = testChain{ + Nodes: map[string]*testchainNode{}, + Template: map[string]AUM{}, + Key: map[string]*Key{}, + KeyPrivs: map[string]ed25519.PrivateKey{}, + } + ) + + // Process any options + for _, o := range options { + if o.Template != nil { + out.Template[o.Name] = *o.Template + } + if o.Key != nil { + out.Key[o.Name] = o.Key + out.KeyPrivs[o.Name] = o.Private + } + if o.SignAllWith { + out.SignAllKeys = append(out.SignAllKeys, o.Name) + } + } + + s.Init(strings.NewReader(input)) + s.Mode = scanner.ScanIdents | scanner.SkipComments | scanner.ScanComments | scanner.ScanChars | scanner.ScanInts + s.Whitespace ^= 1 << '\t' // clear tabs + var ( + lastIdent string + lastWasChain bool // if the last token was '->' + ) + for tok := s.Scan(); tok != scanner.EOF; tok = s.Scan() { + switch tok { + case '\t': + t.Fatalf("tabs disallowed, use spaces (seen at %v)", s.Pos()) + + case '.': // tweaks, like .hashSeed = + s.Scan() + tweak := s.TokenText() + if tok := s.Scan(); tok == '=' { + s.Scan() + switch tweak { + case "hashSeed": + out.Nodes[lastIdent].HashSeed, _ = strconv.Atoi(s.TokenText()) + case "template": + out.Nodes[lastIdent].Template = s.TokenText() + case "signedWith": + out.Nodes[lastIdent].SignedWith = s.TokenText() + } + } + + case scanner.Ident: + out.recordPos(s.TokenText(), s.Pos()) + // If the last token was '->', that means + // that the next identifier has a child relationship + // with the identifier preceeding '->'. + if lastWasChain { + out.recordParent(t, s.TokenText(), lastIdent) + } + lastIdent = s.TokenText() + + case '-': // handle '->' + switch s.Peek() { + case '>': + s.Scan() + lastWasChain = true + continue + } + + case '|': // handle '|' + line, col := s.Pos().Line, s.Pos().Column + nodeLoop: + for _, n := range out.Nodes { + for _, p := range n.Uses { + // Find the identifier used right here on the line above. + if p.Line == line-1 && col <= p.Column && col > p.Column-len(n.Name) { + lastIdent = n.Name + out.recordPos(n.Name, s.Pos()) + break nodeLoop + } + } + } + } + lastWasChain = false + // t.Logf("tok = %v, %q", tok, s.TokenText()) + } + + out.buildChain() + return &out +} + +// called from the parser to record the location of an +// identifier (a named AUM). +func (c *testChain) recordPos(ident string, pos scanner.Position) { + n := c.Nodes[ident] + if n == nil { + n = &testchainNode{Name: ident} + } + + n.Uses = append(n.Uses, pos) + c.Nodes[ident] = n +} + +// called from the parser to record a parent relationship between +// two AUMs. +func (c *testChain) recordParent(t *testing.T, child, parent string) { + if p := c.Nodes[child].Parent; p != "" && p != parent { + t.Fatalf("differing parent specified for %s: %q != %q", child, p, parent) + } + c.Nodes[child].Parent = parent +} + +// called after parsing to build the web of AUM structures. +// This method populates c.AUMs and c.AUMHashes. +func (c *testChain) buildChain() { + pending := make(map[string]*testchainNode, len(c.Nodes)) + for k, v := range c.Nodes { + pending[k] = v + } + + // AUMs with a parent need to know their hash, so we + // only compute AUMs who's parents have been computed + // each iteration. Since at least the genesis AUM + // had no parent, theres always a path to completion + // in O(n+1) where n is the number of AUMs. + c.AUMs = make(map[string]AUM, len(c.Nodes)) + c.AUMHashes = make(map[string]AUMHash, len(c.Nodes)) + for i := 0; i < len(c.Nodes)+1; i++ { + if len(pending) == 0 { + return + } + + next := make([]*testchainNode, 0, 10) + for _, v := range pending { + if _, parentPending := pending[v.Parent]; !parentPending { + next = append(next, v) + } + } + + for _, v := range next { + aum := c.makeAUM(v) + h := aum.Hash() + + c.AUMHashes[v.Name] = h + c.AUMs[v.Name] = aum + delete(pending, v.Name) + } + } + panic("unexpected: incomplete despite len(Nodes)+1 iterations") +} + +func (c *testChain) makeAUM(v *testchainNode) AUM { + // By default, the AUM used is just a no-op AUM + // with a parent hash set (if any). + // + // If .template is set to the same name as in + // a provided optTemplate(), the AUM is built + // from a copy of that instead. + // + // If .hashSeed = is set, the KeyID is + // tweaked to effect tweaking the hash. This is useful + // if you want one AUM to have a lower hash than another. + aum := AUM{MessageKind: AUMNoOp} + if template := v.Template; template != "" { + aum = c.Template[template] + } + if v.Parent != "" { + parentHash := c.AUMHashes[v.Parent] + aum.PrevAUMHash = parentHash[:] + } + if seed := v.HashSeed; seed != 0 { + aum.KeyID = []byte{byte(seed)} + } + if err := aum.StaticValidate(); err != nil { + // Usually caused by a test writer specifying a template + // AUM which is ultimately invalid. + panic(fmt.Sprintf("aum %+v failed static validation: %v", aum, err)) + } + + sigHash := aum.SigHash() + for _, key := range c.SignAllKeys { + aum.Signatures = append(aum.Signatures, Signature{ + KeyID: c.Key[key].ID(), + Signature: ed25519.Sign(c.KeyPrivs[key], sigHash[:]), + }) + } + + // If the aum was specified as being signed by some key, then + // sign it using that key. + if key := v.SignedWith; key != "" { + aum.Signatures = append(aum.Signatures, Signature{ + KeyID: c.Key[key].ID(), + Signature: ed25519.Sign(c.KeyPrivs[key], sigHash[:]), + }) + } + + return aum +} + +// Chonk returns a tailchonk containing all AUMs. +func (c *testChain) Chonk() Chonk { + var out Mem + for _, update := range c.AUMs { + if err := out.CommitVerifiedAUMs([]AUM{update}); err != nil { + panic(err) + } + } + return &out +} + +// ChonkWith returns a tailchonk containing the named AUMs. +func (c *testChain) ChonkWith(names ...string) Chonk { + var out Mem + for _, name := range names { + update := c.AUMs[name] + if err := out.CommitVerifiedAUMs([]AUM{update}); err != nil { + panic(err) + } + } + return &out +} + +type testchainOpt struct { + Name string + Template *AUM + Key *Key + Private ed25519.PrivateKey + SignAllWith bool +} + +func optTemplate(name string, template AUM) testchainOpt { + return testchainOpt{ + Name: name, + Template: &template, + } +} + +func optKey(name string, key Key, priv ed25519.PrivateKey) testchainOpt { + return testchainOpt{ + Name: name, + Key: &key, + Private: priv, + } +} + +func optSignAllUsing(keyName string) testchainOpt { + return testchainOpt{ + Name: keyName, + SignAllWith: true, + } +} + +func TestNewTestchain(t *testing.T) { + c := newTestchain(t, ` + genesis -> B -> C + | -> D + | -> E -> F + + E.hashSeed = 12 // tweak E to have the lowest hash so its chosen + F.template = test + `, optTemplate("test", AUM{MessageKind: AUMNoOp, KeyID: []byte{10}})) + + want := map[string]*testchainNode{ + "genesis": &testchainNode{Name: "genesis", Uses: []scanner.Position{{Line: 2, Column: 16}}}, + "B": &testchainNode{ + Name: "B", + Parent: "genesis", + Uses: []scanner.Position{{Line: 2, Column: 21}, {Line: 3, Column: 21}, {Line: 4, Column: 21}}, + }, + "C": &testchainNode{Name: "C", Parent: "B", Uses: []scanner.Position{{Line: 2, Column: 26}}}, + "D": &testchainNode{Name: "D", Parent: "B", Uses: []scanner.Position{{Line: 3, Column: 26}}}, + "E": &testchainNode{Name: "E", Parent: "B", HashSeed: 12, Uses: []scanner.Position{{Line: 4, Column: 26}, {Line: 6, Column: 10}}}, + "F": &testchainNode{Name: "F", Parent: "E", Template: "test", Uses: []scanner.Position{{Line: 4, Column: 31}, {Line: 7, Column: 10}}}, + } + + if diff := cmp.Diff(want, c.Nodes, cmpopts.IgnoreFields(scanner.Position{}, "Offset")); diff != "" { + t.Errorf("decoded state differs (-want, +got):\n%s", diff) + } + if !bytes.Equal(c.AUMs["F"].KeyID, []byte{10}) { + t.Errorf("AUM 'F' missing KeyID from template: %v", c.AUMs["F"]) + } + + // chonk := c.Chonk() + // authority, err := Open(chonk) + // if err != nil { + // t.Errorf("failed to initialize from chonk: %v", err) + // } + + // if authority.Head() != c.AUMHashes["F"] { + // t.Errorf("head = %X, want %X", authority.Head(), c.AUMHashes["F"]) + // } +} diff --git a/tka/tka.go b/tka/tka.go index cec790d99..5974950e6 100644 --- a/tka/tka.go +++ b/tka/tka.go @@ -4,3 +4,335 @@ // Package tka (WIP) implements the Tailnet Key Authority. package tka + +import ( + "bytes" + "errors" + "fmt" + "os" + "sort" +) + +// A chain describes a linear sequence of updates from Oldest to Head, +// resulting in some State at Head. +type chain struct { + Oldest AUM + Head AUM + + state State + + // Set to true if the AUM chain intersects with the active + // chain from a previous run. + chainsThroughActive bool +} + +// computeChainCandidates returns all possible chains based on AUMs stored +// in the given tailchonk. A chain is defined as a unique (oldest, newest) +// AUM tuple. chain.state is not yet populated in returned chains. +// +// If lastKnownOldest is provided, any chain that includes the given AUM +// has the chainsThroughActive field set to true. This bit is leveraged +// in computeActiveAncestor() to filter out irrelevant chains when determining +// the active ancestor from a list of distinct chains. +func computeChainCandidates(storage Chonk, lastKnownOldest *AUMHash, maxIter int) ([]chain, error) { + heads, err := storage.Heads() + if err != nil { + return nil, fmt.Errorf("reading heads: %v", err) + } + candidates := make([]chain, len(heads)) + for i := range heads { + // Oldest is iteratively computed below. + candidates[i] = chain{Oldest: heads[i], Head: heads[i]} + } + // Not strictly necessary, but simplifies checks in tests. + sort.Slice(candidates, func(i, j int) bool { + ih, jh := candidates[i].Oldest.Hash(), candidates[j].Oldest.Hash() + return bytes.Compare(ih[:], jh[:]) < 0 + }) + + // candidates.Oldest needs to be computed by working backwards from + // head as far as we can. + iterAgain := true // if theres still work to be done. + for i := 0; iterAgain; i++ { + if i >= maxIter { + return nil, fmt.Errorf("iteration limit exceeded (%d)", maxIter) + } + + iterAgain = false + for j := range candidates { + parent, hasParent := candidates[j].Oldest.Parent() + if hasParent { + parent, err := storage.AUM(parent) + if err != nil { + if err == os.ErrNotExist { + continue + } + return nil, fmt.Errorf("reading parent: %v", err) + } + candidates[j].Oldest = parent + if lastKnownOldest != nil && *lastKnownOldest == parent.Hash() { + candidates[j].chainsThroughActive = true + } + iterAgain = true + } + } + } + return candidates, nil +} + +// pickNextAUM returns the AUM which should be used as the next +// AUM in the chain, possibly applying fork resolution logic. +// +// In other words: given an AUM with 3 children like this: +// / - 1 +// P - 2 +// \ - 3 +// +// pickNextAUM will determine and return the correct branch. +// +// This method takes ownership of the provided slice. +func pickNextAUM(state State, candidates []AUM) AUM { + switch len(candidates) { + case 0: + panic("pickNextAUM called with empty candidate set") + case 1: + return candidates[0] + } + + // Oooof, we have some forks in the chain. We need to pick which + // one to use by applying the Fork Resolution Algorithm ✨ + // + // The rules are this: + // 1. The child with the highest signature weight is chosen. + // 2. If equal, the child which is a RemoveKey AUM is chosen. + // 3. If equal, the child with the lowest AUM hash is chosen. + sort.Slice(candidates, func(j, i int) bool { + // Rule 1. + iSigWeight, jSigWeight := candidates[i].Weight(state), candidates[j].Weight(state) + if iSigWeight != jSigWeight { + return iSigWeight < jSigWeight + } + + // Rule 2. + if iKind, jKind := candidates[i].MessageKind, candidates[j].MessageKind; iKind != jKind && + (iKind == AUMRemoveKey || jKind == AUMRemoveKey) { + return jKind == AUMRemoveKey + } + + // Rule 3. + iHash, jHash := candidates[i].Hash(), candidates[j].Hash() + return bytes.Compare(iHash[:], jHash[:]) > 0 + }) + + return candidates[0] +} + +// advanceChain computes the next AUM to advance with based on all child +// AUMs, returning the chosen AUM & the state obtained by applying that +// AUM. +// +// The return value for next is nil if there are no children AUMs, hence +// the provided state is at head (up to date). +func advanceChain(state State, candidates []AUM) (next *AUM, out State, err error) { + if len(candidates) == 0 { + return nil, state, nil + } + + aum := pickNextAUM(state, candidates) + if state, err = state.applyVerifiedAUM(aum); err != nil { + return nil, State{}, fmt.Errorf("advancing state: %v", err) + } + return &aum, state, nil +} + +// fastForward iteratively advances the current state based on known AUMs until +// the given termination function returns true or there is no more progress possible. +// +// The last-processed AUM, and the state computed after applying the last AUM, +// are returned. +func fastForward(storage Chonk, maxIter int, startState State, done func(curAUM AUM, curState State) bool) (AUM, State, error) { + if startState.LastAUMHash == nil { + return AUM{}, State{}, errors.New("invalid initial state") + } + nextAUM, err := storage.AUM(*startState.LastAUMHash) + if err != nil { + return AUM{}, State{}, fmt.Errorf("reading next: %v", err) + } + + curs := nextAUM + state := startState + for i := 0; i < maxIter; i++ { + if done != nil && done(curs, state) { + return curs, state, nil + } + + children, err := storage.ChildAUMs(curs.Hash()) + if err != nil { + return AUM{}, State{}, fmt.Errorf("getting children of %X: %v", curs.Hash(), err) + } + next, nextState, err := advanceChain(state, children) + if err != nil { + return AUM{}, State{}, fmt.Errorf("advance %X: %v", curs.Hash(), err) + } + if next == nil { + // There were no more children, we are at 'head'. + return curs, state, nil + } + curs = *next + state = nextState + } + + return AUM{}, State{}, fmt.Errorf("iteration limit exceeded (%d)", maxIter) +} + +// computeStateAt returns the State at wantHash. +func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) { + // TODO(tom): This is going to get expensive for really long + // chains. We should make nodes emit a checkpoint every + // X updates or something. + + topAUM, err := storage.AUM(wantHash) + if err != nil { + return State{}, err + } + + // Iterate backwards till we find a starting point to compute + // the state from. + // + // Valid starting points are either a checkpoint AUM, or a + // genesis AUM. + curs := topAUM + var state State + for i := 0; true; i++ { + if i > maxIter { + return State{}, fmt.Errorf("iteration limit exceeded (%d)", maxIter) + } + + // Checkpoints encapsulate the state at that point, dope. + if curs.MessageKind == AUMCheckpoint { + state = curs.State.cloneForUpdate(&curs) + break + } + parent, hasParent := curs.Parent() + if !hasParent { + // This is a 'genesis' update: there are none before it, so + // this AUM can be applied to the empty state to determine + // the state at this AUM. + // + // It is only valid for NoOp, AddKey, and Checkpoint AUMs + // to be a genesis update. Checkpoint was handled earlier. + if mk := curs.MessageKind; mk == AUMNoOp || mk == AUMAddKey { + var err error + if state, err = (State{}).applyVerifiedAUM(curs); err != nil { + return State{}, fmt.Errorf("applying genesis (%+v): %v", curs, err) + } + break + } + return State{}, fmt.Errorf("invalid genesis update: %+v", curs) + } + + // If we got here, the current state is dependent on the previous. + // Keep iterating backwards till thats not the case. + if curs, err = storage.AUM(parent); err != nil { + return State{}, fmt.Errorf("reading parent: %v", err) + } + } + + // We now know some starting point state. Iterate forward till we + // are at the AUM we want state for. + _, state, err = fastForward(storage, maxIter, state, func(curs AUM, _ State) bool { + return curs.Hash() == wantHash + }) + // fastForward only terminates before the done condition if it + // doesnt have any later AUMs to process. This cant be the case + // as we've already iterated through them above so they must exist, + // but we check anyway to be super duper sure. + if err == nil && *state.LastAUMHash != wantHash { + panic("unexpected fastForward outcome") + } + return state, err +} + +// computeActiveAncestor determines which ancestor AUM to use as the +// ancestor of the valid chain. +// +// If all the chains end up having the same ancestor, then thats the +// only possible ancestor, ezpz. However if there are multiple distinct +// ancestors, that means there are distinct chains, and we need some +// hint to choose what to use. For that, we rely on the chainsThroughActive +// bit, which signals to us that that ancestor was part of the +// chain in a previous run. +func computeActiveAncestor(storage Chonk, chains []chain) (AUMHash, error) { + // Dedupe possible ancestors, tracking if they were part of + // the active chain on a previous run. + ancestors := make(map[AUMHash]bool, len(chains)) + for _, c := range chains { + ancestors[c.Oldest.Hash()] = c.chainsThroughActive + } + + if len(ancestors) == 1 { + // There's only one. DOPE. + for k, _ := range ancestors { + return k, nil + } + } + + // Theres more than one, so we need to use the ancestor that was + // part of the active chain in a previous iteration. + // Note that there can only be one distinct ancestor that was + // formerly part of the active chain, because AUMs can only have + // one parent and would have converged to a common ancestor. + for k, chainsThroughActive := range ancestors { + if chainsThroughActive { + return k, nil + } + } + + return AUMHash{}, errors.New("multiple distinct chains") +} + +// computeActiveChain bootstraps the runtime state of the Authority when +// starting entirely off stored state. +// +// TODO(tom): Don't look at head states, just iterate forward from +// the ancestor. +// +// The algorithm is as follows: +// 1. Determine all possible 'head' (like in git) states. +// 2. Filter these possible chains based on whether the ancestor was +// formerly (in a previous run) part of the chain. +// 3. Compute the state of the state machine at this ancestor. This is +// needed for fast-forward, as each update operates on the state of +// the update preceeding it. +// 4. Iteratively apply updates till we reach head ('fast forward'). +func computeActiveChain(storage Chonk, lastKnownOldest *AUMHash, maxIter int) (chain, error) { + chains, err := computeChainCandidates(storage, lastKnownOldest, maxIter) + if err != nil { + return chain{}, fmt.Errorf("computing candidates: %v", err) + } + + // Find the right ancestor. + oldestHash, err := computeActiveAncestor(storage, chains) + if err != nil { + return chain{}, fmt.Errorf("computing ancestor: %v", err) + } + ancestor, err := storage.AUM(oldestHash) + if err != nil { + return chain{}, err + } + + // At this stage we know the ancestor AUM, so we have excluded distinct + // chains but we might still have forks (so we don't know the head AUM). + // + // We iterate forward from the ancestor AUM, handling any forks as we go + // till we arrive at a head. + out := chain{Oldest: ancestor, Head: ancestor} + if out.state, err = computeStateAt(storage, maxIter, oldestHash); err != nil { + return chain{}, fmt.Errorf("bootstrapping state: %v", err) + } + out.Head, out.state, err = fastForward(storage, maxIter, out.state, nil) + if err != nil { + return chain{}, fmt.Errorf("fast forward: %v", err) + } + return out, nil +} diff --git a/tka/tka_test.go b/tka/tka_test.go new file mode 100644 index 000000000..72b6d3476 --- /dev/null +++ b/tka/tka_test.go @@ -0,0 +1,187 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tka + +import ( + "bytes" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestComputeChainCandidates(t *testing.T) { + c := newTestchain(t, ` + G1 -> I1 -> I2 -> I3 -> L2 + | -> L1 | -> L3 + + G2 -> L4 + + // We tweak these AUMs so they are different hashes. + G2.hashSeed = 2 + L1.hashSeed = 2 + L3.hashSeed = 2 + L4.hashSeed = 3 + `) + // Should result in 4 chains: + // G1->L1, G1->L2, G1->L3, G2->L4 + + i1H := c.AUMHashes["I1"] + got, err := computeChainCandidates(c.Chonk(), &i1H, 50) + if err != nil { + t.Fatalf("computeChainCandidates() failed: %v", err) + } + + want := []chain{ + {Oldest: c.AUMs["G1"], Head: c.AUMs["L1"], chainsThroughActive: true}, + {Oldest: c.AUMs["G1"], Head: c.AUMs["L3"], chainsThroughActive: true}, + {Oldest: c.AUMs["G1"], Head: c.AUMs["L2"], chainsThroughActive: true}, + {Oldest: c.AUMs["G2"], Head: c.AUMs["L4"]}, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(chain{})); diff != "" { + t.Errorf("chains differ (-want, +got):\n%s", diff) + } +} + +func TestForkResolutionHash(t *testing.T) { + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + + // tweak hashes so L1 & L2 are not identical + L1.hashSeed = 2 + L2.hashSeed = 3 + `) + + got, err := computeActiveChain(c.Chonk(), nil, 50) + if err != nil { + t.Fatalf("computeActiveChain() failed: %v", err) + } + + // The fork with the lowest AUM hash should have been chosen. + l1H := c.AUMHashes["L1"] + l2H := c.AUMHashes["L2"] + want := l1H + if bytes.Compare(l2H[:], l1H[:]) < 0 { + want = l2H + } + + if got := got.Head.Hash(); got != want { + t.Errorf("head was %x, want %x", got, want) + } +} + +func TestForkResolutionSigWeight(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + + G1.template = addKey + L1.hashSeed = 2 + L2.signedWith = key + `, + optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), + optKey("key", key, priv)) + + l1H := c.AUMHashes["L1"] + l2H := c.AUMHashes["L2"] + if bytes.Compare(l2H[:], l1H[:]) < 0 { + t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") + } + + got, err := computeActiveChain(c.Chonk(), nil, 50) + if err != nil { + t.Fatalf("computeActiveChain() failed: %v", err) + } + + // Based on the hash, l1H should be chosen. + // But based on the signature weight (which has higher + // precedence), it should be l2H + want := l2H + if got := got.Head.Hash(); got != want { + t.Errorf("head was %x, want %x", got, want) + } +} + +func TestForkResolutionMessageType(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + | -> L3 + + G1.template = addKey + L1.hashSeed = 11 + L2.template = removeKey + L3.hashSeed = 18 + `, + optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), + optTemplate("removeKey", AUM{MessageKind: AUMRemoveKey, KeyID: key.ID()})) + + l1H := c.AUMHashes["L1"] + l2H := c.AUMHashes["L2"] + l3H := c.AUMHashes["L3"] + if bytes.Compare(l2H[:], l1H[:]) < 0 { + t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") + } + if bytes.Compare(l2H[:], l3H[:]) < 0 { + t.Fatal("failed assert: h(l3) > h(l2)\nTweak hashSeed till this passes") + } + + got, err := computeActiveChain(c.Chonk(), nil, 50) + if err != nil { + t.Fatalf("computeActiveChain() failed: %v", err) + } + + // Based on the hash, L1 or L3 should be chosen. + // But based on the preference for AUMRemoveKey messages, + // it should be L2. + want := l2H + if got := got.Head.Hash(); got != want { + t.Errorf("head was %x, want %x", got, want) + } +} + +func TestComputeStateAt(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> I1 -> I2 + I1.template = addKey + `, + optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key})) + + // G1 is before the key, so there shouldn't be a key there. + state, err := computeStateAt(c.Chonk(), 500, c.AUMHashes["G1"]) + if err != nil { + t.Fatalf("computeStateAt(G1) failed: %v", err) + } + if _, err := state.GetKey(key.ID()); err != ErrNoSuchKey { + t.Errorf("expected key to be missing: err = %v", err) + } + if *state.LastAUMHash != c.AUMHashes["G1"] { + t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, c.AUMHashes["G1"]) + } + + // I1 & I2 are after the key, so the computed state should contain + // the key. + for _, wantHash := range []AUMHash{c.AUMHashes["I1"], c.AUMHashes["I2"]} { + state, err = computeStateAt(c.Chonk(), 500, wantHash) + if err != nil { + t.Fatalf("computeStateAt(%X) failed: %v", wantHash, err) + } + if *state.LastAUMHash != wantHash { + t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, wantHash) + } + if _, err := state.GetKey(key.ID()); err != nil { + t.Errorf("expected key to be present at state: err = %v", err) + } + } +}