Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 75 additions & 25 deletions rivershared/sqlctemplate/sqlc_template.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ type Replacement struct {
// be initialized with a constructor. This lets it default to a usable instance
// on drivers that may themselves not be initialized.
type Replacer struct {
cache map[string]string
cache map[replacerCacheKey]string
cacheMu sync.RWMutex
}

Expand Down Expand Up @@ -131,22 +131,25 @@ func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (strin
return "", nil, errors.New("sqlctemplate found template(s) in SQL, but no context container; bug?")
}

r.cacheMu.RLock()
var (
cachedSQL string
cachedSQLOK bool
)
if r.cache != nil { // protect against map not initialized yet
cachedSQL, cachedSQLOK = r.cache[sql]
}
r.cacheMu.RUnlock()
cacheKey, cacheEligible := replacerCacheKeyFrom(sql, container)
if cacheEligible {
r.cacheMu.RLock()
var (
cachedSQL string
cachedSQLOK bool
)
if r.cache != nil { // protect against map not initialized yet
cachedSQL, cachedSQLOK = r.cache[cacheKey]
}
r.cacheMu.RUnlock()

// If all input templates were stable, the finished SQL will have been cached.
if cachedSQLOK {
if len(container.NamedArgs) > 0 {
args = append(args, maputil.Values(container.NamedArgs)...)
// If all input templates were stable, the finished SQL will have been cached.
if cachedSQLOK {
if len(container.NamedArgs) > 0 {
args = append(args, maputil.Values(container.NamedArgs)...)
}
return cachedSQL, args, nil
}
return cachedSQL, args, nil
}

var (
Expand Down Expand Up @@ -212,19 +215,15 @@ func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (strin
}
}

for _, replacement := range container.Replacements {
if !replacement.Stable {
return updatedSQL, args, nil
if cacheEligible {
r.cacheMu.Lock()
if r.cache == nil {
r.cache = make(map[replacerCacheKey]string)
}
r.cache[cacheKey] = updatedSQL
r.cacheMu.Unlock()
}

r.cacheMu.Lock()
if r.cache == nil {
r.cache = make(map[string]string)
}
r.cache[sql] = updatedSQL
r.cacheMu.Unlock()

return updatedSQL, args, nil
}

Expand Down Expand Up @@ -254,3 +253,54 @@ func WithReplacements(ctx context.Context, replacements map[string]Replacement,
Replacements: replacements,
})
}

// Comparable struct that's used as a key for template caching.
type replacerCacheKey struct {
namedArgs string // all arg names concatenated together
replacementValues string // all values concatenated together
sql string
}

// Builds a cache key for the given SQL and context container.
//
// A key is only built if the given SQL/templates are cacheable, which means all
// template values must be stable. The second return value is a boolean
// indicating whether a cache key was built or not. If false, the input is not
// eligible for caching, and no check against the cache should be made.
func replacerCacheKeyFrom(sql string, container *contextContainer) (replacerCacheKey, bool) {
// Only eligible for caching if all replacements are stable.
for _, replacement := range container.Replacements {
if !replacement.Stable {
return replacerCacheKey{}, false
}
}

var (
namedArgsBuilder strings.Builder

// Named args must be sorted for key stability because Go maps don't
// provide any ordering guarantees.
sortedNamedArgs = maputil.Keys(container.NamedArgs)
)
slices.Sort(sortedNamedArgs)
for _, namedArg := range sortedNamedArgs {
namedArgsBuilder.WriteRune('@') // useful as separator because not valid in the name of a named arg
namedArgsBuilder.WriteString(namedArg)
}

var (
replacementValuesBuilder strings.Builder
sortedReplacements = maputil.Keys(container.Replacements)
)
slices.Sort(sortedReplacements)
for _, template := range sortedReplacements {
replacementValuesBuilder.WriteRune('•') // use a separator that SQL would reject under most circumstances (this may be imperfect)
replacementValuesBuilder.WriteString(container.Replacements[template].Value)
}

return replacerCacheKey{
namedArgs: namedArgsBuilder.String(),
replacementValues: replacementValuesBuilder.String(),
sql: sql,
}, true
}
124 changes: 124 additions & 0 deletions rivershared/sqlctemplate/sqlc_template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,96 @@ func TestReplacer(t *testing.T) {
require.Empty(t, replacer.cache)
})

t.Run("CacheBasedOnInputValues", func(t *testing.T) {
t.Parallel()

replacer, _ := setup(t)

// SQL stays constant across all runs.
const sql = `
SELECT count(*)
FROM /* TEMPLATE: schema */river_job
WHERE kind = @kind
AND state = @state;
`

// Initially cached value
{
ctx := WithReplacements(ctx, map[string]Replacement{
"schema": {Stable: true, Value: "test_schema."},
}, nil)

_, _, err := replacer.RunSafely(ctx, sql, nil)
require.NoError(t, err)
}
require.Len(t, replacer.cache, 1)

// Same SQL, but new value.
{
ctx := WithReplacements(ctx, map[string]Replacement{
"schema": {Stable: true, Value: "other_schema."},
}, nil)

_, _, err := replacer.RunSafely(ctx, sql, nil)
require.NoError(t, err)
}
require.Len(t, replacer.cache, 2)

// Named arg added to the mix.
{
ctx := WithReplacements(ctx, map[string]Replacement{
"schema": {Stable: true, Value: "test_schema."},
}, map[string]any{
"kind": "kind_value",
})

_, _, err := replacer.RunSafely(ctx, sql, nil)
require.NoError(t, err)
}
require.Len(t, replacer.cache, 3)

// Different named arg _value_ (i.e. still same named arg) can still use
// the previous cached SQL.
{
ctx := WithReplacements(ctx, map[string]Replacement{
"schema": {Stable: true, Value: "test_schema."},
}, map[string]any{
"kind": "other_kind_value",
})

_, _, err := replacer.RunSafely(ctx, sql, nil)
require.NoError(t, err)
}
require.Len(t, replacer.cache, 3) // unchanged

// New named arg adds a new cache value.
{
ctx := WithReplacements(ctx, map[string]Replacement{
"schema": {Stable: true, Value: "test_schema."},
}, map[string]any{
"kind": "kind_value",
"state": "state_value",
})

_, _, err := replacer.RunSafely(ctx, sql, nil)
require.NoError(t, err)
}
require.Len(t, replacer.cache, 4)

// Different input SQL.
{
ctx := WithReplacements(ctx, map[string]Replacement{
"schema": {Stable: true, Value: "test_schema."},
}, nil)

_, _, err := replacer.RunSafely(ctx, `
SELECT /* TEMPLATE: schema */river_job;
`, nil)
require.NoError(t, err)
}
require.Len(t, replacer.cache, 5)
})

t.Run("NamedArgsNoInitialArgs", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -369,3 +459,37 @@ func TestReplacer(t *testing.T) {
wg.Wait()
})
}

func BenchmarkReplacer(b *testing.B) {
ctx := context.Background()

runReplacer := func(b *testing.B, replacer *Replacer, stable bool) {
b.Helper()

ctx := WithReplacements(ctx, map[string]Replacement{
"schema": {Stable: stable, Value: "test_schema."},
}, nil)

_, _, err := replacer.RunSafely(ctx, `
-- name: JobCountByState :one
SELECT count(*)
FROM /* TEMPLATE: schema */river_job
WHERE state = @state;
`, nil)
require.NoError(b, err)
}

b.Run("WithCache", func(b *testing.B) {
var replacer Replacer
for range b.N {
runReplacer(b, &replacer, true)
}
})

b.Run("WithoutCache", func(b *testing.B) {
var replacer Replacer
for range b.N {
runReplacer(b, &replacer, false)
}
})
}
Loading