Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions rivershared/sqlctemplate/sqlc_template.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,14 @@ import (
"github.com/riverqueue/river/rivershared/util/maputil"
)

// Context container added by WithReplacements.
type contextContainer struct {
// NamedArgs and their values to be replaced after templates in Replacements
// are rendered.
NamedArgs map[string]any
Templates map[string]Replacement

// Replacements maps template names to replacement values.
Replacements map[string]Replacement
}

type contextKey struct{}
Expand Down Expand Up @@ -136,7 +141,7 @@ func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (strin
}
r.cacheMu.RUnlock()

// If all input templates were stable, the finished SQL will have been
// If all input templates were stable, the finished SQL will have been cached.
if cachedSQLOK {
if len(container.NamedArgs) > 0 {
args = append(args, maputil.Values(container.NamedArgs)...)
Expand All @@ -145,7 +150,7 @@ func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (strin
}

var (
templatesExpected = maputil.Keys(container.Templates)
templatesExpected = maputil.Keys(container.Replacements)
templatesMissing []string // not preallocated because we don't expect any missing parameters in the common case
)

Expand All @@ -158,9 +163,9 @@ func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (strin

template := matches[1]

if tmpl, ok := container.Templates[template]; ok {
if replacement, ok := container.Replacements[template]; ok {
templatesExpected = slices.DeleteFunc(templatesExpected, func(p string) bool { return p == template })
return tmpl.Value
return replacement.Value
} else {
templatesMissing = append(templatesMissing, template)
}
Expand Down Expand Up @@ -207,8 +212,8 @@ func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (strin
}
}

for _, tmpl := range container.Templates {
if !tmpl.Stable {
for _, replacement := range container.Replacements {
if !replacement.Stable {
return updatedSQL, args, nil
}
}
Expand All @@ -229,13 +234,13 @@ func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (strin
//
// If sqlctemplate params are already present in context, the two sets are
// merged, with the new params taking precedent.
func WithReplacements(ctx context.Context, templates map[string]Replacement, namedArgs map[string]any) context.Context {
func WithReplacements(ctx context.Context, replacements map[string]Replacement, namedArgs map[string]any) context.Context {
if container, ok := ctx.Value(contextKey{}).(*contextContainer); ok {
for arg, val := range namedArgs {
container.NamedArgs[arg] = val
}
for template, tmpl := range templates {
container.Templates[template] = tmpl
for template, replacement := range replacements {
container.Replacements[template] = replacement
}
return ctx
}
Expand All @@ -245,7 +250,7 @@ func WithReplacements(ctx context.Context, templates map[string]Replacement, nam
}

return context.WithValue(ctx, contextKey{}, &contextContainer{
NamedArgs: namedArgs,
Templates: templates,
NamedArgs: namedArgs,
Replacements: replacements,
})
}
Loading