diff --git a/riverdriver/riverdatabasesql/river_database_sql_driver.go b/riverdriver/riverdatabasesql/river_database_sql_driver.go index 84821f76..41e985a8 100644 --- a/riverdriver/riverdatabasesql/river_database_sql_driver.go +++ b/riverdriver/riverdatabasesql/river_database_sql_driver.go @@ -981,6 +981,7 @@ type templateReplaceWrapper struct { func (w templateReplaceWrapper) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) { sql, args = w.replacer.Run(ctx, sql, args) + convertDurationToInterval(sql, args) return w.dbtx.ExecContext(ctx, sql, args...) } @@ -1107,3 +1108,17 @@ func schemaTemplateParam(ctx context.Context, schema string) context.Context { "schema": {Value: schema}, }, nil) } + +// convertDurationToInterval converts Go's time.Duration to PostgreSQL's +func convertDurationToInterval(sql string, args []any) { + if !strings.Contains(sql, "interval") { + return + } + + for i, arg := range args { + if d, ok := arg.(time.Duration); ok { + pgInterval := fmt.Sprintf("%d seconds", int(d.Seconds())) + args[i] = pgInterval + } + } +} diff --git a/riverdriver/riverdatabasesql/river_database_sql_driver_test.go b/riverdriver/riverdatabasesql/river_database_sql_driver_test.go index 209b009b..5c24d754 100644 --- a/riverdriver/riverdatabasesql/river_database_sql_driver_test.go +++ b/riverdriver/riverdatabasesql/river_database_sql_driver_test.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "testing" + "time" "github.com/stretchr/testify/require" @@ -131,3 +132,34 @@ func TestSchemaTemplateParam(t *testing.T) { require.Equal(t, "SELECT 1 FROM custom_schema.river_job", updatedSQL) }) } + +func TestConvertDurationToInterval(t *testing.T) { + t.Parallel() + + testCases := []struct { + Desc string + InputSQL string + InputArgs []any + ExpectedArgs []any + }{ + { + Desc: "Convert duration to interval", + InputSQL: ` + INSERT INTO river_leader(leader_id, elected_At, expires_at) + VALUES($1, now(), now() + $2::interval) + ON CONFICT (name) + DO NOTHING + `, + InputArgs: []any{"river", 15 * time.Second}, + ExpectedArgs: []any{"river", "15 seconds"}, + }, + } + for _, tt := range testCases { + t.Run(tt.Desc, func(t *testing.T) { + t.Parallel() + + convertDurationToInterval(tt.InputSQL, tt.InputArgs) + require.Equal(t, tt.InputArgs, tt.ExpectedArgs) + }) + } +}