diff --git a/CHANGELOG.md b/CHANGELOG.md index f5d137da..d0c37bc4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed + +- Fixed a riverpro CLI integration point broken in v0.23.0. [PR #945](https://github.com/riverqueue/river/pull/945) + ## [0.23.0] - 2025-06-04 ⚠️ Internal APIs used for communication between River and River Pro have changed. If using River Pro, make sure to update River and River Pro to latest at the same time to get compatible versions. River v0.23.0 is compatible with River Pro v0.15.0. diff --git a/cmd/river/rivercli/command.go b/cmd/river/rivercli/command.go index 24a2e681..4f96f51c 100644 --- a/cmd/river/rivercli/command.go +++ b/cmd/river/rivercli/command.go @@ -46,10 +46,11 @@ type CommandOpts interface { // RunCommandBundle is a bundle of utilities for RunCommand. type RunCommandBundle struct { - DatabaseURL *string - Logger *slog.Logger - OutStd io.Writer - Schema string + DatabaseURL *string + DriverProcurer DriverProcurer + Logger *slog.Logger + OutStd io.Writer + Schema string } // RunCommand bootstraps and runs a River CLI subcommand. @@ -76,7 +77,7 @@ func RunCommand[TOpts CommandOpts](ctx context.Context, bundle *RunCommandBundle } } - var driverProcurer DriverProcurer + driverProcurer := bundle.DriverProcurer if databaseURL != nil { switch protocol { case "postgres", "postgresql": @@ -86,7 +87,12 @@ func RunCommand[TOpts CommandOpts](ctx context.Context, bundle *RunCommandBundle } defer dbPool.Close() - driverProcurer = &pgxV5DriverProcurer{dbPool: dbPool} + driverProcurerPgxV5, isPgxV5Procurer := driverProcurer.(DriverProcurerPgxV5) + if driverProcurer != nil && isPgxV5Procurer { + driverProcurerPgxV5.InitPgxV5(dbPool) + } else { + driverProcurer = &pgxV5DriverProcurer{dbPool: dbPool} + } case "sqlite": dbPool, err := openSQLitePool(protocol, urlWithoutProtocol) diff --git a/cmd/river/rivercli/driver_procurer.go b/cmd/river/rivercli/driver_procurer.go index df59b6b8..40fc6ec9 100644 --- a/cmd/river/rivercli/driver_procurer.go +++ b/cmd/river/rivercli/driver_procurer.go @@ -23,6 +23,10 @@ type DriverProcurer interface { QueryRow(ctx context.Context, sql string, args ...any) riverdriver.Row } +type DriverProcurerPgxV5 interface { + InitPgxV5(pool *pgxpool.Pool) +} + // BenchmarkerInterface is an interface to a Benchmarker. Its reason for // existence is to wrap a benchmarker to strip it of its generic parameter, // letting us pass it around without having to know the transaction type. diff --git a/cmd/river/rivercli/river_cli.go b/cmd/river/rivercli/river_cli.go index 7073d09c..b549b567 100644 --- a/cmd/river/rivercli/river_cli.go +++ b/cmd/river/rivercli/river_cli.go @@ -29,6 +29,10 @@ import ( ) type Config struct { + // DriverProcurer is used to procure a driver for the database. If not + // specified, a default one will be initialized based on the database URL + // scheme. + DriverProcurer DriverProcurer // Name is the human-friendly named of the executable, used while showing // version output. Usually this is just "River", but it could be "River // Pro". @@ -37,14 +41,16 @@ type Config struct { // CLI provides a common base of commands for the River CLI. type CLI struct { - name string - out io.Writer + driverProcurer DriverProcurer + name string + out io.Writer } func NewCLI(config *Config) *CLI { return &CLI{ - name: config.Name, - out: os.Stdout, + driverProcurer: config.DriverProcurer, + name: config.Name, + out: os.Stdout, } } @@ -72,10 +78,11 @@ func (c *CLI) BaseCommandSet() *cobra.Command { // Make a bundle for RunCommand. Takes a database URL pointer because not every command is required to take a database URL. makeCommandBundle := func(databaseURL *string, schema string) *RunCommandBundle { return &RunCommandBundle{ - DatabaseURL: databaseURL, - Logger: makeLogger(), - OutStd: c.out, - Schema: schema, + DatabaseURL: databaseURL, + DriverProcurer: c.driverProcurer, + Logger: makeLogger(), + OutStd: c.out, + Schema: schema, } } diff --git a/cmd/river/rivercli/river_cli_test.go b/cmd/river/rivercli/river_cli_test.go index 03151cbe..7d1b0682 100644 --- a/cmd/river/rivercli/river_cli_test.go +++ b/cmd/river/rivercli/river_cli_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/spf13/cobra" "github.com/stretchr/testify/require" @@ -27,6 +26,7 @@ import ( type DriverProcurerStub struct { getBenchmarkerStub func(config *riverbench.Config) BenchmarkerInterface getMigratorStub func(config *rivermigrate.Config) (MigratorInterface, error) + initPgxV5Stub func(pool *pgxpool.Pool) queryRowStub func(ctx context.Context, sql string, args ...any) riverdriver.Row } @@ -46,6 +46,14 @@ func (p *DriverProcurerStub) GetMigrator(config *rivermigrate.Config) (MigratorI return p.getMigratorStub(config) } +func (p *DriverProcurerStub) InitPgxV5(pool *pgxpool.Pool) { + if p.initPgxV5Stub == nil { + panic("InitPgxV5 is not stubbed") + } + + p.initPgxV5Stub(pool) +} + func (p *DriverProcurerStub) QueryRow(ctx context.Context, sql string, args ...any) riverdriver.Row { if p.queryRowStub == nil { panic("QueryRow is not stubbed") @@ -110,12 +118,6 @@ var ( testMigrationAll = []rivermigrate.Migration{testMigration01, testMigration02, testMigration03} //nolint:gochecknoglobals ) -type TestDriverProcurer struct{} - -func (p *TestDriverProcurer) ProcurePgxV5(pool *pgxpool.Pool) riverdriver.Driver[pgx.Tx] { - return riverpgxv5.New(pool) -} - // High level integration tests that operate on the Cobra command directly. This // isn't always appropriate because there's no way to inject a test transaction. func TestBaseCommandSetIntegration(t *testing.T) { @@ -259,6 +261,51 @@ func TestBaseCommandSetNonParallel(t *testing.T) { }) } +func TestBaseCommandSetDriverProcurerPgxV5(t *testing.T) { + t.Parallel() + + calledStub := false + + migratorStub := &MigratorStub{} + migratorStub.allVersionsStub = func() []rivermigrate.Migration { return []rivermigrate.Migration{testMigration01} } + migratorStub.getVersionStub = func(version int) (rivermigrate.Migration, error) { + calledStub = true + if version == 1 { + return testMigration01, nil + } + + return rivermigrate.Migration{}, fmt.Errorf("unknown version: %d", version) + } + migratorStub.existingVersionsStub = func(ctx context.Context) ([]rivermigrate.Migration, error) { return nil, nil } + + cli := NewCLI(&Config{ + DriverProcurer: &DriverProcurerStub{ + getMigratorStub: func(config *rivermigrate.Config) (MigratorInterface, error) { + calledStub = true + return migratorStub, nil + }, + initPgxV5Stub: func(pool *pgxpool.Pool) { + calledStub = true + }, + }, + Name: "River", + }) + + var out bytes.Buffer + cli.SetOut(&out) + + cmd := cli.BaseCommandSet() + cmd.SetArgs([]string{"migrate-get", "--up", "--version", "1"}) + require.NoError(t, cmd.Execute()) + + require.True(t, calledStub) + + require.Equal(t, strings.TrimSpace(` +-- River main migration 001 [up] +SELECT 'up 1' FROM river_table + `), strings.TrimSpace(out.String())) +} + func TestMigrateGet(t *testing.T) { t.Parallel() diff --git a/internal/riverinternaltest/riverdrivertest/riverdrivertest.go b/internal/riverinternaltest/riverdrivertest/riverdrivertest.go index f5fea393..f291eed2 100644 --- a/internal/riverinternaltest/riverdrivertest/riverdrivertest.go +++ b/internal/riverinternaltest/riverdrivertest/riverdrivertest.go @@ -3105,7 +3105,7 @@ func Exercise[TTx any](ctx context.Context, t *testing.T, now := time.Now().UTC() - leader := testfactory.Leader(ctx, t, exec, &testfactory.LeaderOpts{ + _ = testfactory.Leader(ctx, t, exec, &testfactory.LeaderOpts{ LeaderID: ptrutil.Ptr(clientID), Now: &now, })