diff --git a/internal/flags/rsa_key_bits.go b/internal/flags/rsa_key_bits.go index e214e1bf..716a4617 100644 --- a/internal/flags/rsa_key_bits.go +++ b/internal/flags/rsa_key_bits.go @@ -39,6 +39,9 @@ func (b *RSAKeyBits) Set(str string) error { if err != nil { return err } + if bits < 1024 { + return fmt.Errorf("RSA key bit size must be at least 1024") + } if bits == 0 || bits%8 != 0 { return fmt.Errorf("RSA key bit size must be a multiples of 8") } @@ -51,5 +54,5 @@ func (b *RSAKeyBits) Type() string { } func (b *RSAKeyBits) Description() string { - return "SSH RSA public key bit size (multiplies of 8)" + return "SSH RSA public key bit size (multiplies of 8, min 1024)" } diff --git a/internal/flags/rsa_key_bits_test.go b/internal/flags/rsa_key_bits_test.go index d4bd2965..7199be73 100644 --- a/internal/flags/rsa_key_bits_test.go +++ b/internal/flags/rsa_key_bits_test.go @@ -32,8 +32,8 @@ func TestRSAKeyBits_Set(t *testing.T) { }{ {"supported", "4096", "4096", false}, {"empty (default)", "", "2048", false}, - {"unsupported", "0", "0", true}, - {"unsupported", "123", "0", true}, + {"unsupported", "512", "0", true}, + {"unsupported", "1025", "0", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {