// Copyright 2019 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package ssh import ( "reflect" "testing" ) func TestFindAgreedAlgorithms(t *testing.T) { initKex := func(k *kexInitMsg) { if k.KexAlgos == nil { k.KexAlgos = []string{"kex1"} } if k.ServerHostKeyAlgos == nil { k.ServerHostKeyAlgos = []string{"hostkey1"} } if k.CiphersClientServer == nil { k.CiphersClientServer = []string{"cipher1"} } if k.CiphersServerClient == nil { k.CiphersServerClient = []string{"cipher1"} } if k.MACsClientServer == nil { k.MACsClientServer = []string{"mac1"} } if k.MACsServerClient == nil { k.MACsServerClient = []string{"mac1"} } if k.CompressionClientServer == nil { k.CompressionClientServer = []string{"compression1"} } if k.CompressionServerClient == nil { k.CompressionServerClient = []string{"compression1"} } if k.LanguagesClientServer == nil { k.LanguagesClientServer = []string{"language1"} } if k.LanguagesServerClient == nil { k.LanguagesServerClient = []string{"language1"} } } initDirAlgs := func(a *directionAlgorithms) { if a.Cipher == "" { a.Cipher = "cipher1" } if a.MAC == "" { a.MAC = "mac1" } if a.Compression == "" { a.Compression = "compression1" } } initAlgs := func(a *algorithms) { if a.kex == "" { a.kex = "kex1" } if a.hostKey == "" { a.hostKey = "hostkey1" } initDirAlgs(&a.r) initDirAlgs(&a.w) } type testcase struct { name string clientIn, serverIn kexInitMsg wantClient, wantServer algorithms wantErr bool } cases := []testcase{ { name: "standard", }, { name: "no common hostkey", serverIn: kexInitMsg{ ServerHostKeyAlgos: []string{"hostkey2"}, }, wantErr: true, }, { name: "no common kex", serverIn: kexInitMsg{ KexAlgos: []string{"kex2"}, }, wantErr: true, }, { name: "no common cipher", serverIn: kexInitMsg{ CiphersClientServer: []string{"cipher2"}, }, wantErr: true, }, { name: "client decides cipher", serverIn: kexInitMsg{ CiphersClientServer: []string{"cipher1", "cipher2"}, CiphersServerClient: []string{"cipher2", "cipher3"}, }, clientIn: kexInitMsg{ CiphersClientServer: []string{"cipher2", "cipher1"}, CiphersServerClient: []string{"cipher3", "cipher2"}, }, wantClient: algorithms{ r: directionAlgorithms{ Cipher: "cipher3", }, w: directionAlgorithms{ Cipher: "cipher2", }, }, wantServer: algorithms{ w: directionAlgorithms{ Cipher: "cipher3", }, r: directionAlgorithms{ Cipher: "cipher2", }, }, }, // TODO(hanwen): fix and add tests for AEAD ignoring // the MACs field } for i := range cases { initKex(&cases[i].clientIn) initKex(&cases[i].serverIn) initAlgs(&cases[i].wantClient) initAlgs(&cases[i].wantServer) } for _, c := range cases { t.Run(c.name, func(t *testing.T) { serverAlgs, serverErr := findAgreedAlgorithms(false, &c.clientIn, &c.serverIn) clientAlgs, clientErr := findAgreedAlgorithms(true, &c.clientIn, &c.serverIn) serverHasErr := serverErr != nil clientHasErr := clientErr != nil if c.wantErr != serverHasErr || c.wantErr != clientHasErr { t.Fatalf("got client/server error (%v, %v), want hasError %v", clientErr, serverErr, c.wantErr) } if c.wantErr { return } if !reflect.DeepEqual(serverAlgs, &c.wantServer) { t.Errorf("server: got algs %#v, want %#v", serverAlgs, &c.wantServer) } if !reflect.DeepEqual(clientAlgs, &c.wantClient) { t.Errorf("server: got algs %#v, want %#v", clientAlgs, &c.wantClient) } }) } }