tsnet: cleanup resources upon start failure (#5301)

In a partially initialized state, we should cleanup
all prior resources when an error occurs.

Signed-off-by: Joe Tsai <joetsai@digital-static.net>
pull/5304/head
Joe Tsai 2 years ago committed by GitHub
parent f0d6f173c9
commit b1fff4499f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,6 +10,7 @@ package tsnet
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
@ -182,7 +183,10 @@ func (s *Server) getAuthKey() string {
return os.Getenv("TS_AUTHKEY") return os.Getenv("TS_AUTHKEY")
} }
func (s *Server) start() error { func (s *Server) start() (reterr error) {
var closePool closeOnErrorPool
defer closePool.closeAllIfError(&reterr)
exe, err := os.Executable() exe, err := os.Executable()
if err != nil { if err != nil {
return err return err
@ -244,6 +248,7 @@ func (s *Server) start() error {
if err != nil { if err != nil {
return fmt.Errorf("error creating filch: %w", err) return fmt.Errorf("error creating filch: %w", err)
} }
closePool.add(s.logbuffer)
c := logtail.Config{ c := logtail.Config{
Collection: lpc.Collection, Collection: lpc.Collection,
PrivateID: lpc.PrivateID, PrivateID: lpc.PrivateID,
@ -259,11 +264,13 @@ func (s *Server) start() error {
HTTPC: &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost)}, HTTPC: &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost)},
} }
s.logtail = logtail.NewLogger(c, logf) s.logtail = logtail.NewLogger(c, logf)
closePool.addFunc(func() { s.logtail.Shutdown(context.Background()) })
s.linkMon, err = monitor.New(logf) s.linkMon, err = monitor.New(logf)
if err != nil { if err != nil {
return err return err
} }
closePool.add(s.linkMon)
s.dialer = new(tsdial.Dialer) // mutated below (before used) s.dialer = new(tsdial.Dialer) // mutated below (before used)
eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
@ -274,6 +281,7 @@ func (s *Server) start() error {
if err != nil { if err != nil {
return err return err
} }
closePool.add(s.dialer)
tunDev, magicConn, dns, ok := eng.(wgengine.InternalsGetter).GetInternals() tunDev, magicConn, dns, ok := eng.(wgengine.InternalsGetter).GetInternals()
if !ok { if !ok {
@ -317,6 +325,7 @@ func (s *Server) start() error {
lb.SetVarRoot(s.rootPath) lb.SetVarRoot(s.rootPath)
logf("tsnet starting with hostname %q, varRoot %q", s.hostname, s.rootPath) logf("tsnet starting with hostname %q, varRoot %q", s.hostname, s.rootPath)
s.lb = lb s.lb = lb
closePool.addFunc(func() { s.lb.Shutdown() })
lb.SetDecompressor(func() (controlclient.Decompressor, error) { lb.SetDecompressor(func() (controlclient.Decompressor, error) {
return smallzstd.NewDecoder(nil) return smallzstd.NewDecoder(nil)
}) })
@ -357,9 +366,22 @@ func (s *Server) start() error {
logf("localapi serve error: %v", err) logf("localapi serve error: %v", err)
} }
}() }()
closePool.add(s.localAPIListener)
return nil return nil
} }
type closeOnErrorPool []func()
func (p *closeOnErrorPool) add(c io.Closer) { *p = append(*p, func() { c.Close() }) }
func (p *closeOnErrorPool) addFunc(fn func()) { *p = append(*p, fn) }
func (p closeOnErrorPool) closeAllIfError(errp *error) {
if *errp != nil {
for _, closeFn := range p {
closeFn()
}
}
}
func (s *Server) logf(format string, a ...interface{}) { func (s *Server) logf(format string, a ...interface{}) {
if s.logtail != nil { if s.logtail != nil {
s.logtail.Logf(format, a...) s.logtail.Logf(format, a...)

Loading…
Cancel
Save