diff --git a/util/workgraph/workgraph.go b/util/workgraph/workgraph.go new file mode 100644 index 000000000..183a736c0 --- /dev/null +++ b/util/workgraph/workgraph.go @@ -0,0 +1,450 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package workgraph contains a "workgraph"; a data structure that allows +// defining individual jobs, dependencies between them, and then executing all +// jobs to completion. +package workgraph + +import ( + "context" + "errors" + "fmt" + "runtime" + "slices" + "strings" + "sync" + + "tailscale.com/util/set" +) + +// ErrCyclic is returned when there is a cycle in the graph. +var ErrCyclic = errors.New("graph is cyclic") + +// Node is the interface that must be implemented by a node in a WorkGraph. +type Node interface { + // ID should return a unique ID for this node. IDs for each Node in a + // WorkGraph must be unique. + ID() string + + // Run is called when this node in a WorkGraph is executed; it should + // return an error if execution fails, which will cause all dependent + // Nodes to fail to execute. + Run(context.Context) error +} + +type nodeFunc struct { + id string + run func(context.Context) error +} + +func (n *nodeFunc) ID() string { return n.id } +func (n *nodeFunc) Run(ctx context.Context) error { return n.run(ctx) } + +// NodeFunc is a helper that returns a Node with the given ID that calls the +// given function when Node.Run is called. +func NodeFunc(id string, fn func(context.Context) error) Node { + return &nodeFunc{id, fn} +} + +// WorkGraph is a directed acyclic graph of individual jobs to be executed, +// each of which may have dependencies on other jobs. It supports adding a job +// as a Node–a combination of a unique ID and the function to execute that +// job–and then running all added Nodes while respecting dependencies. +type WorkGraph struct { + nodes map[string]Node // keyed by Node.ID + edges edgeList[string] // keyed by Node.ID + + // Concurrency is the number of concurrent goroutines to use to process + // jobs. If zero, runtime.GOMAXPROCS will be used. + // + // This field must not be modified after Run has been called. + Concurrency int +} + +// NewWorkGraph creates a new empty WorkGraph. +func NewWorkGraph() *WorkGraph { + ret := &WorkGraph{ + nodes: make(map[string]Node), + edges: newEdgeList[string](), + } + return ret +} + +// AddNodeOpts contains options that can be passed to AddNode. +type AddNodeOpts struct { + // Dependencies are any Node IDs that must be completed before this + // Node is started. + Dependencies []string +} + +// AddNode adds a new Node to the WorkGraph with the provided options. It +// returns an error if the given Node.ID was already added to the WorkGraph, or +// if one of the options provided was invalid. +func (g *WorkGraph) AddNode(n Node, opts *AddNodeOpts) error { + id := n.ID() + if _, found := g.nodes[id]; found { + return fmt.Errorf("node %q already exists", id) + } + g.nodes[id] = n + + if opts == nil { + return nil + } + + // Create an edge from each dependency pointing to this node, forcing + // that node to be evaluated first. + for _, dep := range opts.Dependencies { + if _, found := g.nodes[dep]; !found { + return fmt.Errorf("dependency %q not found", dep) + } + g.edges.Add(dep, id) + } + return nil +} + +type queueEntry struct { + id string + done chan struct{} +} + +// Run will iterate through all Nodes in this WorkGraph, running them once all +// their dependencies have been satisfied, and returning any errors that occur. +func (g *WorkGraph) Run(ctx context.Context) error { + groups, err := g.topoSortKahn() + if err != nil { + return err + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Create one goroutine that pushes jobs onto our queue... + var wg sync.WaitGroup + queue := make(chan queueEntry) + publishCtx, publishCancel := context.WithCancel(ctx) + defer publishCancel() + + wg.Add(1) + go g.runPublisher(publishCtx, &wg, queue, groups) + + firstErr := make(chan error, 1) + saveErr := func(err error) { + if err == nil { + return + } + + // Tell the publisher to shut down + publishCancel() + + select { + case firstErr <- err: + default: + } + } + + // ... and N goroutines that each work on an item from the queue. + n := g.Concurrency + if n == 0 { + n = runtime.GOMAXPROCS(-1) + } + + wg.Add(n) + for i := 0; i < n; i++ { + go g.runWorker(ctx, &wg, queue, saveErr) + } + + wg.Wait() + select { + case err := <-firstErr: + return err + default: + } + return nil +} + +func (g *WorkGraph) runPublisher(ctx context.Context, wg *sync.WaitGroup, queue chan queueEntry, groups []set.Set[string]) { + defer wg.Done() + defer close(queue) + + // For each parallel group... + var dones []chan struct{} + for _, group := range groups { + dones = dones[:0] // re-use existing storage, if any + + // Push all items in this group onto our queue + for curr := range group { + done := make(chan struct{}) + dones = append(dones, done) + + select { + case <-ctx.Done(): + return + case queue <- queueEntry{curr, done}: + } + } + + // Now that we've started everything, wait for them all + // to complete. + for _, done := range dones { + select { + case <-ctx.Done(): + return + case <-done: + } + } + + // Now that we've done this entire group, we can + // continue with the next one. + } +} + +func (g *WorkGraph) runWorker(ctx context.Context, wg *sync.WaitGroup, queue chan queueEntry, saveErr func(error)) { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case ent, ok := <-queue: + if !ok { + return + } + if err := g.runEntry(ctx, ent); err != nil { + saveErr(err) + return + } + } + } +} + +func (g *WorkGraph) runEntry(ctx context.Context, ent queueEntry) (retErr error) { + defer close(ent.done) + defer func() { + if r := recover(); r != nil { + // Ensure that we wrap an existing error with %w so errors.Is works + switch v := r.(type) { + case error: + retErr = fmt.Errorf("node %q: caught panic: %w", ent.id, v) + default: + retErr = fmt.Errorf("node %q: caught panic: %v", ent.id, v) + } + } + }() + + node := g.nodes[ent.id] + return node.Run(ctx) +} + +// Depth-first toplogical sort; used in tests +// +// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search +func (g *WorkGraph) topoSortDFS() (sorted []string, err error) { + const ( + markTemporary = 1 + markPermanent = 2 + ) + marks := make(map[string]int) // map[node.ID]markType + + var visit func(string) error + visit = func(n string) error { + // "if n has a permanent mark then" + if marks[n] == markPermanent { + return nil + } + // "if n has a temporary mark then" + if marks[n] == markTemporary { + return ErrCyclic + } + + // "mark n with a temporary mark" + marks[n] = markTemporary + + // "for each node m with an edge from n to m do" + for m := range g.edges.OutgoingNodes(n) { + if err := visit(m); err != nil { + return err + } + } + + // "remove temporary mark from n" + // "mark n with a permanent mark" + // + // NOTE: this is safe because if this node had a temporary + // mark, we'd have returned above, and the only thing that adds + // a mark to a node is this function. + marks[n] = markPermanent + + // "add n to head of L"; note that we append for performance + // reasons and reverse later + sorted = append(sorted, n) + return nil + } + + // For all nodes, visit them. From the algorithm description: + // while exists nodes without a permanent mark do + // select an unmarked node n + // visit(n) + for nid := range g.nodes { + if err := visit(nid); err != nil { + return nil, err + } + } + + // We appended to the slice for performance reasons; reverse it to get + // our final result. + slices.Reverse(sorted) + return sorted, nil +} + +// topoSortKahn runs a variant of Kahn's algorithm for topological sorting, +// which not only returns a sort, but provides individual "groups" of nodes +// that can be executed concurrently. +// +// See: +// - https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm +// - https://stackoverflow.com/a/67267597 +func (g *WorkGraph) topoSortKahn() (sorted []set.Set[string], err error) { + // We mutate the set of edges during this function, so copy it. + edges := g.edges.Clone() + + // Create S_0, the set of nodes with no incoming edge + s0 := make(set.Set[string]) + for nid := range g.nodes { + if !edges.HasIncoming(nid) { + s0.Add(nid) + } + } + + // Add this set to the returned set of nodes + sorted = append(sorted, s0) + + // Repeatedly iterate, starting from the initial set, until we have no + // more nodes. The inner loop is essentially Kahn's algorithm. + sCurr := s0 + for { + // Initialize the next set + sNext := make(set.Set[string]) + + // For each node 'n' in the current set... + for n := range sCurr { + // For each successor 'd' of the current node... + for d := range edges.OutgoingNodes(n) { + // Remove edge 'n -> d' + edges.Remove(n, d) + + // If this node 'd' has no incoming edges, we + // can add it to the current set since it can + // be processed. + if !edges.HasIncoming(d) { + sNext.Add(d) + } + } + } + + // If the current set is non-empty, then append it to the list + // of returned sets, make it the current set, and continue. + // Otherwise, we're done. + if len(sNext) == 0 { + break + } + + sorted = append(sorted, sNext) + sCurr = sNext + } + + if edges.Len() > 0 { + return nil, ErrCyclic + } + return sorted, nil +} + +// Graphviz prints a basic Graphviz representation of the WorkGraph. This is +// primarily useful for debugging. +func (g *WorkGraph) Graphviz() string { + var buf strings.Builder + buf.WriteString("digraph workgraph {\n") + for from, edges := range g.edges.outgoing { + for to := range edges { + fmt.Fprintf(&buf, "\t%s -> %s;\n", from, to) + } + } + buf.WriteString("}") + return buf.String() +} + +// edgeList is a helper type that is used to maintain a set of edges, tracking +// both incoming and outgoing edges for a given node. +type edgeList[K comparable] struct { + incoming map[K]set.Set[K] // for edge A -> B, keyed by B + outgoing map[K]set.Set[K] // for edge A -> B, keyed by A +} + +func newEdgeList[K comparable]() edgeList[K] { + return edgeList[K]{ + incoming: make(map[K]set.Set[K]), + outgoing: make(map[K]set.Set[K]), + } +} + +func (el *edgeList[K]) Clone() edgeList[K] { + ret := edgeList[K]{ + incoming: make(map[K]set.Set[K], len(el.incoming)), + outgoing: make(map[K]set.Set[K], len(el.outgoing)), + } + for k, v := range el.incoming { + ret.incoming[k] = v.Clone() + } + for k, v := range el.outgoing { + ret.outgoing[k] = v.Clone() + } + return ret +} + +func (el *edgeList[K]) Len() int { + i := 0 + for _, set := range el.incoming { + i += set.Len() + } + return i +} + +func (el *edgeList[K]) Add(from, to K) { + if _, found := el.incoming[to]; !found { + el.incoming[to] = make(set.Set[K]) + } + if _, found := el.outgoing[from]; !found { + el.outgoing[from] = make(set.Set[K]) + } + + el.incoming[to].Add(from) + el.outgoing[from].Add(to) +} + +func (el *edgeList[K]) Remove(from, to K) { + if m, ok := el.incoming[to]; ok { + delete(m, from) + } + if m, ok := el.outgoing[from]; ok { + delete(m, to) + } +} + +func (el *edgeList[K]) HasIncoming(id K) bool { + return el.incoming[id].Len() > 0 +} + +func (el *edgeList[K]) HasOutgoing(id K) bool { + return el.outgoing[id].Len() > 0 +} + +func (el *edgeList[K]) Exists(from, to K) bool { + return el.outgoing[from].Contains(to) +} + +func (el *edgeList[K]) IncomingNodes(id K) set.Set[K] { + return el.incoming[id] +} + +func (el *edgeList[K]) OutgoingNodes(id K) set.Set[K] { + return el.outgoing[id] +} diff --git a/util/workgraph/workgraph_test.go b/util/workgraph/workgraph_test.go new file mode 100644 index 000000000..c36731f6b --- /dev/null +++ b/util/workgraph/workgraph_test.go @@ -0,0 +1,353 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package workgraph + +import ( + "context" + "errors" + "fmt" + "runtime" + "sync/atomic" + "testing" + + "tailscale.com/util/must" + "tailscale.com/util/set" +) + +func debugGraph(tb testing.TB, g *WorkGraph) { + before := g.Graphviz() + tb.Cleanup(func() { + if !tb.Failed() { + return + } + + after := g.Graphviz() + tb.Logf("graphviz at start of test:\n%s", before) + tb.Logf("graphviz at end of test:\n%s", after) + }) +} + +func makeTestGraph(tb testing.TB) *WorkGraph { + logFunc := func(s string) func(context.Context) error { + return func(_ context.Context) error { + tb.Log(s) + return nil + } + } + makeNode := func(s string) Node { + return NodeFunc(s, logFunc(s+" called")) + } + withDeps := func(ss ...string) *AddNodeOpts { + return &AddNodeOpts{Dependencies: ss} + } + + g := NewWorkGraph() + + // Ensure we have at least 2 concurrent goroutines + g.Concurrency = runtime.GOMAXPROCS(-1) + if g.Concurrency < 2 { + g.Concurrency = 2 + } + + n1 := makeNode("one") + n2 := makeNode("two") + n3 := makeNode("three") + n4 := makeNode("four") + n5 := makeNode("five") + n6 := makeNode("six") + + must.Do(g.AddNode(n1, nil)) // can execute first + must.Do(g.AddNode(n2, nil)) // can execute first + must.Do(g.AddNode(n3, withDeps("one"))) + must.Do(g.AddNode(n4, withDeps("one", "two"))) + must.Do(g.AddNode(n5, withDeps("one"))) + must.Do(g.AddNode(n6, withDeps("four", "five"))) + + return g +} + +func TestWorkGraph(t *testing.T) { + g := makeTestGraph(t) + debugGraph(t, g) + + if err := g.Run(context.Background()); err != nil { + t.Fatal(err) + } +} + +func TestWorkGroup_Error(t *testing.T) { + g := NewWorkGraph() + + terr := errors.New("test error") + + returnsErr := func(_ context.Context) error { return terr } + notCalled := func(_ context.Context) error { panic("unused") } + + n1 := NodeFunc("one", returnsErr) + n2 := NodeFunc("two", notCalled) + n3 := NodeFunc("three", notCalled) + + must.Do(g.AddNode(n1, nil)) + must.Do(g.AddNode(n2, &AddNodeOpts{Dependencies: []string{"one"}})) + must.Do(g.AddNode(n3, &AddNodeOpts{Dependencies: []string{"one", "two"}})) + + err := g.Run(context.Background()) + if err == nil { + t.Fatal("wanted non-nil error") + } + if !errors.Is(err, terr) { + t.Errorf("got %v, want %v", err, terr) + } +} + +func TestWorkGroup_HandlesPanic(t *testing.T) { + g := NewWorkGraph() + + terr := errors.New("test error") + n1 := NodeFunc("one", func(_ context.Context) error { panic(terr) }) + + must.Do(g.AddNode(n1, nil)) + err := g.Run(context.Background()) + if err == nil { + t.Fatal("wanted non-nil error") + } + if !errors.Is(err, terr) { + t.Errorf("got %v, want %v", err, terr) + } +} + +func TestWorkGroup_Cancellation(t *testing.T) { + g := NewWorkGraph() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var running atomic.Int64 + blocks := func(ctx context.Context) error { + running.Add(1) + <-ctx.Done() + return ctx.Err() + } + + n1 := NodeFunc("one", blocks) + n2 := NodeFunc("two", blocks) + n3 := NodeFunc("three", blocks) + + must.Do(g.AddNode(n1, nil)) + must.Do(g.AddNode(n2, nil)) + + // Ensure that we have a node with dependencies that's also waiting + // since we want to verify that the queue publisher also properly + // handles context cancellation. + must.Do(g.AddNode(n3, &AddNodeOpts{Dependencies: []string{"one", "two"}})) + + // call Run in a goroutine since it blocks + errCh := make(chan error, 1) + go func() { + errCh <- g.Run(ctx) + }() + + // after all goroutines are running, cancel the context to unblock + for running.Load() != 2 { + // wait + } + cancel() + err := <-errCh + + if err == nil { + t.Fatal("wanted non-nil error") + } + if !errors.Is(err, context.Canceled) { + t.Errorf("got %v, want %v", err, context.Canceled) + } +} + +func TestTopoSortDFS(t *testing.T) { + g := makeTestGraph(t) + debugGraph(t, g) + + sorted, err := g.topoSortDFS() + if err != nil { + t.Fatal(err) + } + t.Logf("DFS topological sort: %v", sorted) + + validateTopologicalSortDFS(t, g, sorted) +} + +func validateTopologicalSortDFS(tb testing.TB, g *WorkGraph, order []string) { + // A valid ordering is any one where a node ID later in the list does + // not depend on a node ID earlier in the list. + for i, node := range order { + for j := 0; j < i; j++ { + if g.edges.Exists(node, order[j]) { + tb.Errorf("invalid edge: %v [%d] -> %v [%d]", node, i, order[j], j) + } + } + } +} + +func TestTopoSortKahn(t *testing.T) { + g := makeTestGraph(t) + debugGraph(t, g) + + groups, err := g.topoSortKahn() + if err != nil { + t.Fatal(err) + } + t.Logf("grouped topological sort: %v", groups) + + validateTopologicalSortKahn(t, g, groups) +} + +func validateTopologicalSortKahn(tb testing.TB, g *WorkGraph, groups []set.Set[string]) { + // A valid ordering is any one where a node ID later in the list does + // not depend on a node ID earlier in the list. + prev := make(map[string]bool) + for i, group := range groups { + for node := range group { + for m := range prev { + if g.edges.Exists(node, m) { + tb.Errorf("group[%d]: invalid edge: %v -> %v", i, node, m) + } + } + prev[node] = true + } + } + + // Verify that our topologically sorted groups contain all nodes. + for nid := range g.nodes { + if !prev[nid] { + tb.Errorf("topological sort missing node %v", nid) + } + } +} + +func FuzzTopSortKahn(f *testing.F) { + // We can't pass a map[string][]string (or similar) into a fuzz + // function, so instead let's create test data by using a combination + // of 'n' nodes and an adjacency matrix of edges from node to node. + // + // We then need to filter this adjacency matrix in the Fuzz function, + // since the fuzzer doesn't distinguish between "invalid fuzz inputs + // due to logic bugs", and "invalid fuzz data that causes a real + // error". + f.Add( + 10, // number of nodes + []byte{ + 1, 0, // 1 depends on 0 + 6, 2, // 6 depends on 2 + 9, 8, // 9 depends on 8 + }, + ) + f.Fuzz(func(t *testing.T, numNodes int, edges []byte) { + g := createGraphFromFuzzInput(t, numNodes, edges) + if g == nil { + return + } + + // This should not error + groups, err := g.topoSortKahn() + if err != nil { + t.Fatal(err) + } + validateTopologicalSortKahn(t, g, groups) + }) +} + +func FuzzTopSortDFS(f *testing.F) { + // We can't pass a map[string][]string (or similar) into a fuzz + // function, so instead let's create test data by using a combination + // of 'n' nodes and an adjacency matrix of edges from node to node. + // + // We then need to filter this adjacency matrix in the Fuzz function, + // since the fuzzer doesn't distinguish between "invalid fuzz inputs + // due to logic bugs", and "invalid fuzz data that causes a real + // error". + f.Add( + 10, // number of nodes + []byte{ + 1, 0, // 1 depends on 0 + 6, 2, // 6 depends on 2 + 9, 8, // 9 depends on 8 + }, + ) + f.Fuzz(func(t *testing.T, numNodes int, edges []byte) { + g := createGraphFromFuzzInput(t, numNodes, edges) + if g == nil { + return + } + + // This should not error + sorted, err := g.topoSortDFS() + if err != nil { + t.Fatal(err) + } + validateTopologicalSortDFS(t, g, sorted) + }) +} + +func createGraphFromFuzzInput(tb testing.TB, numNodes int, edges []byte) *WorkGraph { + nodeName := func(i int) string { + return fmt.Sprintf("node-%d", i) + } + + filterAdjacencyMatrix := func(numNodes int, edges []byte) map[string][]string { + deps := make(map[string][]string) + for i := 0; i < len(edges); i += 2 { + node, dep := int(edges[i]), int(edges[i+1]) + if node >= numNodes || dep >= numNodes { + // invalid node + continue + } + if node == dep { + // can't depend on self + continue + } + + // We add nodes in incrementing order (0, 1, 2, etc.), + // so an edge can't point 'forward' or it'll fail to be + // added. + if dep > node { + continue + } + + nn := nodeName(node) + deps[nn] = append(deps[nn], nodeName(dep)) + } + return deps + } + + // Constrain the number of nodes + if numNodes <= 0 || numNodes > 1000 { + return nil + } + // Must have pairs of edges (from, to) + if len(edges)%2 != 0 { + return nil + } + + // Convert list of edges into list of dependencies + deps := filterAdjacencyMatrix(numNodes, edges) + if len(deps) == 0 { + return nil + } + + // Actually create graph. + g := NewWorkGraph() + doNothing := func(context.Context) error { return nil } + for i := 0; i < numNodes; i++ { + nn := nodeName(i) + node := NodeFunc(nn, doNothing) + if err := g.AddNode(node, &AddNodeOpts{ + Dependencies: deps[nn], + }); err != nil { + tb.Error(err) // shouldn't error after we filtered out bad edges above + } + } + if tb.Failed() { + return nil + } + return g +}