Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package builtin
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"reflect"
"sort"
Expand All @@ -16,6 +17,10 @@ import (
var (
Index map[string]int
Names []string

// MaxDepth limits the recursion depth for nested structures.
MaxDepth = 10000
ErrorMaxDepth = errors.New("recursion depth exceeded")
)

func init() {
Expand Down Expand Up @@ -377,7 +382,7 @@ var Builtins = []*Function{
{
Name: "max",
Func: func(args ...any) (any, error) {
return minMax("max", runtime.Less, args...)
return minMax("max", runtime.Less, 0, args...)
},
Validate: func(args []reflect.Type) (reflect.Type, error) {
return validateAggregateFunc("max", args)
Expand All @@ -386,7 +391,7 @@ var Builtins = []*Function{
{
Name: "min",
Func: func(args ...any) (any, error) {
return minMax("min", runtime.More, args...)
return minMax("min", runtime.More, 0, args...)
},
Validate: func(args []reflect.Type) (reflect.Type, error) {
return validateAggregateFunc("min", args)
Expand All @@ -395,7 +400,7 @@ var Builtins = []*Function{
{
Name: "mean",
Func: func(args ...any) (any, error) {
count, sum, err := mean(args...)
count, sum, err := mean(0, args...)
if err != nil {
return nil, err
}
Expand All @@ -411,7 +416,7 @@ var Builtins = []*Function{
{
Name: "median",
Func: func(args ...any) (any, error) {
values, err := median(args...)
values, err := median(0, args...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -940,7 +945,10 @@ var Builtins = []*Function{
if v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
return nil, size, fmt.Errorf("cannot flatten %s", v.Kind())
}
ret := flatten(v)
ret, err := flatten(v, 0)
if err != nil {
return nil, 0, err
}
size = uint(len(ret))
return ret, size, nil
},
Expand Down
97 changes: 97 additions & 0 deletions builtin/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -722,3 +722,100 @@ func TestBuiltin_with_deref(t *testing.T) {
})
}
}

func TestBuiltin_flatten_recursion(t *testing.T) {
var s []any
s = append(s, &s) // s contains a pointer to itself

env := map[string]any{
"arr": s,
}

program, err := expr.Compile("flatten(arr)", expr.Env(env))
require.NoError(t, err)

_, err = expr.Run(program, env)
require.Error(t, err)
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
}

func TestBuiltin_flatten_recursion_slice(t *testing.T) {
s := make([]any, 1)
s[0] = s

env := map[string]any{
"arr": s,
}

program, err := expr.Compile("flatten(arr)", expr.Env(env))
require.NoError(t, err)

_, err = expr.Run(program, env)
require.Error(t, err)
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
}

func TestBuiltin_numerical_recursion(t *testing.T) {
s := make([]any, 1)
s[0] = s

env := map[string]any{
"arr": s,
}

tests := []string{
"max(arr)",
"min(arr)",
"mean(arr)",
"median(arr)",
}

for _, input := range tests {
t.Run(input, func(t *testing.T) {
program, err := expr.Compile(input, expr.Env(env))
require.NoError(t, err)

_, err = expr.Run(program, env)
require.Error(t, err)
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
})
}
}

func TestBuiltin_recursion_custom_max_depth(t *testing.T) {
originalMaxDepth := builtin.MaxDepth
defer func() {
builtin.MaxDepth = originalMaxDepth
}()

// Set a small depth limit
builtin.MaxDepth = 2

// Create a deeply nested array (depth 5)
// [1, [2, [3, [4, [5]]]]]
arr := []any{1, []any{2, []any{3, []any{4, []any{5}}}}}

env := map[string]any{
"arr": arr,
}

t.Run("flatten exceeds max depth", func(t *testing.T) {
program, err := expr.Compile("flatten(arr)", expr.Env(env))
require.NoError(t, err)

_, err = expr.Run(program, env)
require.Error(t, err)
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
})

t.Run("flatten within max depth", func(t *testing.T) {
// Depth 2: [1, [2]]
shallowArr := []any{1, []any{2}}
envShallow := map[string]any{"arr": shallowArr}
program, err := expr.Compile("flatten(arr)", expr.Env(envShallow))
require.NoError(t, err)

_, err = expr.Run(program, envShallow)
require.NoError(t, err)
})
}
33 changes: 24 additions & 9 deletions builtin/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,15 +253,18 @@ func String(arg any) any {
return fmt.Sprintf("%v", arg)
}

func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
func minMax(name string, fn func(any, any) bool, depth int, args ...any) (any, error) {
if depth > MaxDepth {
return nil, ErrorMaxDepth
}
var val any
for _, arg := range args {
rv := reflect.ValueOf(arg)
switch rv.Kind() {
case reflect.Array, reflect.Slice:
size := rv.Len()
for i := 0; i < size; i++ {
elemVal, err := minMax(name, fn, rv.Index(i).Interface())
elemVal, err := minMax(name, fn, depth+1, rv.Index(i).Interface())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -294,7 +297,10 @@ func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
return val, nil
}

func mean(args ...any) (int, float64, error) {
func mean(depth int, args ...any) (int, float64, error) {
if depth > MaxDepth {
return 0, 0, ErrorMaxDepth
}
var total float64
var count int

Expand All @@ -304,7 +310,7 @@ func mean(args ...any) (int, float64, error) {
case reflect.Array, reflect.Slice:
size := rv.Len()
for i := 0; i < size; i++ {
elemCount, elemSum, err := mean(rv.Index(i).Interface())
elemCount, elemSum, err := mean(depth+1, rv.Index(i).Interface())
if err != nil {
return 0, 0, err
}
Expand All @@ -327,7 +333,10 @@ func mean(args ...any) (int, float64, error) {
return count, total, nil
}

func median(args ...any) ([]float64, error) {
func median(depth int, args ...any) ([]float64, error) {
if depth > MaxDepth {
return nil, ErrorMaxDepth
}
var values []float64

for _, arg := range args {
Expand All @@ -336,7 +345,7 @@ func median(args ...any) ([]float64, error) {
case reflect.Array, reflect.Slice:
size := rv.Len()
for i := 0; i < size; i++ {
elems, err := median(rv.Index(i).Interface())
elems, err := median(depth+1, rv.Index(i).Interface())
if err != nil {
return nil, err
}
Expand All @@ -355,18 +364,24 @@ func median(args ...any) ([]float64, error) {
return values, nil
}

func flatten(arg reflect.Value) []any {
func flatten(arg reflect.Value, depth int) ([]any, error) {
if depth > MaxDepth {
return nil, ErrorMaxDepth
}
ret := []any{}
for i := 0; i < arg.Len(); i++ {
v := deref.Value(arg.Index(i))
if v.Kind() == reflect.Array || v.Kind() == reflect.Slice {
x := flatten(v)
x, err := flatten(v, depth+1)
if err != nil {
return nil, err
}
ret = append(ret, x...)
} else {
ret = append(ret, v.Interface())
}
}
return ret
return ret, nil
}

func get(params ...any) (out any, err error) {
Expand Down
Loading