diff --git a/changelog/16379.txt b/changelog/16379.txt new file mode 100644 index 0000000000000..b5903e3ea11c2 --- /dev/null +++ b/changelog/16379.txt @@ -0,0 +1,3 @@ +```release-note:bug +core: Validate input parameters for vault operator init command +``` \ No newline at end of file diff --git a/command/operator_init.go b/command/operator_init.go index a8b8e56010245..6d67dcd9b6a45 100644 --- a/command/operator_init.go +++ b/command/operator_init.go @@ -40,8 +40,10 @@ type OperatorInitCommand struct { } const ( - defKeyShares = 5 - defKeyThreshold = 3 + defKeyShares = 5 + defKeyThreshold = 3 + defRecoveryShares = 5 + defRecoveryThreshold = 3 ) func (c *OperatorInitCommand) Synopsis() string { @@ -103,7 +105,6 @@ func (c *OperatorInitCommand) Flags() *FlagSets { Name: "key-shares", Aliases: []string{"n"}, Target: &c.flagKeyShares, - Default: defKeyShares, Completion: complete.PredictAnything, Usage: "Number of key shares to split the generated root key into. " + "This is the number of \"unseal keys\" to generate.", @@ -113,7 +114,6 @@ func (c *OperatorInitCommand) Flags() *FlagSets { Name: "key-threshold", Aliases: []string{"t"}, Target: &c.flagKeyThreshold, - Default: defKeyThreshold, Completion: complete.PredictAnything, Usage: "Number of key shares required to reconstruct the root key. " + "This must be less than or equal to -key-shares.", @@ -182,7 +182,6 @@ func (c *OperatorInitCommand) Flags() *FlagSets { f.IntVar(&IntVar{ Name: "recovery-shares", Target: &c.flagRecoveryShares, - Default: 5, Completion: complete.PredictAnything, Usage: "Number of key shares to split the recovery key into. " + "This is only used in auto-unseal mode.", @@ -191,7 +190,6 @@ func (c *OperatorInitCommand) Flags() *FlagSets { f.IntVar(&IntVar{ Name: "recovery-threshold", Target: &c.flagRecoveryThreshold, - Default: 3, Completion: complete.PredictAnything, Usage: "Number of key shares required to reconstruct the recovery key. " + "This is only used in Auto Unseal mode.", @@ -233,6 +231,35 @@ func (c *OperatorInitCommand) Run(args []string) int { if c.flagStoredShares != -1 { c.UI.Warn("-stored-shares has no effect and will be removed in Vault 1.3.\n") } + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + // Set defaults based on use of auto unseal seal + sealInfo, err := client.Sys().SealStatus() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + switch sealInfo.RecoverySeal { + case true: + if c.flagRecoveryShares == 0 { + c.flagRecoveryShares = defRecoveryShares + } + if c.flagRecoveryThreshold == 0 { + c.flagRecoveryThreshold = defRecoveryThreshold + } + default: + if c.flagKeyShares == 0 { + c.flagKeyShares = defKeyShares + } + if c.flagKeyThreshold == 0 { + c.flagKeyThreshold = defKeyThreshold + } + } // Build the initial init request initReq := &api.InitRequest{ @@ -246,12 +273,6 @@ func (c *OperatorInitCommand) Run(args []string) int { RecoveryPGPKeys: c.flagRecoveryPGPKeys, } - client, err := c.Client() - if err != nil { - c.UI.Error(err.Error()) - return 2 - } - // Check auto mode switch { case c.flagStatus: @@ -471,14 +492,6 @@ func (c *OperatorInitCommand) init(client *api.Client, req *api.InitRequest) int req.RecoveryThreshold))) } - if len(resp.RecoveryKeys) > 0 && (req.SecretShares != defKeyShares || req.SecretThreshold != defKeyThreshold) { - c.UI.Output("") - c.UI.Warn(wrapAtLength( - "WARNING! -key-shares and -key-threshold is ignored when " + - "Auto Unseal is used. Use -recovery-shares and -recovery-threshold instead.", - )) - } - return 0 } diff --git a/command/operator_init_test.go b/command/operator_init_test.go index 491d623a14732..ec02873587dfc 100644 --- a/command/operator_init_test.go +++ b/command/operator_init_test.go @@ -355,7 +355,7 @@ func TestOperatorInitCommand_Run(t *testing.T) { t.Errorf("expected %d to be %d", code, exp) } - expected := "Error initializing: " + expected := "Error making API request" combined := ui.OutputWriter.String() + ui.ErrorWriter.String() if !strings.Contains(combined, expected) { t.Errorf("expected %q to contain %q", combined, expected) diff --git a/http/sys_init.go b/http/sys_init.go index b21e5363ea020..ae3059462bef4 100644 --- a/http/sys_init.go +++ b/http/sys_init.go @@ -4,7 +4,9 @@ import ( "context" "encoding/base64" "encoding/hex" + "fmt" "net/http" + "strings" "github.com/hashicorp/vault/vault" ) @@ -44,6 +46,12 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) return } + // Validate init request parameters + if err := validateInitParameters(core, req); err != nil { + respondError(w, http.StatusBadRequest, err) + return + } + // Initialize barrierConfig := &vault.SealConfig{ SecretShares: req.SecretShares, @@ -128,3 +136,41 @@ type InitResponse struct { type InitStatusResponse struct { Initialized bool `json:"initialized"` } + +// Validates if the right parameters are used based on AutoUnseal +func validateInitParameters(core *vault.Core, req InitRequest) error { + recoveryFlags := make([]string, 0) + barrierFlags := make([]string, 0) + + if req.SecretShares != 0 { + barrierFlags = append(barrierFlags, "secret_shares") + } + if req.SecretThreshold != 0 { + barrierFlags = append(barrierFlags, "secret_threshold") + } + if len(req.PGPKeys) != 0 { + barrierFlags = append(barrierFlags, "pgp_keys") + } + if req.RecoveryShares != 0 { + recoveryFlags = append(recoveryFlags, "recovery_shares") + } + if req.RecoveryThreshold != 0 { + recoveryFlags = append(recoveryFlags, "recovery_threshold") + } + if len(req.RecoveryPGPKeys) != 0 { + recoveryFlags = append(recoveryFlags, "recovery_pgp_keys") + } + + switch core.SealAccess().RecoveryKeySupported() { + case true: + if len(barrierFlags) > 0 { + return fmt.Errorf("parameters %s not applicable to seal type %s", strings.Join(barrierFlags, ","), core.SealAccess().BarrierType()) + } + default: + if len(recoveryFlags) > 0 { + return fmt.Errorf("parameters %s not applicable to seal type %s", strings.Join(recoveryFlags, ","), core.SealAccess().BarrierType()) + } + + } + return nil +} diff --git a/http/sys_init_test.go b/http/sys_init_test.go index f4be3413a9022..38a15f6ccc42e 100644 --- a/http/sys_init_test.go +++ b/http/sys_init_test.go @@ -4,9 +4,15 @@ import ( "encoding/hex" "net/http" "reflect" + "strconv" "testing" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/builtin/logical/transit" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" + "github.com/hashicorp/vault/vault/seal" ) func TestSysInit_get(t *testing.T) { @@ -123,3 +129,63 @@ func TestSysInit_put(t *testing.T) { t.Fatal("should not be sealed") } } + +func TestSysInit_Put_ValidateParams(t *testing.T) { + core := vault.TestCore(t) + ln, addr := TestServer(t, core) + defer ln.Close() + + resp := testHttpPut(t, "", addr+"/v1/sys/init", map[string]interface{}{ + "secret_shares": 5, + "secret_threshold": 3, + "recovery_shares": 5, + "recovery_threshold": 3, + }) + testResponseStatus(t, resp, http.StatusBadRequest) + body := map[string][]string{} + testResponseBody(t, resp, &body) + if body["errors"][0] != "parameters recovery_shares,recovery_threshold not applicable to seal type shamir" { + t.Fatal(body) + } +} + +func TestSysInit_Put_ValidateParams_AutoUnseal(t *testing.T) { + testSeal := seal.NewTestSeal(nil) + autoSeal := vault.NewAutoSeal(testSeal) + autoSeal.SetType("transit") + + // Create the transit server. + conf := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "transit": transit.Factory, + }, + Seal: autoSeal, + } + opts := &vault.TestClusterOptions{ + NumCores: 1, + HandlerFunc: Handler, + Logger: logging.NewVaultLogger(hclog.Trace).Named(t.Name()).Named("transit-seal" + strconv.Itoa(0)), + } + cluster := vault.NewTestCluster(t, conf, opts) + cluster.Start() + defer cluster.Cleanup() + + cores := cluster.Cores + core := cores[0].Core + + ln, addr := TestServer(t, core) + defer ln.Close() + + resp := testHttpPut(t, "", addr+"/v1/sys/init", map[string]interface{}{ + "secret_shares": 5, + "secret_threshold": 3, + "recovery_shares": 5, + "recovery_threshold": 3, + }) + testResponseStatus(t, resp, http.StatusBadRequest) + body := map[string][]string{} + testResponseBody(t, resp, &body) + if body["errors"][0] != "parameters secret_shares,secret_threshold not applicable to seal type transit" { + t.Fatal(body) + } +}