From 134d819bdf377afd13c329f25ea4f06a1a2dafb4 Mon Sep 17 00:00:00 2001 From: Hidde Beydals Date: Tue, 9 Jun 2020 18:23:45 +0200 Subject: [PATCH] ssh: add in-memory knownhosts utilities --- pkg/ssh/knownhosts/knownhosts.go | 446 ++++++++++++++++++++++++++ pkg/ssh/knownhosts/knownhosts_test.go | 327 +++++++++++++++++++ 2 files changed, 773 insertions(+) create mode 100644 pkg/ssh/knownhosts/knownhosts.go create mode 100644 pkg/ssh/knownhosts/knownhosts_test.go diff --git a/pkg/ssh/knownhosts/knownhosts.go b/pkg/ssh/knownhosts/knownhosts.go new file mode 100644 index 00000000..b83d11fd --- /dev/null +++ b/pkg/ssh/knownhosts/knownhosts.go @@ -0,0 +1,446 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Copyright 2020 The FluxCD contributors. All rights reserved. +// This package provides an in-memory known hosts database +// derived from the golang.org/x/crypto/ssh/knownhosts +// package. +// It has been slightly modified and adapted to work with +// in-memory host keys not related to any known_hosts files +// on disk, and the database can be initialized with just a +// known_hosts byte blob. +// https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts + +package knownhosts + +import ( + "bufio" + "bytes" + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "strings" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" +) + +// See the sshd manpage +// (http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT) for +// background. + +type addr struct{ host, port string } + +func (a *addr) String() string { + h := a.host + if strings.Contains(h, ":") { + h = "[" + h + "]" + } + return h + ":" + a.port +} + +type matcher interface { + match(addr) bool +} + +type hostPattern struct { + negate bool + addr addr +} + +func (p *hostPattern) String() string { + n := "" + if p.negate { + n = "!" + } + + return n + p.addr.String() +} + +type hostPatterns []hostPattern + +func (ps hostPatterns) match(a addr) bool { + matched := false + for _, p := range ps { + if !p.match(a) { + continue + } + if p.negate { + return false + } + matched = true + } + return matched +} + +// See +// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/addrmatch.c +// The matching of * has no regard for separators, unlike filesystem globs +func wildcardMatch(pat []byte, str []byte) bool { + for { + if len(pat) == 0 { + return len(str) == 0 + } + if len(str) == 0 { + return false + } + + if pat[0] == '*' { + if len(pat) == 1 { + return true + } + + for j := range str { + if wildcardMatch(pat[1:], str[j:]) { + return true + } + } + return false + } + + if pat[0] == '?' || pat[0] == str[0] { + pat = pat[1:] + str = str[1:] + } else { + return false + } + } +} + +func (p *hostPattern) match(a addr) bool { + return wildcardMatch([]byte(p.addr.host), []byte(a.host)) && p.addr.port == a.port +} + +type inMemoryHostKeyDB struct { + hostKeys []hostKey + revoked map[string]*ssh.PublicKey +} + +func newInMemoryHostKeyDB() *inMemoryHostKeyDB { + db := &inMemoryHostKeyDB{ + revoked: make(map[string]*ssh.PublicKey), + } + + return db +} + +func keyEq(a, b ssh.PublicKey) bool { + return bytes.Equal(a.Marshal(), b.Marshal()) +} + +type hostKey struct { + matcher matcher + cert bool + key ssh.PublicKey +} + +func (l *hostKey) match(a addr) bool { + return l.matcher.match(a) +} + +// IsAuthorityForHost can be used as a callback in ssh.CertChecker +func (db *inMemoryHostKeyDB) IsHostAuthority(remote ssh.PublicKey, address string) bool { + h, p, err := net.SplitHostPort(address) + if err != nil { + return false + } + a := addr{host: h, port: p} + + for _, l := range db.hostKeys { + if l.cert && keyEq(l.key, remote) && l.match(a) { + return true + } + } + return false +} + +// IsRevoked can be used as a callback in ssh.CertChecker +func (db *inMemoryHostKeyDB) IsRevoked(key *ssh.Certificate) bool { + _, ok := db.revoked[string(key.Marshal())] + return ok +} + +const markerCert = "@cert-authority" +const markerRevoked = "@revoked" + +func nextWord(line []byte) (string, []byte) { + i := bytes.IndexAny(line, "\t ") + if i == -1 { + return string(line), nil + } + + return string(line[:i]), bytes.TrimSpace(line[i:]) +} + +func parseLine(line []byte) (marker, host string, key ssh.PublicKey, err error) { + if w, next := nextWord(line); w == markerCert || w == markerRevoked { + marker = w + line = next + } + + host, line = nextWord(line) + if len(line) == 0 { + return "", "", nil, errors.New("knownhosts: missing host pattern") + } + + // ignore the keytype as it's in the key blob anyway. + _, line = nextWord(line) + if len(line) == 0 { + return "", "", nil, errors.New("knownhosts: missing key type pattern") + } + + keyBlob, _ := nextWord(line) + + keyBytes, err := base64.StdEncoding.DecodeString(keyBlob) + if err != nil { + return "", "", nil, err + } + key, err = ssh.ParsePublicKey(keyBytes) + if err != nil { + return "", "", nil, err + } + + return marker, host, key, nil +} + +func (db *inMemoryHostKeyDB) parseLine(line []byte) error { + marker, pattern, key, err := parseLine(line) + if err != nil { + return err + } + + if marker == markerRevoked { + db.revoked[string(key.Marshal())] = &key + return nil + } + + entry := hostKey{ + key: key, + cert: marker == markerCert, + } + + if pattern[0] == '|' { + entry.matcher, err = newHashedHost(pattern) + } else { + entry.matcher, err = newHostnameMatcher(pattern) + } + + if err != nil { + return err + } + + db.hostKeys = append(db.hostKeys, entry) + return nil +} + +func newHostnameMatcher(pattern string) (matcher, error) { + var hps hostPatterns + for _, p := range strings.Split(pattern, ",") { + if len(p) == 0 { + continue + } + + var a addr + var negate bool + if p[0] == '!' { + negate = true + p = p[1:] + } + + if len(p) == 0 { + return nil, errors.New("knownhosts: negation without following hostname") + } + + var err error + if p[0] == '[' { + a.host, a.port, err = net.SplitHostPort(p) + if err != nil { + return nil, err + } + } else { + a.host, a.port, err = net.SplitHostPort(p) + if err != nil { + a.host = p + a.port = "22" + } + } + hps = append(hps, hostPattern{ + negate: negate, + addr: a, + }) + } + return hps, nil +} + +// check checks a key against the host database. This should not be +// used for verifying certificates. +func (db *inMemoryHostKeyDB) check(address string, remote net.Addr, remoteKey ssh.PublicKey) error { + if revoked := db.revoked[string(remoteKey.Marshal())]; revoked != nil { + return &knownhosts.RevokedError{Revoked: knownhosts.KnownKey{Key: *revoked}} + } + + host, port, err := net.SplitHostPort(remote.String()) + if err != nil { + return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", remote, err) + } + + hostToCheck := addr{host, port} + if address != "" { + // Give preference to the hostname if available. + host, port, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", address, err) + } + + hostToCheck = addr{host, port} + } + + return db.checkAddr(hostToCheck, remoteKey) +} + +// checkAddr checks if we can find the given public key for the +// given address. If we only find an entry for the IP address, +// or only the hostname, then this still succeeds. +func (db *inMemoryHostKeyDB) checkAddr(a addr, remoteKey ssh.PublicKey) error { + // TODO(hanwen): are these the right semantics? What if there + // is just a key for the IP address, but not for the + // hostname? + + // Algorithm => key. + knownKeys := map[string]ssh.PublicKey{} + for _, l := range db.hostKeys { + if l.match(a) { + typ := l.key.Type() + if _, ok := knownKeys[typ]; !ok { + knownKeys[typ] = l.key + } + } + } + + keyErr := &knownhosts.KeyError{} + for _, v := range knownKeys { + keyErr.Want = append(keyErr.Want, knownhosts.KnownKey{Key: v}) + } + + // Unknown remote host. + if len(knownKeys) == 0 { + return keyErr + } + + // If the remote host starts using a different, unknown key type, we + // also interpret that as a mismatch. + if known, ok := knownKeys[remoteKey.Type()]; !ok || !keyEq(known, remoteKey) { + return keyErr + } + + return nil +} + +// The Read function parses file contents. +func (db *inMemoryHostKeyDB) Read(r io.Reader) error { + scanner := bufio.NewScanner(r) + + lineNum := 0 + for scanner.Scan() { + lineNum++ + line := scanner.Bytes() + line = bytes.TrimSpace(line) + if len(line) == 0 || line[0] == '#' { + continue + } + + if err := db.parseLine(line); err != nil { + return fmt.Errorf("knownhosts: %v", err) + } + } + return scanner.Err() +} + +// New creates a host key callback from the given OpenSSH host key +// file bytes. The returned callback is for use in +// ssh.ClientConfig.HostKeyCallback. By preference, the key check +// operates on the hostname if available, i.e. if a server changes its +// IP address, the host key check will still succeed, even though a +// record of the new IP address is not available. +func New(b []byte) (ssh.HostKeyCallback, error) { + db := newInMemoryHostKeyDB() + r := bytes.NewReader(b) + if err := db.Read(r); err != nil { + return nil, err + } + + var certChecker ssh.CertChecker + certChecker.IsHostAuthority = db.IsHostAuthority + certChecker.IsRevoked = db.IsRevoked + certChecker.HostKeyFallback = db.check + + return certChecker.CheckHostKey, nil +} + +func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) { + if len(encoded) == 0 || encoded[0] != '|' { + err = errors.New("knownhosts: hashed host must start with '|'") + return + } + components := strings.Split(encoded, "|") + if len(components) != 4 { + err = fmt.Errorf("knownhosts: got %d components, want 3", len(components)) + return + } + + hashType = components[1] + if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil { + return + } + if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil { + return + } + return +} + +func encodeHash(typ string, salt []byte, hash []byte) string { + return strings.Join([]string{"", + typ, + base64.StdEncoding.EncodeToString(salt), + base64.StdEncoding.EncodeToString(hash), + }, "|") +} + +// See https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120 +func hashHost(hostname string, salt []byte) []byte { + mac := hmac.New(sha1.New, salt) + mac.Write([]byte(hostname)) + return mac.Sum(nil) +} + +type hashedHost struct { + salt []byte + hash []byte +} + +const sha1HashType = "1" + +func newHashedHost(encoded string) (*hashedHost, error) { + typ, salt, hash, err := decodeHash(encoded) + if err != nil { + return nil, err + } + + // The type field seems for future algorithm agility, but it's + // actually hardcoded in openssh currently, see + // https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120 + if typ != sha1HashType { + return nil, fmt.Errorf("knownhosts: got hash type %s, must be '1'", typ) + } + + return &hashedHost{salt: salt, hash: hash}, nil +} + +func (h *hashedHost) match(a addr) bool { + return bytes.Equal(hashHost(knownhosts.Normalize(a.String()), h.salt), h.hash) +} diff --git a/pkg/ssh/knownhosts/knownhosts_test.go b/pkg/ssh/knownhosts/knownhosts_test.go new file mode 100644 index 00000000..8d057c5a --- /dev/null +++ b/pkg/ssh/knownhosts/knownhosts_test.go @@ -0,0 +1,327 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Copyright 2020 The FluxCD contributors. All rights reserved. +// This package provides an in-memory known hosts database +// derived from the golang.org/x/crypto/ssh/knownhosts +// package. +// It has been slightly modified and adapted to work with +// in-memory host keys not related to any known_hosts files +// on disk, and the database can be initialized with just a +// known_hosts byte blob. +// https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts + +package knownhosts + +import ( + "bytes" + "fmt" + "net" + "reflect" + "testing" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" +) + +const edKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGBAarftlLeoyf+v+nVchEZII/vna2PCV8FaX4vsF5BX" +const alternateEdKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIIXffBYeYL+WVzVru8npl5JHt2cjlr4ornFTWzoij9sx" +const ecKeyStr = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNLCu01+wpXe3xB5olXCN4SqU2rQu0qjSRKJO4Bg+JRCPU+ENcgdA5srTU8xYDz/GEa4dzK5ldPw4J/gZgSXCMs=" + +var ecKey, alternateEdKey, edKey ssh.PublicKey +var testAddr = &net.TCPAddr{ + IP: net.IP{198, 41, 30, 196}, + Port: 22, +} + +var testAddr6 = &net.TCPAddr{ + IP: net.IP{198, 41, 30, 196, + 1, 2, 3, 4, + 1, 2, 3, 4, + 1, 2, 3, 4, + }, + Port: 22, +} + +func init() { + var err error + ecKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(ecKeyStr)) + if err != nil { + panic(err) + } + edKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(edKeyStr)) + if err != nil { + panic(err) + } + alternateEdKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(alternateEdKeyStr)) + if err != nil { + panic(err) + } +} + +func testDB(t *testing.T, s string) *inMemoryHostKeyDB { + db := newInMemoryHostKeyDB() + if err := db.Read(bytes.NewBufferString(s)); err != nil { + t.Fatalf("Read: %v", err) + } + + return db +} + +func TestRevoked(t *testing.T) { + db := testDB(t, "\n\n@revoked * "+edKeyStr+"\n") + want := &knownhosts.RevokedError{ + Revoked: knownhosts.KnownKey{ + Key: edKey, + }, + } + if err := db.check("", &net.TCPAddr{ + Port: 42, + }, edKey); err == nil { + t.Fatal("no error for revoked key") + } else if !reflect.DeepEqual(want, err) { + t.Fatalf("got %#v, want %#v", want, err) + } +} + +func TestHostAuthority(t *testing.T) { + for _, m := range []struct { + authorityFor string + address string + + good bool + }{ + {authorityFor: "localhost", address: "localhost:22", good: true}, + {authorityFor: "localhost", address: "localhost", good: false}, + {authorityFor: "localhost", address: "localhost:1234", good: false}, + {authorityFor: "[localhost]:1234", address: "localhost:1234", good: true}, + {authorityFor: "[localhost]:1234", address: "localhost:22", good: false}, + {authorityFor: "[localhost]:1234", address: "localhost", good: false}, + } { + db := testDB(t, `@cert-authority `+m.authorityFor+` `+edKeyStr) + if ok := db.IsHostAuthority(db.hostKeys[0].key, m.address); ok != m.good { + t.Errorf("IsHostAuthority: authority %s, address %s, wanted good = %v, got good = %v", + m.authorityFor, m.address, m.good, ok) + } + } +} + +func TestBracket(t *testing.T) { + db := testDB(t, `[git.eclipse.org]:29418,[198.41.30.196]:29418 `+edKeyStr) + + if err := db.check("git.eclipse.org:29418", &net.TCPAddr{ + IP: net.IP{198, 41, 30, 196}, + Port: 29418, + }, edKey); err != nil { + t.Errorf("got error %v, want none", err) + } + + if err := db.check("git.eclipse.org:29419", &net.TCPAddr{ + Port: 42, + }, edKey); err == nil { + t.Fatalf("no error for unknown address") + } else if ke, ok := err.(*knownhosts.KeyError); !ok { + t.Fatalf("got type %T, want *KeyError", err) + } else if len(ke.Want) > 0 { + t.Fatalf("got Want %v, want []", ke.Want) + } +} + +func TestNewKeyType(t *testing.T) { + str := fmt.Sprintf("%s %s", testAddr, edKeyStr) + db := testDB(t, str) + if err := db.check("", testAddr, ecKey); err == nil { + t.Fatalf("no error for unknown address") + } else if ke, ok := err.(*knownhosts.KeyError); !ok { + t.Fatalf("got type %T, want *KeyError", err) + } else if len(ke.Want) == 0 { + t.Fatalf("got empty KeyError.Want") + } +} + +func TestSameKeyType(t *testing.T) { + str := fmt.Sprintf("%s %s", testAddr, edKeyStr) + db := testDB(t, str) + if err := db.check("", testAddr, alternateEdKey); err == nil { + t.Fatalf("no error for unknown address") + } else if ke, ok := err.(*knownhosts.KeyError); !ok { + t.Fatalf("got type %T, want *KeyError", err) + } else if len(ke.Want) == 0 { + t.Fatalf("got empty KeyError.Want") + } else if got, want := ke.Want[0].Key.Marshal(), edKey.Marshal(); !bytes.Equal(got, want) { + t.Fatalf("got key %q, want %q", got, want) + } +} + +func TestIPAddress(t *testing.T) { + str := fmt.Sprintf("%s %s", testAddr, edKeyStr) + db := testDB(t, str) + if err := db.check("", testAddr, edKey); err != nil { + t.Errorf("got error %q, want none", err) + } +} + +func TestIPv6Address(t *testing.T) { + str := fmt.Sprintf("%s %s", testAddr6, edKeyStr) + db := testDB(t, str) + + if err := db.check("", testAddr6, edKey); err != nil { + t.Errorf("got error %q, want none", err) + } +} + +func TestBasic(t *testing.T) { + str := fmt.Sprintf("#comment\n\nserver.org,%s %s\notherhost %s", testAddr, edKeyStr, ecKeyStr) + db := testDB(t, str) + if err := db.check("server.org:22", testAddr, edKey); err != nil { + t.Errorf("got error %v, want none", err) + } + + want := knownhosts.KnownKey{ + Key: edKey, + } + if err := db.check("server.org:22", testAddr, ecKey); err == nil { + t.Errorf("succeeded, want KeyError") + } else if ke, ok := err.(*knownhosts.KeyError); !ok { + t.Errorf("got %T, want *KeyError", err) + } else if len(ke.Want) != 1 { + t.Errorf("got %v, want 1 entry", ke) + } else if !reflect.DeepEqual(ke.Want[0], want) { + t.Errorf("got %v, want %v", ke.Want[0], want) + } +} + +func TestHostNamePrecedence(t *testing.T) { + var evilAddr = &net.TCPAddr{ + IP: net.IP{66, 66, 66, 66}, + Port: 22, + } + + str := fmt.Sprintf("server.org,%s %s\nevil.org,%s %s", testAddr, edKeyStr, evilAddr, ecKeyStr) + db := testDB(t, str) + + if err := db.check("server.org:22", evilAddr, ecKey); err == nil { + t.Errorf("check succeeded") + } else if _, ok := err.(*knownhosts.KeyError); !ok { + t.Errorf("got %T, want *KeyError", err) + } +} + +func TestDBOrderingPrecedenceKeyType(t *testing.T) { + str := fmt.Sprintf("server.org,%s %s\nserver.org,%s %s", testAddr, edKeyStr, testAddr, alternateEdKeyStr) + db := testDB(t, str) + + if err := db.check("server.org:22", testAddr, alternateEdKey); err == nil { + t.Errorf("check succeeded") + } else if _, ok := err.(*knownhosts.KeyError); !ok { + t.Errorf("got %T, want *KeyError", err) + } +} + +func TestNegate(t *testing.T) { + str := fmt.Sprintf("%s,!server.org %s", testAddr, edKeyStr) + db := testDB(t, str) + if err := db.check("server.org:22", testAddr, ecKey); err == nil { + t.Errorf("succeeded") + } else if ke, ok := err.(*knownhosts.KeyError); !ok { + t.Errorf("got error type %T, want *KeyError", err) + } else if len(ke.Want) != 0 { + t.Errorf("got expected keys %d (first of type %s), want []", len(ke.Want), ke.Want[0].Key.Type()) + } +} + +func TestWildcard(t *testing.T) { + str := fmt.Sprintf("server*.domain %s", edKeyStr) + db := testDB(t, str) + + want := &knownhosts.KeyError{ + Want: []knownhosts.KnownKey{{ + Key: edKey, + }}, + } + + got := db.check("server.domain:22", &net.TCPAddr{}, ecKey) + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s, want %s", got, want) + } +} + +func TestWildcardMatch(t *testing.T) { + for _, c := range []struct { + pat, str string + want bool + }{ + {"a?b", "abb", true}, + {"ab", "abc", false}, + {"abc", "ab", false}, + {"a*b", "axxxb", true}, + {"a*b", "axbxb", true}, + {"a*b", "axbxbc", false}, + {"a*?", "axbxc", true}, + {"a*b*", "axxbxxxxxx", true}, + {"a*b*c", "axxbxxxxxxc", true}, + {"a*b*?", "axxbxxxxxxc", true}, + {"a*b*z", "axxbxxbxxxz", true}, + {"a*b*z", "axxbxxzxxxz", true}, + {"a*b*z", "axxbxxzxxx", false}, + } { + got := wildcardMatch([]byte(c.pat), []byte(c.str)) + if got != c.want { + t.Errorf("wildcardMatch(%q, %q) = %v, want %v", c.pat, c.str, got, c.want) + } + + } +} + +// TODO(hanwen): test coverage for certificates. + +const testHostname = "hostname" + +// generated with keygen -H -f +const encodedTestHostnameHash = "|1|IHXZvQMvTcZTUU29+2vXFgx8Frs=|UGccIWfRVDwilMBnA3WJoRAC75Y=" + +func TestHostHash(t *testing.T) { + testHostHash(t, testHostname, encodedTestHostnameHash) +} + +func TestHashList(t *testing.T) { + encoded := knownhosts.HashHostname(testHostname) + testHostHash(t, testHostname, encoded) +} + +func testHostHash(t *testing.T, hostname, encoded string) { + typ, salt, hash, err := decodeHash(encoded) + if err != nil { + t.Fatalf("decodeHash: %v", err) + } + + if got := encodeHash(typ, salt, hash); got != encoded { + t.Errorf("got encoding %s want %s", got, encoded) + } + + if typ != sha1HashType { + t.Fatalf("got hash type %q, want %q", typ, sha1HashType) + } + + got := hashHost(hostname, salt) + if !bytes.Equal(got, hash) { + t.Errorf("got hash %x want %x", got, hash) + } +} + +func TestHashedHostkeyCheck(t *testing.T) { + str := fmt.Sprintf("%s %s", knownhosts.HashHostname(testHostname), edKeyStr) + db := testDB(t, str) + if err := db.check(testHostname+":22", testAddr, edKey); err != nil { + t.Errorf("check(%s): %v", testHostname, err) + } + want := &knownhosts.KeyError{ + Want: []knownhosts.KnownKey{{ + Key: edKey, + }}, + } + if got := db.check(testHostname+":22", testAddr, alternateEdKey); !reflect.DeepEqual(got, want) { + t.Errorf("got error %v, want %v", got, want) + } +}