diff --git a/rivershared/sqlctemplate/sqlc_template.go b/rivershared/sqlctemplate/sqlc_template.go index 36986b24..c8417377 100644 --- a/rivershared/sqlctemplate/sqlc_template.go +++ b/rivershared/sqlctemplate/sqlc_template.go @@ -110,10 +110,20 @@ func (r *Replacer) Run(ctx context.Context, sql string, args []any) (string, []a // RunSafely is the same as Run, but returns an error in case of missing or // extra templates. func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (string, []any, error) { - // If nothing present in context, short circuit quickly. - container, containerOK := ctx.Value(contextKey{}).(*contextContainer) - if !containerOK { + var ( + container, containerOK = ctx.Value(contextKey{}).(*contextContainer) + sqlContainsTemplate = strings.Contains(sql, "/* TEMPLATE") + ) + switch { + case !containerOK && !sqlContainsTemplate: + // Neither context container or template in SQL; short circuit fast because there's no work to do. return sql, args, nil + + case containerOK && !sqlContainsTemplate: + return "", nil, errors.New("sqlctemplate found context container but SQL contains no templates; bug?") + + case !containerOK && sqlContainsTemplate: + return "", nil, errors.New("sqlctemplate found template(s) in SQL, but no context container; bug?") } r.cacheMu.RLock() @@ -134,10 +144,6 @@ func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (strin return cachedSQL, args, nil } - if !strings.Contains(sql, "/* TEMPLATE") { - return sql, args, nil - } - var ( templatesExpected = maputil.Keys(container.Templates) templatesMissing []string // not preallocated because we don't expect any missing parameters in the common case diff --git a/rivershared/sqlctemplate/sqlc_template_test.go b/rivershared/sqlctemplate/sqlc_template_test.go index f10fa30f..bbe47251 100644 --- a/rivershared/sqlctemplate/sqlc_template_test.go +++ b/rivershared/sqlctemplate/sqlc_template_test.go @@ -21,36 +21,43 @@ func TestReplacer(t *testing.T) { return &Replacer{}, &testBundle{} } - t.Run("NoContainer", func(t *testing.T) { + t.Run("NoContainerError", func(t *testing.T) { t.Parallel() replacer, _ := setup(t) - updatedSQL, args, err := replacer.RunSafely(ctx, ` + _, _, err := replacer.RunSafely(ctx, ` SELECT /* TEMPLATE: schema */river_job; `, nil) - require.NoError(t, err) - require.Nil(t, args) - require.Equal(t, ` - SELECT /* TEMPLATE: schema */river_job; - `, updatedSQL) + require.EqualError(t, err, "sqlctemplate found template(s) in SQL, but no context container; bug?") }) - t.Run("NoTemplate", func(t *testing.T) { + t.Run("NoTemplateError", func(t *testing.T) { t.Parallel() replacer, _ := setup(t) ctx := WithReplacements(ctx, map[string]Replacement{}, nil) + _, _, err := replacer.RunSafely(ctx, ` + SELECT 1; + `, nil) + require.EqualError(t, err, "sqlctemplate found context container but SQL contains no templates; bug?") + }) + + t.Run("NoContainerOrTemplate", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + updatedSQL, args, err := replacer.RunSafely(ctx, ` SELECT 1; `, nil) require.NoError(t, err) - require.Nil(t, args) require.Equal(t, ` SELECT 1; `, updatedSQL) + require.Nil(t, args) }) t.Run("BasicTemplate", func(t *testing.T) {