diff --git a/rivershared/sqlctemplate/sqlc_template.go b/rivershared/sqlctemplate/sqlc_template.go index 548eac82..e0a02d6a 100644 --- a/rivershared/sqlctemplate/sqlc_template.go +++ b/rivershared/sqlctemplate/sqlc_template.go @@ -88,7 +88,7 @@ type Replacement struct { // be initialized with a constructor. This lets it default to a usable instance // on drivers that may themselves not be initialized. type Replacer struct { - cache map[string]string + cache map[replacerCacheKey]string cacheMu sync.RWMutex } @@ -131,22 +131,25 @@ func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (strin return "", nil, errors.New("sqlctemplate found template(s) in SQL, but no context container; bug?") } - r.cacheMu.RLock() - var ( - cachedSQL string - cachedSQLOK bool - ) - if r.cache != nil { // protect against map not initialized yet - cachedSQL, cachedSQLOK = r.cache[sql] - } - r.cacheMu.RUnlock() + cacheKey, cacheEligible := replacerCacheKeyFrom(sql, container) + if cacheEligible { + r.cacheMu.RLock() + var ( + cachedSQL string + cachedSQLOK bool + ) + if r.cache != nil { // protect against map not initialized yet + cachedSQL, cachedSQLOK = r.cache[cacheKey] + } + r.cacheMu.RUnlock() - // If all input templates were stable, the finished SQL will have been cached. - if cachedSQLOK { - if len(container.NamedArgs) > 0 { - args = append(args, maputil.Values(container.NamedArgs)...) + // If all input templates were stable, the finished SQL will have been cached. + if cachedSQLOK { + if len(container.NamedArgs) > 0 { + args = append(args, maputil.Values(container.NamedArgs)...) + } + return cachedSQL, args, nil } - return cachedSQL, args, nil } var ( @@ -212,19 +215,15 @@ func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (strin } } - for _, replacement := range container.Replacements { - if !replacement.Stable { - return updatedSQL, args, nil + if cacheEligible { + r.cacheMu.Lock() + if r.cache == nil { + r.cache = make(map[replacerCacheKey]string) } + r.cache[cacheKey] = updatedSQL + r.cacheMu.Unlock() } - r.cacheMu.Lock() - if r.cache == nil { - r.cache = make(map[string]string) - } - r.cache[sql] = updatedSQL - r.cacheMu.Unlock() - return updatedSQL, args, nil } @@ -254,3 +253,54 @@ func WithReplacements(ctx context.Context, replacements map[string]Replacement, Replacements: replacements, }) } + +// Comparable struct that's used as a key for template caching. +type replacerCacheKey struct { + namedArgs string // all arg names concatenated together + replacementValues string // all values concatenated together + sql string +} + +// Builds a cache key for the given SQL and context container. +// +// A key is only built if the given SQL/templates are cacheable, which means all +// template values must be stable. The second return value is a boolean +// indicating whether a cache key was built or not. If false, the input is not +// eligible for caching, and no check against the cache should be made. +func replacerCacheKeyFrom(sql string, container *contextContainer) (replacerCacheKey, bool) { + // Only eligible for caching if all replacements are stable. + for _, replacement := range container.Replacements { + if !replacement.Stable { + return replacerCacheKey{}, false + } + } + + var ( + namedArgsBuilder strings.Builder + + // Named args must be sorted for key stability because Go maps don't + // provide any ordering guarantees. + sortedNamedArgs = maputil.Keys(container.NamedArgs) + ) + slices.Sort(sortedNamedArgs) + for _, namedArg := range sortedNamedArgs { + namedArgsBuilder.WriteRune('@') // useful as separator because not valid in the name of a named arg + namedArgsBuilder.WriteString(namedArg) + } + + var ( + replacementValuesBuilder strings.Builder + sortedReplacements = maputil.Keys(container.Replacements) + ) + slices.Sort(sortedReplacements) + for _, template := range sortedReplacements { + replacementValuesBuilder.WriteRune('•') // use a separator that SQL would reject under most circumstances (this may be imperfect) + replacementValuesBuilder.WriteString(container.Replacements[template].Value) + } + + return replacerCacheKey{ + namedArgs: namedArgsBuilder.String(), + replacementValues: replacementValuesBuilder.String(), + sql: sql, + }, true +} diff --git a/rivershared/sqlctemplate/sqlc_template_test.go b/rivershared/sqlctemplate/sqlc_template_test.go index bbe47251..5fd0c612 100644 --- a/rivershared/sqlctemplate/sqlc_template_test.go +++ b/rivershared/sqlctemplate/sqlc_template_test.go @@ -199,6 +199,96 @@ func TestReplacer(t *testing.T) { require.Empty(t, replacer.cache) }) + t.Run("CacheBasedOnInputValues", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + // SQL stays constant across all runs. + const sql = ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE kind = @kind + AND state = @state; + ` + + // Initially cached value + { + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, nil) + + _, _, err := replacer.RunSafely(ctx, sql, nil) + require.NoError(t, err) + } + require.Len(t, replacer.cache, 1) + + // Same SQL, but new value. + { + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "other_schema."}, + }, nil) + + _, _, err := replacer.RunSafely(ctx, sql, nil) + require.NoError(t, err) + } + require.Len(t, replacer.cache, 2) + + // Named arg added to the mix. + { + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, map[string]any{ + "kind": "kind_value", + }) + + _, _, err := replacer.RunSafely(ctx, sql, nil) + require.NoError(t, err) + } + require.Len(t, replacer.cache, 3) + + // Different named arg _value_ (i.e. still same named arg) can still use + // the previous cached SQL. + { + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, map[string]any{ + "kind": "other_kind_value", + }) + + _, _, err := replacer.RunSafely(ctx, sql, nil) + require.NoError(t, err) + } + require.Len(t, replacer.cache, 3) // unchanged + + // New named arg adds a new cache value. + { + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, map[string]any{ + "kind": "kind_value", + "state": "state_value", + }) + + _, _, err := replacer.RunSafely(ctx, sql, nil) + require.NoError(t, err) + } + require.Len(t, replacer.cache, 4) + + // Different input SQL. + { + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, nil) + + _, _, err := replacer.RunSafely(ctx, ` + SELECT /* TEMPLATE: schema */river_job; + `, nil) + require.NoError(t, err) + } + require.Len(t, replacer.cache, 5) + }) + t.Run("NamedArgsNoInitialArgs", func(t *testing.T) { t.Parallel() @@ -369,3 +459,37 @@ func TestReplacer(t *testing.T) { wg.Wait() }) } + +func BenchmarkReplacer(b *testing.B) { + ctx := context.Background() + + runReplacer := func(b *testing.B, replacer *Replacer, stable bool) { + b.Helper() + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: stable, Value: "test_schema."}, + }, nil) + + _, _, err := replacer.RunSafely(ctx, ` + -- name: JobCountByState :one + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE state = @state; + `, nil) + require.NoError(b, err) + } + + b.Run("WithCache", func(b *testing.B) { + var replacer Replacer + for range b.N { + runReplacer(b, &replacer, true) + } + }) + + b.Run("WithoutCache", func(b *testing.B) { + var replacer Replacer + for range b.N { + runReplacer(b, &replacer, false) + } + }) +}