Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ func Test_Client(t *testing.T) {
&overridableJobMiddleware{
workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error {
middlewareCalled = true
require.Equal(t, `{"name": "inserted name"}`, string(job.EncodedArgs))
require.JSONEq(t, `{"name": "inserted name"}`, string(job.EncodedArgs))
job.EncodedArgs = []byte(`{"name": "middleware name"}`)
return doInner(ctx)
},
Expand Down
54 changes: 46 additions & 8 deletions riverdriver/riverdatabasesql/river_database_sql_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/riverqueue/river/riverdriver"
"github.com/riverqueue/river/riverdriver/riverdatabasesql/internal/dbsqlc"
"github.com/riverqueue/river/riverdriver/riverdatabasesql/internal/pgtypealias"
"github.com/riverqueue/river/rivershared/sqlctemplate"
"github.com/riverqueue/river/rivershared/util/sliceutil"
"github.com/riverqueue/river/rivershared/util/valutil"
"github.com/riverqueue/river/rivertype"
Expand All @@ -35,7 +36,8 @@ var migrationFS embed.FS

// Driver is an implementation of riverdriver.Driver for database/sql.
type Driver struct {
dbPool *sql.DB
dbPool *sql.DB
replacer sqlctemplate.Replacer
}

// New returns a new database/sql River driver for use with River.
Expand All @@ -44,11 +46,13 @@ type Driver struct {
// configured to use the schema specified in the client's Schema field. The pool
// must not be closed while associated River objects are running.
func New(dbPool *sql.DB) *Driver {
return &Driver{dbPool: dbPool}
return &Driver{
dbPool: dbPool,
}
}

func (d *Driver) GetExecutor() riverdriver.Executor {
return &Executor{d.dbPool, d.dbPool}
return &Executor{d.dbPool, templateReplaceWrapper{d.dbPool, &d.replacer}, d}
}

func (d *Driver) GetListener() riverdriver.Listener { panic(riverdriver.ErrNotImplemented) }
Expand All @@ -63,20 +67,29 @@ func (d *Driver) HasPool() bool { return d.dbPool != nil }
func (d *Driver) SupportsListener() bool { return false }

func (d *Driver) UnwrapExecutor(tx *sql.Tx) riverdriver.ExecutorTx {
return &ExecutorTx{Executor: Executor{nil, tx}, tx: tx}
// Allows UnwrapExecutor to be invoked even if driver is nil.
var replacer *sqlctemplate.Replacer
if d == nil {
replacer = &sqlctemplate.Replacer{}
} else {
replacer = &d.replacer
}

return &ExecutorTx{Executor: Executor{nil, templateReplaceWrapper{tx, replacer}, d}, tx: tx}
}

type Executor struct {
dbPool *sql.DB
dbtx dbsqlc.DBTX
dbtx templateReplaceWrapper
driver *Driver
}

func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) {
tx, err := e.dbPool.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &ExecutorTx{Executor: Executor{nil, tx}, tx: tx}, nil
return &ExecutorTx{Executor: Executor{nil, templateReplaceWrapper{tx, &e.driver.replacer}, e.driver}, tx: tx}, nil
}

func (e *Executor) ColumnExists(ctx context.Context, tableName, columnName string) (bool, error) {
Expand Down Expand Up @@ -846,7 +859,7 @@ type ExecutorTx struct {
}

func (t *ExecutorTx) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) {
return (&ExecutorSubTx{Executor: Executor{nil, t.tx}, savepointNum: 0, single: &singleTransaction{}, tx: t.tx}).Begin(ctx)
return (&ExecutorSubTx{Executor: Executor{nil, templateReplaceWrapper{t.tx, &t.driver.replacer}, t.driver}, savepointNum: 0, single: &singleTransaction{}, tx: t.tx}).Begin(ctx)
}

func (t *ExecutorTx) Commit(ctx context.Context) error {
Expand Down Expand Up @@ -878,7 +891,7 @@ func (t *ExecutorSubTx) Begin(ctx context.Context) (riverdriver.ExecutorTx, erro
if err != nil {
return nil, err
}
return &ExecutorSubTx{Executor: Executor{nil, t.tx}, savepointNum: nextSavepointNum, single: &singleTransaction{parent: t.single}, tx: t.tx}, nil
return &ExecutorSubTx{Executor: Executor{nil, templateReplaceWrapper{t.tx, &t.driver.replacer}, t.driver}, savepointNum: nextSavepointNum, single: &singleTransaction{parent: t.single}, tx: t.tx}, nil
}

func (t *ExecutorSubTx) Commit(ctx context.Context) error {
Expand Down Expand Up @@ -944,6 +957,31 @@ func (t *singleTransaction) setDone() {
}
}

type templateReplaceWrapper struct {
dbtx dbsqlc.DBTX
replacer *sqlctemplate.Replacer
}

func (w templateReplaceWrapper) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) {
sql, args = w.replacer.Run(ctx, sql, args)
return w.dbtx.ExecContext(ctx, sql, args...)
}

func (w templateReplaceWrapper) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) {
sql, _ = w.replacer.Run(ctx, sql, nil)
return w.dbtx.PrepareContext(ctx, sql)
}

func (w templateReplaceWrapper) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) {
sql, args = w.replacer.Run(ctx, sql, args)
return w.dbtx.QueryContext(ctx, sql, args...)
}

