diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index f452bba46..722c3fe66 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -36,6 +36,8 @@ import ( "tailscale.com/types/wgkey" ) +const msgLimit = 1 << 20 // encrypted message length limit + // Server is a control plane server. Its zero value is ready for use. // Everything is stored in-memory in one tailnet. type Server struct { @@ -397,15 +399,23 @@ func (s *Server) CompleteAuth(authPathOrURL string) bool { } func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey tailcfg.MachineKey) { + msg, err := ioutil.ReadAll(io.LimitReader(r.Body, msgLimit)) + if err != nil { + r.Body.Close() + http.Error(w, fmt.Sprintf("bad map request read: %v", err), 400) + return + } + r.Body.Close() + var req tailcfg.RegisterRequest - if err := s.decode(mkey, r.Body, &req); err != nil { - panic(fmt.Sprintf("serveRegister: decode: %v", err)) + if err := s.decode(mkey, msg, &req); err != nil { + go panic(fmt.Sprintf("serveRegister: decode: %v", err)) } if req.Version != 1 { - panic(fmt.Sprintf("serveRegister: unsupported version: %d", req.Version)) + go panic(fmt.Sprintf("serveRegister: unsupported version: %d", req.Version)) } if req.NodeKey.IsZero() { - panic("serveRegister: request has zero node key") + go panic("serveRegister: request has zero node key") } if s.Verbose { j, _ := json.MarshalIndent(req, "", "\t") @@ -558,8 +568,16 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey tailcfg.M defer s.incrInServeMap(-1) ctx := r.Context() + msg, err := ioutil.ReadAll(io.LimitReader(r.Body, msgLimit)) + if err != nil { + r.Body.Close() + http.Error(w, fmt.Sprintf("bad map request read: %v", err), 400) + return + } + r.Body.Close() + req := new(tailcfg.MapRequest) - if err := s.decode(mkey, r.Body, req); err != nil { + if err := s.decode(mkey, msg, req); err != nil { go panic(fmt.Sprintf("bad map request: %v", err)) } @@ -747,15 +765,7 @@ func (s *Server) sendMapMsg(w http.ResponseWriter, mkey tailcfg.MachineKey, comp return nil } -func (s *Server) decode(mkey tailcfg.MachineKey, r io.Reader, v interface{}) error { - if c, _ := r.(io.Closer); c != nil { - defer c.Close() - } - const msgLimit = 1 << 20 - msg, err := ioutil.ReadAll(io.LimitReader(r, msgLimit)) - if err != nil { - return err - } +func (s *Server) decode(mkey tailcfg.MachineKey, msg []byte, v interface{}) error { if len(msg) == msgLimit { return errors.New("encrypted message too long") }