diff --git a/command_run.go b/command_run.go index 24b7935166..44ffc22349 100644 --- a/command_run.go +++ b/command_run.go @@ -314,7 +314,11 @@ func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context if err := cmd.checkAllRequiredFlags(); err != nil { cmd.isInError = true - _ = ShowSubcommandHelp(cmd) + if cmd.OnUsageError != nil { + err = cmd.OnUsageError(ctx, cmd, err, cmd.parent != nil) + } else { + _ = ShowSubcommandHelp(cmd) + } return ctx, err } diff --git a/command_test.go b/command_test.go index b9e69bc307..6df83cdb75 100644 --- a/command_test.go +++ b/command_test.go @@ -4096,6 +4096,21 @@ func TestCheckRequiredFlags(t *testing.T) { } } +func TestCheckRequiredFlagsWithOnUsageError(t *testing.T) { + expectedError := errors.New("OnUsageError") + cmd := &Command{ + Name: "foo", + Flags: []Flag{ + &StringFlag{Name: "requiredFlag", Required: true}, + }, + OnUsageError: func(_ context.Context, _ *Command, _ error, _ bool) error { + return expectedError + }, + } + actualError := cmd.Run(buildTestContext(t), []string{"requiredFlag"}) + require.ErrorIs(t, actualError, expectedError) +} + func TestCommand_ParentCommand_Set(t *testing.T) { cmd := &Command{ parent: &Command{