Skip to content

Commit 2a07330

Browse files
committed
fix(builtin): limit recursion depth
Add builtin.MaxDepth (default 10k) to prevent stack overflows when processing deeply nested or cyclic structures in builtin functions. The functions flatten, min, max, mean, and median now return a "recursion depth exceeded" error instead of crashing the runtime. Signed-off-by: Ville Vesilehto <[email protected]>
1 parent ad49544 commit 2a07330

File tree

3 files changed

+134
-14
lines changed

3 files changed

+134
-14
lines changed

builtin/builtin.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package builtin
33
import (
44
"encoding/base64"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"reflect"
89
"sort"
@@ -16,6 +17,10 @@ import (
1617
var (
1718
Index map[string]int
1819
Names []string
20+
21+
// MaxDepth limits the recursion depth for nested structures.
22+
MaxDepth = 10000
23+
ErrorMaxDepth = errors.New("recursion depth exceeded")
1924
)
2025

2126
func init() {
@@ -377,7 +382,7 @@ var Builtins = []*Function{
377382
{
378383
Name: "max",
379384
Func: func(args ...any) (any, error) {
380-
return minMax("max", runtime.Less, args...)
385+
return minMax("max", runtime.Less, 0, args...)
381386
},
382387
Validate: func(args []reflect.Type) (reflect.Type, error) {
383388
return validateAggregateFunc("max", args)
@@ -386,7 +391,7 @@ var Builtins = []*Function{
386391
{
387392
Name: "min",
388393
Func: func(args ...any) (any, error) {
389-
return minMax("min", runtime.More, args...)
394+
return minMax("min", runtime.More, 0, args...)
390395
},
391396
Validate: func(args []reflect.Type) (reflect.Type, error) {
392397
return validateAggregateFunc("min", args)
@@ -395,7 +400,7 @@ var Builtins = []*Function{
395400
{
396401
Name: "mean",
397402
Func: func(args ...any) (any, error) {
398-
count, sum, err := mean(args...)
403+
count, sum, err := mean(0, args...)
399404
if err != nil {
400405
return nil, err
401406
}
@@ -411,7 +416,7 @@ var Builtins = []*Function{
411416
{
412417
Name: "median",
413418
Func: func(args ...any) (any, error) {
414-
values, err := median(args...)
419+
values, err := median(0, args...)
415420
if err != nil {
416421
return nil, err
417422
}
@@ -940,7 +945,10 @@ var Builtins = []*Function{
940945
if v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
941946
return nil, size, fmt.Errorf("cannot flatten %s", v.Kind())
942947
}
943-
ret := flatten(v)
948+
ret, err := flatten(v, 0)
949+
if err != nil {
950+
return nil, 0, err
951+
}
944952
size = uint(len(ret))
945953
return ret, size, nil
946954
},

builtin/builtin_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,3 +722,100 @@ func TestBuiltin_with_deref(t *testing.T) {
722722
})
723723
}
724724
}
725+
726+
func TestBuiltin_flatten_recursion(t *testing.T) {
727+
var s []any
728+
s = append(s, &s) // s contains a pointer to itself
729+
730+
env := map[string]any{
731+
"arr": s,
732+
}
733+
734+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
735+
require.NoError(t, err)
736+
737+
_, err = expr.Run(program, env)
738+
require.Error(t, err)
739+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
740+
}
741+
742+
func TestBuiltin_flatten_recursion_slice(t *testing.T) {
743+
s := make([]any, 1)
744+
s[0] = s
745+
746+
env := map[string]any{
747+
"arr": s,
748+
}
749+
750+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
751+
require.NoError(t, err)
752+
753+
_, err = expr.Run(program, env)
754+
require.Error(t, err)
755+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
756+
}
757+
758+
func TestBuiltin_numerical_recursion(t *testing.T) {
759+
s := make([]any, 1)
760+
s[0] = s
761+
762+
env := map[string]any{
763+
"arr": s,
764+
}
765+
766+
tests := []string{
767+
"max(arr)",
768+
"min(arr)",
769+
"mean(arr)",
770+
"median(arr)",
771+
}
772+
773+
for _, input := range tests {
774+
t.Run(input, func(t *testing.T) {
775+
program, err := expr.Compile(input, expr.Env(env))
776+
require.NoError(t, err)
777+
778+
_, err = expr.Run(program, env)
779+
require.Error(t, err)
780+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
781+
})
782+
}
783+
}
784+
785+
func TestBuiltin_recursion_custom_max_depth(t *testing.T) {
786+
originalMaxDepth := builtin.MaxDepth
787+
defer func() {
788+
builtin.MaxDepth = originalMaxDepth
789+
}()
790+
791+
// Set a small depth limit
792+
builtin.MaxDepth = 2
793+
794+
// Create a deeply nested array (depth 5)
795+
// [1, [2, [3, [4, [5]]]]]
796+
arr := []any{1, []any{2, []any{3, []any{4, []any{5}}}}}
797+
798+
env := map[string]any{
799+
"arr": arr,
800+
}
801+
802+
t.Run("flatten exceeds max depth", func(t *testing.T) {
803+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
804+
require.NoError(t, err)
805+
806+
_, err = expr.Run(program, env)
807+
require.Error(t, err)
808+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
809+
})
810+
811+
t.Run("flatten within max depth", func(t *testing.T) {
812+
// Depth 2: [1, [2]]
813+
shallowArr := []any{1, []any{2}}
814+
envShallow := map[string]any{"arr": shallowArr}
815+
program, err := expr.Compile("flatten(arr)", expr.Env(envShallow))
816+
require.NoError(t, err)
817+
818+
_, err = expr.Run(program, envShallow)
819+
require.NoError(t, err)
820+
})
821+
}

builtin/lib.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -253,15 +253,18 @@ func String(arg any) any {
253253
return fmt.Sprintf("%v", arg)
254254
}
255255

256-
func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
256+
func minMax(name string, fn func(any, any) bool, depth int, args ...any) (any, error) {
257+
if depth > MaxDepth {
258+
return nil, ErrorMaxDepth
259+
}
257260
var val any
258261
for _, arg := range args {
259262
rv := reflect.ValueOf(arg)
260263
switch rv.Kind() {
261264
case reflect.Array, reflect.Slice:
262265
size := rv.Len()
263266
for i := 0; i < size; i++ {
264-
elemVal, err := minMax(name, fn, rv.Index(i).Interface())
267+
elemVal, err := minMax(name, fn, depth+1, rv.Index(i).Interface())
265268
if err != nil {
266269
return nil, err
267270
}
@@ -294,7 +297,10 @@ func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
294297
return val, nil
295298
}
296299

297-
func mean(args ...any) (int, float64, error) {
300+
func mean(depth int, args ...any) (int, float64, error) {
301+
if depth > MaxDepth {
302+
return 0, 0, ErrorMaxDepth
303+
}
298304
var total float64
299305
var count int
300306

@@ -304,7 +310,7 @@ func mean(args ...any) (int, float64, error) {
304310
case reflect.Array, reflect.Slice:
305311
size := rv.Len()
306312
for i := 0; i < size; i++ {
307-
elemCount, elemSum, err := mean(rv.Index(i).Interface())
313+
elemCount, elemSum, err := mean(depth+1, rv.Index(i).Interface())
308314
if err != nil {
309315
return 0, 0, err
310316
}
@@ -327,7 +333,10 @@ func mean(args ...any) (int, float64, error) {
327333
return count, total, nil
328334
}
329335

330-
func median(args ...any) ([]float64, error) {
336+
func median(depth int, args ...any) ([]float64, error) {
337+
if depth > MaxDepth {
338+
return nil, ErrorMaxDepth
339+
}
331340
var values []float64
332341

333342
for _, arg := range args {
@@ -336,7 +345,7 @@ func median(args ...any) ([]float64, error) {
336345
case reflect.Array, reflect.Slice:
337346
size := rv.Len()
338347
for i := 0; i < size; i++ {
339-
elems, err := median(rv.Index(i).Interface())
348+
elems, err := median(depth+1, rv.Index(i).Interface())
340349
if err != nil {
341350
return nil, err
342351
}
@@ -355,18 +364,24 @@ func median(args ...any) ([]float64, error) {
355364
return values, nil
356365
}
357366

358-
func flatten(arg reflect.Value) []any {
367+
func flatten(arg reflect.Value, depth int) ([]any, error) {
368+
if depth > MaxDepth {
369+
return nil, ErrorMaxDepth
370+
}
359371
ret := []any{}
360372
for i := 0; i < arg.Len(); i++ {
361373
v := deref.Value(arg.Index(i))
362374
if v.Kind() == reflect.Array || v.Kind() == reflect.Slice {
363-
x := flatten(v)
375+
x, err := flatten(v, depth+1)
376+
if err != nil {
377+
return nil, err
378+
}
364379
ret = append(ret, x...)
365380
} else {
366381
ret = append(ret, v.Interface())
367382
}
368383
}
369-
return ret
384+
return ret, nil
370385
}
371386

372387
func get(params ...any) (out any, err error) {

0 commit comments

Comments
 (0)