ssh: add in-memory knownhosts utilities

pull/34/head
Hidde Beydals 5 years ago
parent 9e31bbe716
commit 134d819bdf

@ -0,0 +1,446 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Copyright 2020 The FluxCD contributors. All rights reserved.
// This package provides an in-memory known hosts database
// derived from the golang.org/x/crypto/ssh/knownhosts
// package.
// It has been slightly modified and adapted to work with
// in-memory host keys not related to any known_hosts files
// on disk, and the database can be initialized with just a
// known_hosts byte blob.
// https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts
package knownhosts
import (
"bufio"
"bytes"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"strings"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)
// See the sshd manpage
// (http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT) for
// background.
type addr struct{ host, port string }
func (a *addr) String() string {
h := a.host
if strings.Contains(h, ":") {
h = "[" + h + "]"
}
return h + ":" + a.port
}
type matcher interface {
match(addr) bool
}
type hostPattern struct {
negate bool
addr addr
}
func (p *hostPattern) String() string {
n := ""
if p.negate {
n = "!"
}
return n + p.addr.String()
}
type hostPatterns []hostPattern
func (ps hostPatterns) match(a addr) bool {
matched := false
for _, p := range ps {
if !p.match(a) {
continue
}
if p.negate {
return false
}
matched = true
}
return matched
}
// See
// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/addrmatch.c
// The matching of * has no regard for separators, unlike filesystem globs
func wildcardMatch(pat []byte, str []byte) bool {
for {
if len(pat) == 0 {
return len(str) == 0
}
if len(str) == 0 {
return false
}
if pat[0] == '*' {
if len(pat) == 1 {
return true
}
for j := range str {
if wildcardMatch(pat[1:], str[j:]) {
return true
}
}
return false
}
if pat[0] == '?' || pat[0] == str[0] {
pat = pat[1:]
str = str[1:]
} else {
return false
}
}
}
func (p *hostPattern) match(a addr) bool {
return wildcardMatch([]byte(p.addr.host), []byte(a.host)) && p.addr.port == a.port
}
type inMemoryHostKeyDB struct {
hostKeys []hostKey
revoked map[string]*ssh.PublicKey
}
func newInMemoryHostKeyDB() *inMemoryHostKeyDB {
db := &inMemoryHostKeyDB{
revoked: make(map[string]*ssh.PublicKey),
}
return db
}
func keyEq(a, b ssh.PublicKey) bool {
return bytes.Equal(a.Marshal(), b.Marshal())
}
type hostKey struct {
matcher matcher
cert bool
key ssh.PublicKey
}
func (l *hostKey) match(a addr) bool {
return l.matcher.match(a)
}
// IsAuthorityForHost can be used as a callback in ssh.CertChecker
func (db *inMemoryHostKeyDB) IsHostAuthority(remote ssh.PublicKey, address string) bool {
h, p, err := net.SplitHostPort(address)
if err != nil {
return false
}
a := addr{host: h, port: p}
for _, l := range db.hostKeys {
if l.cert && keyEq(l.key, remote) && l.match(a) {
return true
}
}
return false
}
// IsRevoked can be used as a callback in ssh.CertChecker
func (db *inMemoryHostKeyDB) IsRevoked(key *ssh.Certificate) bool {
_, ok := db.revoked[string(key.Marshal())]
return ok
}
const markerCert = "@cert-authority"
const markerRevoked = "@revoked"
func nextWord(line []byte) (string, []byte) {
i := bytes.IndexAny(line, "\t ")
if i == -1 {
return string(line), nil
}
return string(line[:i]), bytes.TrimSpace(line[i:])
}
func parseLine(line []byte) (marker, host string, key ssh.PublicKey, err error) {
if w, next := nextWord(line); w == markerCert || w == markerRevoked {
marker = w
line = next
}
host, line = nextWord(line)
if len(line) == 0 {
return "", "", nil, errors.New("knownhosts: missing host pattern")
}
// ignore the keytype as it's in the key blob anyway.
_, line = nextWord(line)
if len(line) == 0 {
return "", "", nil, errors.New("knownhosts: missing key type pattern")
}
keyBlob, _ := nextWord(line)
keyBytes, err := base64.StdEncoding.DecodeString(keyBlob)
if err != nil {
return "", "", nil, err
}
key, err = ssh.ParsePublicKey(keyBytes)
if err != nil {
return "", "", nil, err
}
return marker, host, key, nil
}
func (db *inMemoryHostKeyDB) parseLine(line []byte) error {
marker, pattern, key, err := parseLine(line)
if err != nil {
return err
}
if marker == markerRevoked {
db.revoked[string(key.Marshal())] = &key
return nil
}
entry := hostKey{
key: key,
cert: marker == markerCert,
}
if pattern[0] == '|' {
entry.matcher, err = newHashedHost(pattern)
} else {
entry.matcher, err = newHostnameMatcher(pattern)
}
if err != nil {
return err
}
db.hostKeys = append(db.hostKeys, entry)
return nil
}
func newHostnameMatcher(pattern string) (matcher, error) {
var hps hostPatterns
for _, p := range strings.Split(pattern, ",") {
if len(p) == 0 {
continue
}
var a addr
var negate bool
if p[0] == '!' {
negate = true
p = p[1:]
}
if len(p) == 0 {
return nil, errors.New("knownhosts: negation without following hostname")
}
var err error
if p[0] == '[' {
a.host, a.port, err = net.SplitHostPort(p)
if err != nil {
return nil, err
}
} else {
a.host, a.port, err = net.SplitHostPort(p)
if err != nil {
a.host = p
a.port = "22"
}
}
hps = append(hps, hostPattern{
negate: negate,
addr: a,
})
}
return hps, nil
}
// check checks a key against the host database. This should not be
// used for verifying certificates.
func (db *inMemoryHostKeyDB) check(address string, remote net.Addr, remoteKey ssh.PublicKey) error {
if revoked := db.revoked[string(remoteKey.Marshal())]; revoked != nil {
return &knownhosts.RevokedError{Revoked: knownhosts.KnownKey{Key: *revoked}}
}
host, port, err := net.SplitHostPort(remote.String())
if err != nil {
return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", remote, err)
}
hostToCheck := addr{host, port}
if address != "" {
// Give preference to the hostname if available.
host, port, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", address, err)
}
hostToCheck = addr{host, port}
}
return db.checkAddr(hostToCheck, remoteKey)
}
// checkAddr checks if we can find the given public key for the
// given address. If we only find an entry for the IP address,
// or only the hostname, then this still succeeds.
func (db *inMemoryHostKeyDB) checkAddr(a addr, remoteKey ssh.PublicKey) error {
// TODO(hanwen): are these the right semantics? What if there
// is just a key for the IP address, but not for the
// hostname?
// Algorithm => key.
knownKeys := map[string]ssh.PublicKey{}
for _, l := range db.hostKeys {
if l.match(a) {
typ := l.key.Type()
if _, ok := knownKeys[typ]; !ok {
knownKeys[typ] = l.key
}
}
}
keyErr := &knownhosts.KeyError{}
for _, v := range knownKeys {
keyErr.Want = append(keyErr.Want, knownhosts.KnownKey{Key: v})
}
// Unknown remote host.
if len(knownKeys) == 0 {
return keyErr
}
// If the remote host starts using a different, unknown key type, we
// also interpret that as a mismatch.
if known, ok := knownKeys[remoteKey.Type()]; !ok || !keyEq(known, remoteKey) {
return keyErr
}
return nil
}
// The Read function parses file contents.
func (db *inMemoryHostKeyDB) Read(r io.Reader) error {
scanner := bufio.NewScanner(r)
lineNum := 0
for scanner.Scan() {
lineNum++
line := scanner.Bytes()
line = bytes.TrimSpace(line)
if len(line) == 0 || line[0] == '#' {
continue
}
if err := db.parseLine(line); err != nil {
return fmt.Errorf("knownhosts: %v", err)
}
}
return scanner.Err()
}
// New creates a host key callback from the given OpenSSH host key
// file bytes. The returned callback is for use in
// ssh.ClientConfig.HostKeyCallback. By preference, the key check
// operates on the hostname if available, i.e. if a server changes its
// IP address, the host key check will still succeed, even though a
// record of the new IP address is not available.
func New(b []byte) (ssh.HostKeyCallback, error) {
db := newInMemoryHostKeyDB()
r := bytes.NewReader(b)
if err := db.Read(r); err != nil {
return nil, err
}
var certChecker ssh.CertChecker
certChecker.IsHostAuthority = db.IsHostAuthority
certChecker.IsRevoked = db.IsRevoked
certChecker.HostKeyFallback = db.check
return certChecker.CheckHostKey, nil
}
func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) {
if len(encoded) == 0 || encoded[0] != '|' {
err = errors.New("knownhosts: hashed host must start with '|'")
return
}
components := strings.Split(encoded, "|")
if len(components) != 4 {
err = fmt.Errorf("knownhosts: got %d components, want 3", len(components))
return
}
hashType = components[1]
if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil {
return
}
if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil {
return
}
return
}
func encodeHash(typ string, salt []byte, hash []byte) string {
return strings.Join([]string{"",
typ,
base64.StdEncoding.EncodeToString(salt),
base64.StdEncoding.EncodeToString(hash),
}, "|")
}
// See https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
func hashHost(hostname string, salt []byte) []byte {
mac := hmac.New(sha1.New, salt)
mac.Write([]byte(hostname))
return mac.Sum(nil)
}
type hashedHost struct {
salt []byte
hash []byte
}
const sha1HashType = "1"
func newHashedHost(encoded string) (*hashedHost, error) {
typ, salt, hash, err := decodeHash(encoded)
if err != nil {
return nil, err
}
// The type field seems for future algorithm agility, but it's
// actually hardcoded in openssh currently, see
// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
if typ != sha1HashType {
return nil, fmt.Errorf("knownhosts: got hash type %s, must be '1'", typ)
}
return &hashedHost{salt: salt, hash: hash}, nil
}
func (h *hashedHost) match(a addr) bool {
return bytes.Equal(hashHost(knownhosts.Normalize(a.String()), h.salt), h.hash)
}

