diff --git a/client_test.go b/client_test.go index 9d2af82b..31864be5 100644 --- a/client_test.go +++ b/client_test.go @@ -710,7 +710,7 @@ func Test_Client(t *testing.T) { &overridableJobMiddleware{ workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { middlewareCalled = true - require.Equal(t, `{"name": "inserted name"}`, string(job.EncodedArgs)) + require.JSONEq(t, `{"name": "inserted name"}`, string(job.EncodedArgs)) job.EncodedArgs = []byte(`{"name": "middleware name"}`) return doInner(ctx) }, diff --git a/riverdriver/riverdatabasesql/river_database_sql_driver.go b/riverdriver/riverdatabasesql/river_database_sql_driver.go index 34918e19..195609ec 100644 --- a/riverdriver/riverdatabasesql/river_database_sql_driver.go +++ b/riverdriver/riverdatabasesql/river_database_sql_driver.go @@ -25,6 +25,7 @@ import ( "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverdatabasesql/internal/dbsqlc" "github.com/riverqueue/river/riverdriver/riverdatabasesql/internal/pgtypealias" + "github.com/riverqueue/river/rivershared/sqlctemplate" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivershared/util/valutil" "github.com/riverqueue/river/rivertype" @@ -35,7 +36,8 @@ var migrationFS embed.FS // Driver is an implementation of riverdriver.Driver for database/sql. type Driver struct { - dbPool *sql.DB + dbPool *sql.DB + replacer sqlctemplate.Replacer } // New returns a new database/sql River driver for use with River. @@ -44,11 +46,13 @@ type Driver struct { // configured to use the schema specified in the client's Schema field. The pool // must not be closed while associated River objects are running. func New(dbPool *sql.DB) *Driver { - return &Driver{dbPool: dbPool} + return &Driver{ + dbPool: dbPool, + } } func (d *Driver) GetExecutor() riverdriver.Executor { - return &Executor{d.dbPool, d.dbPool} + return &Executor{d.dbPool, templateReplaceWrapper{d.dbPool, &d.replacer}, d} } func (d *Driver) GetListener() riverdriver.Listener { panic(riverdriver.ErrNotImplemented) } @@ -63,12 +67,21 @@ func (d *Driver) HasPool() bool { return d.dbPool != nil } func (d *Driver) SupportsListener() bool { return false } func (d *Driver) UnwrapExecutor(tx *sql.Tx) riverdriver.ExecutorTx { - return &ExecutorTx{Executor: Executor{nil, tx}, tx: tx} + // Allows UnwrapExecutor to be invoked even if driver is nil. + var replacer *sqlctemplate.Replacer + if d == nil { + replacer = &sqlctemplate.Replacer{} + } else { + replacer = &d.replacer + } + + return &ExecutorTx{Executor: Executor{nil, templateReplaceWrapper{tx, replacer}, d}, tx: tx} } type Executor struct { dbPool *sql.DB - dbtx dbsqlc.DBTX + dbtx templateReplaceWrapper + driver *Driver } func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { @@ -76,7 +89,7 @@ func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { if err != nil { return nil, err } - return &ExecutorTx{Executor: Executor{nil, tx}, tx: tx}, nil + return &ExecutorTx{Executor: Executor{nil, templateReplaceWrapper{tx, &e.driver.replacer}, e.driver}, tx: tx}, nil } func (e *Executor) ColumnExists(ctx context.Context, tableName, columnName string) (bool, error) { @@ -846,7 +859,7 @@ type ExecutorTx struct { } func (t *ExecutorTx) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { - return (&ExecutorSubTx{Executor: Executor{nil, t.tx}, savepointNum: 0, single: &singleTransaction{}, tx: t.tx}).Begin(ctx) + return (&ExecutorSubTx{Executor: Executor{nil, templateReplaceWrapper{t.tx, &t.driver.replacer}, t.driver}, savepointNum: 0, single: &singleTransaction{}, tx: t.tx}).Begin(ctx) } func (t *ExecutorTx) Commit(ctx context.Context) error { @@ -878,7 +891,7 @@ func (t *ExecutorSubTx) Begin(ctx context.Context) (riverdriver.ExecutorTx, erro if err != nil { return nil, err } - return &ExecutorSubTx{Executor: Executor{nil, t.tx}, savepointNum: nextSavepointNum, single: &singleTransaction{parent: t.single}, tx: t.tx}, nil + return &ExecutorSubTx{Executor: Executor{nil, templateReplaceWrapper{t.tx, &t.driver.replacer}, t.driver}, savepointNum: nextSavepointNum, single: &singleTransaction{parent: t.single}, tx: t.tx}, nil } func (t *ExecutorSubTx) Commit(ctx context.Context) error { @@ -944,6 +957,31 @@ func (t *singleTransaction) setDone() { } } +type templateReplaceWrapper struct { + dbtx dbsqlc.DBTX + replacer *sqlctemplate.Replacer +} + +func (w templateReplaceWrapper) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.ExecContext(ctx, sql, args...) +} + +func (w templateReplaceWrapper) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) { + sql, _ = w.replacer.Run(ctx, sql, nil) + return w.dbtx.PrepareContext(ctx, sql) +} + +func (w templateReplaceWrapper) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.QueryContext(ctx, sql, args...) +} + +func (w templateReplaceWrapper) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *sql.Row { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.QueryRowContext(ctx, sql, args...) +} + func jobRowFromInternal(internal *dbsqlc.RiverJob) (*rivertype.JobRow, error) { var attemptedAt *time.Time if internal.AttemptedAt != nil { diff --git a/riverdriver/riverpgxv5/river_pgx_v5_driver.go b/riverdriver/riverpgxv5/river_pgx_v5_driver.go index c446c673..3db9d3af 100644 --- a/riverdriver/riverpgxv5/river_pgx_v5_driver.go +++ b/riverdriver/riverpgxv5/river_pgx_v5_driver.go @@ -17,6 +17,7 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/puddle/v2" @@ -24,6 +25,7 @@ import ( "github.com/riverqueue/river/internal/dbunique" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverpgxv5/internal/dbsqlc" + "github.com/riverqueue/river/rivershared/sqlctemplate" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivertype" ) @@ -33,7 +35,8 @@ var migrationFS embed.FS // Driver is an implementation of riverdriver.Driver for Pgx v5. type Driver struct { - dbPool *pgxpool.Pool + dbPool *pgxpool.Pool + replacer sqlctemplate.Replacer } // New returns a new Pgx v5 River driver for use with River. @@ -49,10 +52,14 @@ type Driver struct { // in testing so that inserts can be performed and verified on a test // transaction that will be rolled back. func New(dbPool *pgxpool.Pool) *Driver { - return &Driver{dbPool: dbPool} + return &Driver{ + dbPool: dbPool, + } } -func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{d.dbPool} } +func (d *Driver) GetExecutor() riverdriver.Executor { + return &Executor{templateReplaceWrapper{d.dbPool, &d.replacer}, d} +} func (d *Driver) GetListener() riverdriver.Listener { return &Listener{dbPool: d.dbPool} } func (d *Driver) GetMigrationFS(line string) fs.FS { if line == riverdriver.MigrationLineMain { @@ -65,14 +72,20 @@ func (d *Driver) HasPool() bool { return d.dbPool != nil } func (d *Driver) SupportsListener() bool { return true } func (d *Driver) UnwrapExecutor(tx pgx.Tx) riverdriver.ExecutorTx { - return &ExecutorTx{Executor: Executor{tx}, tx: tx} + // Allows UnwrapExecutor to be invoked even if driver is nil. + var replacer *sqlctemplate.Replacer + if d == nil { + replacer = &sqlctemplate.Replacer{} + } else { + replacer = &d.replacer + } + + return &ExecutorTx{Executor: Executor{templateReplaceWrapper{tx, replacer}, d}, tx: tx} } type Executor struct { - dbtx interface { - dbsqlc.DBTX - Begin(ctx context.Context) (pgx.Tx, error) - } + dbtx templateReplaceWrapper + driver *Driver } func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { @@ -80,7 +93,7 @@ func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { if err != nil { return nil, err } - return &ExecutorTx{Executor: Executor{tx}, tx: tx}, nil + return &ExecutorTx{Executor: Executor{templateReplaceWrapper{tx, &e.driver.replacer}, e.driver}, tx: tx}, nil } func (e *Executor) ColumnExists(ctx context.Context, tableName, columnName string) (bool, error) { @@ -814,6 +827,37 @@ func (l *Listener) WaitForNotification(ctx context.Context) (*riverdriver.Notifi }, nil } +type templateReplaceWrapper struct { + dbtx interface { + dbsqlc.DBTX + Begin(ctx context.Context) (pgx.Tx, error) + } + replacer *sqlctemplate.Replacer +} + +func (w templateReplaceWrapper) Begin(ctx context.Context) (pgx.Tx, error) { + return w.dbtx.Begin(ctx) +} + +func (w templateReplaceWrapper) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.Exec(ctx, sql, args...) +} + +func (w templateReplaceWrapper) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.Query(ctx, sql, args...) +} + +func (w templateReplaceWrapper) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.QueryRow(ctx, sql, args...) +} + +func (w templateReplaceWrapper) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + return w.dbtx.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + func interpretError(err error) error { if errors.Is(err, puddle.ErrClosedPool) { return riverdriver.ErrClosedPool diff --git a/rivershared/sqlctemplate/sqlc_template.go b/rivershared/sqlctemplate/sqlc_template.go new file mode 100644 index 00000000..36986b24 --- /dev/null +++ b/rivershared/sqlctemplate/sqlc_template.go @@ -0,0 +1,239 @@ +// Package sqlctemplate provides a way of making arbitrary text replacement in +// sqlc queries which normally only allow parameters which are in places valid +// in a prepared statement. For example, it can be used to insert a schema name +// as a prefix to tables referenced in sqlc, which is otherwise impossible. +// +// Replacement is carried out from within invocations of sqlc's generated DBTX +// interface, after sqlc generated code runs, but before queries are executed. +// This is accomplished by implementing DBTX, calling Replacer.Run from within +// them, and injecting parameters in with WithReplacements (which is unfortunately +// the only way of injecting them). +// +// Templates are modeled as SQL comments so that they're still parseable as +// valid SQL. An example use of the basic /* TEMPLATE ... */ syntax: +// +// -- name: JobCountByState :one +// SELECT count(*) +// FROM /* TEMPLATE: schema */river_job +// WHERE state = @state; +// +// An open/close syntax is also available for when SQL is required before +// processing for the query to be valid. For example, a WHERE or ORDER BY clause +// can't be empty, so the SQL includes a sentinel value that's parseable which +// is then replaced later with template values: +// +// -- name: JobList :many +// SELECT * +// FROM river_job +// WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */ +// ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ +// LIMIT @max::int; +// +// Be careful not to place a template on a line by itself because sqlc will +// strip any lines that start with a comment. For example, this does NOT work: +// +// -- name: JobList :many +// SELECT * +// FROM river_job +// /* TEMPLATE_BEGIN: where_clause */ +// LIMIT @max::int; +package sqlctemplate + +import ( + "context" + "errors" + "fmt" + "regexp" + "slices" + "strconv" + "strings" + "sync" + + "github.com/riverqueue/river/rivershared/util/maputil" +) + +type contextContainer struct { + NamedArgs map[string]any + Templates map[string]Replacement +} + +type contextKey struct{} + +// Replacement defines a replacement for a template value in some input SQL. +type Replacement struct { + // Stable is whether the replacement value is expected to be stable for any + // number of times Replacer.Run is called with the same given input SQL. If + // all replacements are stable, then the output of Replacer.Run is cached so + // that it doesn't have to be processed again. Replacements should be not be + // stable if they depend on input parameters. + Stable bool + + // Value is the value which the template should be replaced with. For a /* + // TEMPLATE ... */ tag, replaces template and the comment containing it. For + // a /* TEMPLATE_BEGIN ... */ ... /* TEMPLATE_END */ tag pair, replaces both + // templates, comments, and the value between them. + Value string +} + +// Replacer replaces templates with template values. As an optimization, it +// contains an internal cache to short circuit SQL that has entirely stable +// template replacements and whose output is invariant of input parameters. +// +// The struct is written so that it's safe to use as a value and doesn't need to +// be initialized with a constructor. This lets it default to a usable instance +// on drivers that may themselves not be initialized. +type Replacer struct { + cache map[string]string + cacheMu sync.RWMutex +} + +var ( + templateBeginEndRE = regexp.MustCompile(`/\* TEMPLATE_BEGIN: (.*?) \*/ .*? /\* TEMPLATE_END \*/`) + templateRE = regexp.MustCompile(`/\* TEMPLATE: (.*?) \*/`) +) + +// Run replaces any tempates in input SQL with values from context added via +// WithReplacements. +// +// args aren't used for replacements in the input SQL, but are needed to +// determine which placeholder number (e.g. $1, $2, $3, ...) we should start +// with to replace any template named args. The returned args value should then +// be used as query input as named args from context may have been added to it. +func (r *Replacer) Run(ctx context.Context, sql string, args []any) (string, []any) { + sql, namedArgs, err := r.RunSafely(ctx, sql, args) + if err != nil { + panic(err) + } + return sql, namedArgs +} + +// RunSafely is the same as Run, but returns an error in case of missing or +// extra templates. +func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (string, []any, error) { + // If nothing present in context, short circuit quickly. + container, containerOK := ctx.Value(contextKey{}).(*contextContainer) + if !containerOK { + return sql, args, nil + } + + r.cacheMu.RLock() + var ( + cachedSQL string + cachedSQLOK bool + ) + if r.cache != nil { // protect against map not initialized yet + cachedSQL, cachedSQLOK = r.cache[sql] + } + r.cacheMu.RUnlock() + + // If all input templates were stable, the finished SQL will have been + if cachedSQLOK { + if len(container.NamedArgs) > 0 { + args = append(args, maputil.Values(container.NamedArgs)...) + } + return cachedSQL, args, nil + } + + if !strings.Contains(sql, "/* TEMPLATE") { + return sql, args, nil + } + + var ( + templatesExpected = maputil.Keys(container.Templates) + templatesMissing []string // not preallocated because we don't expect any missing parameters in the common case + ) + + replaceTemplate := func(sql string, templateRE *regexp.Regexp) string { + return templateRE.ReplaceAllStringFunc(sql, func(templateStr string) string { + // Really dumb, but Go doesn't provide any way to get submatches in a + // function, so we have to match twice. + // https://github.com/golang/go/issues/5690 + matches := templateRE.FindStringSubmatch(templateStr) + + template := matches[1] + + if tmpl, ok := container.Templates[template]; ok { + templatesExpected = slices.DeleteFunc(templatesExpected, func(p string) bool { return p == template }) + return tmpl.Value + } else { + templatesMissing = append(templatesMissing, template) + } + + return templateStr + }) + } + + updatedSQL := sql + updatedSQL = replaceTemplate(updatedSQL, templateBeginEndRE) + updatedSQL = replaceTemplate(updatedSQL, templateRE) + + if len(templatesExpected) > 0 { + return "", nil, errors.New("sqlctemplate params present in context but missing in SQL: " + strings.Join(templatesExpected, ", ")) + } + + if len(templatesMissing) > 0 { + return "", nil, errors.New("sqlctemplate params present in SQL but missing in context: " + strings.Join(templatesMissing, ", ")) + } + + if len(container.NamedArgs) > 0 { + placeholderNum := len(args) + for arg, val := range container.NamedArgs { + placeholderNum++ + + var ( + symbol = "@" + arg + symbolIndex = strings.Index(updatedSQL, symbol) + ) + + if symbolIndex == -1 { + return "", nil, fmt.Errorf("sqltemplate expected to find named arg %q, but it wasn't present", symbol) + } + + // ReplaceAll because an input parameter may appear multiple times. + updatedSQL = strings.ReplaceAll(updatedSQL, symbol, "$"+strconv.Itoa(placeholderNum)) + args = append(args, val) + } + } + + for _, tmpl := range container.Templates { + if !tmpl.Stable { + return updatedSQL, args, nil + } + } + + r.cacheMu.Lock() + if r.cache == nil { + r.cache = make(map[string]string) + } + r.cache[sql] = updatedSQL + r.cacheMu.Unlock() + + return updatedSQL, args, nil +} + +// WithReplacements adds sqlctemplate templates to the given context (they go in +// context because it's the only way to get them down into the innards of sqlc). +// namedArgs can also be passed in to replace arguments found in +// +// If sqlctemplate params are already present in context, the two sets are +// merged, with the new params taking precedent. +func WithReplacements(ctx context.Context, templates map[string]Replacement, namedArgs map[string]any) context.Context { + if container, ok := ctx.Value(contextKey{}).(*contextContainer); ok { + for arg, val := range namedArgs { + container.NamedArgs[arg] = val + } + for template, tmpl := range templates { + container.Templates[template] = tmpl + } + return ctx + } + + if namedArgs == nil { + namedArgs = make(map[string]any) + } + + return context.WithValue(ctx, contextKey{}, &contextContainer{ + NamedArgs: namedArgs, + Templates: templates, + }) +} diff --git a/rivershared/sqlctemplate/sqlc_template_test.go b/rivershared/sqlctemplate/sqlc_template_test.go new file mode 100644 index 00000000..f10fa30f --- /dev/null +++ b/rivershared/sqlctemplate/sqlc_template_test.go @@ -0,0 +1,364 @@ +package sqlctemplate + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestReplacer(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct{} + + setup := func(t *testing.T) (*Replacer, *testBundle) { //nolint:unparam + t.Helper() + + return &Replacer{}, &testBundle{} + } + + t.Run("NoContainer", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT /* TEMPLATE: schema */river_job; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT /* TEMPLATE: schema */river_job; + `, updatedSQL) + }) + + t.Run("NoTemplate", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{}, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT 1; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT 1; + `, updatedSQL) + }) + + t.Run("BasicTemplate", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Value: "test_schema."}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + -- name: JobCountByState :one + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE state = @state; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + -- name: JobCountByState :one + SELECT count(*) + FROM test_schema.river_job + WHERE state = @state; + `, updatedSQL) + }) + + t.Run("BeginEndTemplate", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "order_by_clause": {Value: "kind, id"}, + "where_clause": {Value: "kind = 'no_op'"}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + -- name: JobList :many + SELECT * + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */ + ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ + LIMIT @max::int; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + -- name: JobList :many + SELECT * + FROM river_job + WHERE kind = 'no_op' + ORDER BY kind, id + LIMIT @max::int; + `, updatedSQL) + }) + + t.Run("RepeatedTemplate", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Value: "test_schema."}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job r1 + INNER JOIN /* TEMPLATE: schema */river_job r2 ON r1.id = r2.id; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job r1 + INNER JOIN test_schema.river_job r2 ON r1.id = r2.id; + `, updatedSQL) + }) + + t.Run("AllTemplatesStableCached", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job; + `, updatedSQL) + + require.Len(t, replacer.cache, 1) + + // Invoke again to make sure we get back the same result. + updatedSQL, args, err = replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job; + `, updatedSQL) + }) + + t.Run("AnyTemplateNotStableNotCached", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + "where_clause": {Value: "kind = 'no_op'"}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE kind = 'no_op'; + `, updatedSQL) + + require.Empty(t, replacer.cache) + }) + + t.Run("NamedArgsNoInitialArgs", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */; + `, nil) + require.NoError(t, err) + require.Equal(t, []any{"no_op"}, args) + require.Equal(t, ` + SELECT count(*) + FROM river_job + WHERE kind = $1; + `, updatedSQL) + }) + + t.Run("NamedArgsWithInitialArgs", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */ + AND status = $1; + `, []any{"succeeded"}) + require.NoError(t, err) + require.Equal(t, []any{"succeeded", "no_op"}, args) + require.Equal(t, ` + SELECT count(*) + FROM river_job + WHERE kind = $2 + AND status = $1; + `, updatedSQL) + }) + + t.Run("MultipleWithReplacementsOverrides", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + "where_clause": {Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + ctx = WithReplacements(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind AND status = @status"}, + }, map[string]any{ + "status": "succeeded", + }) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */; + `, nil) + require.NoError(t, err) + require.Equal(t, []any{"no_op", "succeeded"}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE kind = $1 AND status = $2; + `, updatedSQL) + }) + + t.Run("Stress", func(t *testing.T) { + t.Parallel() + + const ( + clearCacheIterations = 10 + numIterations = 50 + ) + + replacer, _ := setup(t) + + periodicallyClearCache := func(i int, replacer *Replacer) { + if i+1%clearCacheIterations == 0 { // +1 so we don't clear cache on i == 0 + replacer.cacheMu.Lock() + replacer.cache = nil + replacer.cacheMu.Unlock() + } + } + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + + for i := range numIterations { + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Value: "test_schema."}, + }, nil) + + updatedSQL, _, err := replacer.RunSafely(ctx, ` + SELECT count(*) FROM /* TEMPLATE: schema */river_job; + `, nil) + require.NoError(t, err) + require.Equal(t, ` + SELECT count(*) FROM test_schema.river_job; + `, updatedSQL) + + periodicallyClearCache(i, replacer) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + + for i := range numIterations { + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, nil) + + updatedSQL, _, err := replacer.RunSafely(ctx, ` + SELECT distinct(kind) FROM /* TEMPLATE: schema */river_job; + `, nil) + require.NoError(t, err) + require.Equal(t, ` + SELECT distinct(kind) FROM test_schema.river_job; + `, updatedSQL) + + periodicallyClearCache(i, replacer) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + + for i := range numIterations { + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, nil) + + updatedSQL, _, err := replacer.RunSafely(ctx, ` + SELECT count(*) FROM /* TEMPLATE: schema */river_job WHERE status = 'succeeded'; + `, nil) + require.NoError(t, err) + require.Equal(t, ` + SELECT count(*) FROM test_schema.river_job WHERE status = 'succeeded'; + `, updatedSQL) + + periodicallyClearCache(i, replacer) + } + }() + + wg.Wait() + }) +}