package main

import (
	"crypto/elliptic"
	"fmt"
	"strconv"
	"strings"
)

var supportedPublicKeyAlgorithms = []string{"rsa", "ecdsa", "ed25519"}

type PublicKeyAlgorithm string

func (a *PublicKeyAlgorithm) String() string {
	return string(*a)
}

func (a *PublicKeyAlgorithm) Set(str string) error {
	if strings.TrimSpace(str) == "" {
		return fmt.Errorf("no public key algorithm given, must be one of: %s",
			strings.Join(supportedPublicKeyAlgorithms, ", "))
	}
	for _, v := range supportedPublicKeyAlgorithms {
		if str == v {
			*a = PublicKeyAlgorithm(str)
			return nil
		}
	}
	return fmt.Errorf("unsupported public key algorithm '%s', must be one of: %s",
		str, strings.Join(supportedPublicKeyAlgorithms, ", "))
}

func (a *PublicKeyAlgorithm) Type() string {
	return "publicKeyAlgorithm"
}

func (a *PublicKeyAlgorithm) Description() string {
	return fmt.Sprintf("SSH public key algorithm (%s)", strings.Join(supportedPublicKeyAlgorithms, ", "))
}

var defaultRSAKeyBits = 2048

type RSAKeyBits int

func (b *RSAKeyBits) String() string {
	return strconv.Itoa(int(*b))
}

func (b *RSAKeyBits) Set(str string) error {
	if strings.TrimSpace(str) == "" {
		*b = RSAKeyBits(defaultRSAKeyBits)
		return nil
	}
	bits, err := strconv.Atoi(str)
	if err != nil {
		return err
	}
	if bits%8 != 0 {
		return fmt.Errorf("RSA key bit size should be a multiples of 8")
	}
	*b = RSAKeyBits(bits)
	return nil
}

func (b *RSAKeyBits) Type() string {
	return "rsaKeyBits"
}

func (b *RSAKeyBits) Description() string {
	return "SSH RSA public key bit size (multiplies of 8)"
}

type ECDSACurve struct {
	elliptic.Curve
}

var supportedECDSACurves = map[string]elliptic.Curve{
	"p256": elliptic.P256(),
	"p384": elliptic.P384(),
	"p521": elliptic.P521(),
}

func (c *ECDSACurve) String() string {
	if c.Curve == nil {
		return ""
	}
	return strings.ToLower(strings.Replace(c.Curve.Params().Name, "-", "", 1))
}

func (c *ECDSACurve) Set(str string) error {
	if v, ok := supportedECDSACurves[str]; ok {
		*c = ECDSACurve{v}
		return nil
	}
	return fmt.Errorf("unsupported curve '%s', should be one of: %s", str, strings.Join(ecdsaCurves(), ", "))
}

func (c *ECDSACurve) Type() string {
	return "ecdsaCurve"
}

func (c *ECDSACurve) Description() string {
	return fmt.Sprintf("SSH ECDSA public key curve (%s)", strings.Join(ecdsaCurves(), ", "))
}

func ecdsaCurves() []string {
	keys := make([]string, 0, len(supportedECDSACurves))
	for k := range supportedECDSACurves {
		keys = append(keys, k)
	}
	return keys
}