@ -0,0 +1,327 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Copyright 2020 The FluxCD contributors. All rights reserved.
// This package provides an in-memory known hosts database
// derived from the golang.org/x/crypto/ssh/knownhosts
// package.
// It has been slightly modified and adapted to work with
// in-memory host keys not related to any known_hosts files
// on disk, and the database can be initialized with just a
// known_hosts byte blob.
// https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts
package knownhosts
import (
"bytes"
"fmt"
"net"
"reflect"
"testing"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)
const edKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGBAarftlLeoyf+v+nVchEZII/vna2PCV8FaX4vsF5BX"
const alternateEdKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIIXffBYeYL+WVzVru8npl5JHt2cjlr4ornFTWzoij9sx"
const ecKeyStr = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNLCu01+wpXe3xB5olXCN4SqU2rQu0qjSRKJO4Bg+JRCPU+ENcgdA5srTU8xYDz/GEa4dzK5ldPw4J/gZgSXCMs="
var ecKey, alternateEdKey, edKey ssh.PublicKey
var testAddr = &net.TCPAddr{
IP: net.IP{198, 41, 30, 196},
Port: 22,
}
var testAddr6 = &net.TCPAddr{
IP: net.IP{198, 41, 30, 196,
1, 2, 3, 4,
1, 2, 3, 4,
1, 2, 3, 4,
},
Port: 22,
}
func init() {
var err error
ecKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(ecKeyStr))
if err != nil {
panic(err)
}
edKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(edKeyStr))
if err != nil {
panic(err)
}
alternateEdKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(alternateEdKeyStr))
if err != nil {
panic(err)
}
}
func testDB(t *testing.T, s string) *inMemoryHostKeyDB {
db := newInMemoryHostKeyDB()
if err := db.Read(bytes.NewBufferString(s)); err != nil {
t.Fatalf("Read: %v", err)
}
return db
}
func TestRevoked(t *testing.T) {
db := testDB(t, "\n\n@revoked * "+edKeyStr+"\n")
want := &knownhosts.RevokedError{
Revoked: knownhosts.KnownKey{
Key: edKey,
},
}
if err := db.check("", &net.TCPAddr{
Port: 42,
}, edKey); err == nil {
t.Fatal("no error for revoked key")
} else if !reflect.DeepEqual(want, err) {
t.Fatalf("got %#v, want %#v", want, err)
}
}
func TestHostAuthority(t *testing.T) {
for _, m := range []struct {
authorityFor string
address string
good bool
}{
{authorityFor: "localhost", address: "localhost:22", good: true},
{authorityFor: "localhost", address: "localhost", good: false},
{authorityFor: "localhost", address: "localhost:1234", good: false},
{authorityFor: "[localhost]:1234", address: "localhost:1234", good: true},
{authorityFor: "[localhost]:1234", address: "localhost:22", good: false},
{authorityFor: "[localhost]:1234", address: "localhost", good: false},
} {
db := testDB(t, `@cert-authority `+m.authorityFor+` `+edKeyStr)
if ok := db.IsHostAuthority(db.hostKeys[0].key, m.address); ok != m.good {
t.Errorf("IsHostAuthority: authority %s, address %s, wanted good = %v, got good = %v",
m.authorityFor, m.address, m.good, ok)
}
}
}
func TestBracket(t *testing.T) {
db := testDB(t, `[git.eclipse.org]:29418,[198.41.30.196]:29418 `+edKeyStr)
if err := db.check("git.eclipse.org:29418", &net.TCPAddr{
IP: net.IP{198, 41, 30, 196},
Port: 29418,
}, edKey); err != nil {
t.Errorf("got error %v, want none", err)
}
if err := db.check("git.eclipse.org:29419", &net.TCPAddr{
Port: 42,
}, edKey); err == nil {
t.Fatalf("no error for unknown address")
} else if ke, ok := err.(*knownhosts.KeyError); !ok {
t.Fatalf("got type %T, want *KeyError", err)
} else if len(ke.Want) > 0 {
t.Fatalf("got Want %v, want []", ke.Want)
}
}
func TestNewKeyType(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr, ecKey); err == nil {
t.Fatalf("no error for unknown address")
} else if ke, ok := err.(*knownhosts.KeyError); !ok {
t.Fatalf("got type %T, want *KeyError", err)
} else if len(ke.Want) == 0 {
t.Fatalf("got empty KeyError.Want")
}
}
func TestSameKeyType(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr, alternateEdKey); err == nil {
t.Fatalf("no error for unknown address")
} else if ke, ok := err.(*knownhosts.KeyError); !ok {
t.Fatalf("got type %T, want *KeyError", err)
} else if len(ke.Want) == 0 {
t.Fatalf("got empty KeyError.Want")
} else if got, want := ke.Want[0].Key.Marshal(), edKey.Marshal(); !bytes.Equal(got, want) {
t.Fatalf("got key %q, want %q", got, want)
}
}
func TestIPAddress(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr, edKey); err != nil {
t.Errorf("got error %q, want none", err)
}
}
func TestIPv6Address(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr6, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr6, edKey); err != nil {
t.Errorf("got error %q, want none", err)
}
}
func TestBasic(t *testing.T) {
str := fmt.Sprintf("#comment\n\nserver.org,%s %s\notherhost %s", testAddr, edKeyStr, ecKeyStr)
db := testDB(t, str)
if err := db.check("server.org:22", testAddr, edKey); err != nil {
t.Errorf("got error %v, want none", err)
}
want := knownhosts.KnownKey{
Key: edKey,
}
if err := db.check("server.org:22", testAddr, ecKey); err == nil {
t.Errorf("succeeded, want KeyError")
} else if ke, ok := err.(*knownhosts.KeyError); !ok {
t.Errorf("got %T, want *KeyError", err)
} else if len(ke.Want) != 1 {
t.Errorf("got %v, want 1 entry", ke)
} else if !reflect.DeepEqual(ke.Want[0], want) {
t.Errorf("got %v, want %v", ke.Want[0], want)
}
}
func TestHostNamePrecedence(t *testing.T) {
var evilAddr = &net.TCPAddr{
IP: net.IP{66, 66, 66, 66},
Port: 22,
}
str := fmt.Sprintf("server.org,%s %s\nevil.org,%s %s", testAddr, edKeyStr, evilAddr, ecKeyStr)
db := testDB(t, str)
if err := db.check("server.org:22", evilAddr, ecKey); err == nil {
t.Errorf("check succeeded")
} else if _, ok := err.(*knownhosts.KeyError); !ok {
t.Errorf("got %T, want *KeyError", err)
}
}
func TestDBOrderingPrecedenceKeyType(t *testing.T) {
str := fmt.Sprintf("server.org,%s %s\nserver.org,%s %s", testAddr, edKeyStr, testAddr, alternateEdKeyStr)
db := testDB(t, str)
if err := db.check("server.org:22", testAddr, alternateEdKey); err == nil {
t.Errorf("check succeeded")
} else if _, ok := err.(*knownhosts.KeyError); !ok {
t.Errorf("got %T, want *KeyError", err)
}
}
func TestNegate(t *testing.T) {
str := fmt.Sprintf("%s,!server.org %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("server.org:22", testAddr, ecKey); err == nil {
t.Errorf("succeeded")
} else if ke, ok := err.(*knownhosts.KeyError); !ok {
t.Errorf("got error type %T, want *KeyError", err)
} else if len(ke.Want) != 0 {
t.Errorf("got expected keys %d (first of type %s), want []", len(ke.Want), ke.Want[0].Key.Type())
}
}
func TestWildcard(t *testing.T) {
str := fmt.Sprintf("server*.domain %s", edKeyStr)
db := testDB(t, str)
want := &knownhosts.KeyError{
Want: []knownhosts.KnownKey{{
Key: edKey,
}},
}
got := db.check("server.domain:22", &net.TCPAddr{}, ecKey)
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s, want %s", got, want)
}
}
func TestWildcardMatch(t *testing.T) {
for _, c := range []struct {
pat, str string
want bool
}{
{"a?b", "abb", true},
{"ab", "abc", false},
{"abc", "ab", false},
{"a*b", "axxxb", true},
{"a*b", "axbxb", true},
{"a*b", "axbxbc", false},
{"a*?", "axbxc", true},
{"a*b*", "axxbxxxxxx", true},
{"a*b*c", "axxbxxxxxxc", true},
{"a*b*?", "axxbxxxxxxc", true},
{"a*b*z", "axxbxxbxxxz", true},
{"a*b*z", "axxbxxzxxxz", true},
{"a*b*z", "axxbxxzxxx", false},
} {
got := wildcardMatch([]byte(c.pat), []byte(c.str))
if got != c.want {
t.Errorf("wildcardMatch(%q, %q) = %v, want %v", c.pat, c.str, got, c.want)
}
}
}
// TODO(hanwen): test coverage for certificates.
const testHostname = "hostname"
// generated with keygen -H -f
const encodedTestHostnameHash = "|1|IHXZvQMvTcZTUU29+2vXFgx8Frs=|UGccIWfRVDwilMBnA3WJoRAC75Y="
func TestHostHash(t *testing.T) {
testHostHash(t, testHostname, encodedTestHostnameHash)
}
func TestHashList(t *testing.T) {
encoded := knownhosts.HashHostname(testHostname)
testHostHash(t, testHostname, encoded)
}
func testHostHash(t *testing.T, hostname, encoded string) {
typ, salt, hash, err := decodeHash(encoded)
if err != nil {
t.Fatalf("decodeHash: %v", err)
}
if got := encodeHash(typ, salt, hash); got != encoded {
t.Errorf("got encoding %s want %s", got, encoded)
}
if typ != sha1HashType {
t.Fatalf("got hash type %q, want %q", typ, sha1HashType)
}
got := hashHost(hostname, salt)
if !bytes.Equal(got, hash) {
t.Errorf("got hash %x want %x", got, hash)
}
}
func TestHashedHostkeyCheck(t *testing.T) {
str := fmt.Sprintf("%s %s", knownhosts.HashHostname(testHostname), edKeyStr)
db := testDB(t, str)
if err := db.check(testHostname+":22", testAddr, edKey); err != nil {
t.Errorf("check(%s): %v", testHostname, err)
}
want := &knownhosts.KeyError{
Want: []knownhosts.KnownKey{{
Key: edKey,
}},
}
if got := db.check(testHostname+":22", testAddr, alternateEdKey); !reflect.DeepEqual(got, want) {
t.Errorf("got error %v, want %v", got, want)
}
}
Loading…
Cancel
Save