func (w templateReplaceWrapper) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *sql.Row {
sql, args = w.replacer.Run(ctx, sql, args)
return w.dbtx.QueryRowContext(ctx, sql, args...)
}

func jobRowFromInternal(internal *dbsqlc.RiverJob) (*rivertype.JobRow, error) {
var attemptedAt *time.Time
if internal.AttemptedAt != nil {
Expand Down
62 changes: 53 additions & 9 deletions riverdriver/riverpgxv5/river_pgx_v5_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ import (
"time"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/puddle/v2"

"github.com/riverqueue/river/internal/dbunique"
"github.com/riverqueue/river/riverdriver"
"github.com/riverqueue/river/riverdriver/riverpgxv5/internal/dbsqlc"
"github.com/riverqueue/river/rivershared/sqlctemplate"
"github.com/riverqueue/river/rivershared/util/sliceutil"
"github.com/riverqueue/river/rivertype"
)
Expand All @@ -33,7 +35,8 @@ var migrationFS embed.FS

// Driver is an implementation of riverdriver.Driver for Pgx v5.
type Driver struct {
dbPool *pgxpool.Pool
dbPool *pgxpool.Pool
replacer sqlctemplate.Replacer
}

// New returns a new Pgx v5 River driver for use with River.
Expand All @@ -49,10 +52,14 @@ type Driver struct {
// in testing so that inserts can be performed and verified on a test
// transaction that will be rolled back.
func New(dbPool *pgxpool.Pool) *Driver {
return &Driver{dbPool: dbPool}
return &Driver{
dbPool: dbPool,
}
}

func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{d.dbPool} }
func (d *Driver) GetExecutor() riverdriver.Executor {
return &Executor{templateReplaceWrapper{d.dbPool, &d.replacer}, d}
}
func (d *Driver) GetListener() riverdriver.Listener { return &Listener{dbPool: d.dbPool} }
func (d *Driver) GetMigrationFS(line string) fs.FS {
if line == riverdriver.MigrationLineMain {
Expand All @@ -65,22 +72,28 @@ func (d *Driver) HasPool() bool { return d.dbPool != nil }
func (d *Driver) SupportsListener() bool { return true }

func (d *Driver) UnwrapExecutor(tx pgx.Tx) riverdriver.ExecutorTx {
return &ExecutorTx{Executor: Executor{tx}, tx: tx}
// Allows UnwrapExecutor to be invoked even if driver is nil.
var replacer *sqlctemplate.Replacer
if d == nil {
replacer = &sqlctemplate.Replacer{}
} else {
replacer = &d.replacer
}

return &ExecutorTx{Executor: Executor{templateReplaceWrapper{tx, replacer}, d}, tx: tx}
}

type Executor struct {
dbtx interface {
dbsqlc.DBTX
Begin(ctx context.Context) (pgx.Tx, error)
}
dbtx templateReplaceWrapper
driver *Driver
}

func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) {
tx, err := e.dbtx.Begin(ctx)
if err != nil {
return nil, err
}
return &ExecutorTx{Executor: Executor{tx}, tx: tx}, nil
return &ExecutorTx{Executor: Executor{templateReplaceWrapper{tx, &e.driver.replacer}, e.driver}, tx: tx}, nil
}

func (e *Executor) ColumnExists(ctx context.Context, tableName, columnName string) (bool, error) {
Expand Down Expand Up @@ -814,6 +827,37 @@ func (l *Listener) WaitForNotification(ctx context.Context) (*riverdriver.Notifi
}, nil
}

type templateReplaceWrapper struct {
dbtx interface {
dbsqlc.DBTX
Begin(ctx context.Context) (pgx.Tx, error)
}
replacer *sqlctemplate.Replacer
}

func (w templateReplaceWrapper) Begin(ctx context.Context) (pgx.Tx, error) {
return w.dbtx.Begin(ctx)
}

func (w templateReplaceWrapper) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) {
sql, args = w.replacer.Run(ctx, sql, args)
return w.dbtx.Exec(ctx, sql, args...)
}

func (w templateReplaceWrapper) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
sql, args = w.replacer.Run(ctx, sql, args)
return w.dbtx.Query(ctx, sql, args...)
}

func (w templateReplaceWrapper) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
sql, args = w.replacer.Run(ctx, sql, args)
return w.dbtx.QueryRow(ctx, sql, args...)
}

func (w templateReplaceWrapper) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
return w.dbtx.CopyFrom(ctx, tableName, columnNames, rowSrc)
}

func interpretError(err error) error {
if errors.Is(err, puddle.ErrClosedPool) {
return riverdriver.ErrClosedPool
Expand Down
Loading
Loading