diff --git a/src/main/java/com/hubspot/jinjava/JinjavaConfig.java b/src/main/java/com/hubspot/jinjava/JinjavaConfig.java index ad527458c..ea0abea84 100644 --- a/src/main/java/com/hubspot/jinjava/JinjavaConfig.java +++ b/src/main/java/com/hubspot/jinjava/JinjavaConfig.java @@ -19,6 +19,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.hubspot.jinjava.el.JinjavaInterpreterResolver; +import com.hubspot.jinjava.el.JinjavaObjectUnwrapper; +import com.hubspot.jinjava.el.ObjectUnwrapper; import com.hubspot.jinjava.interpret.Context; import com.hubspot.jinjava.interpret.Context.Library; import com.hubspot.jinjava.interpret.InterpreterFactory; @@ -69,6 +71,8 @@ public class JinjavaConfig { private final boolean enablePreciseDivideFilter; private final ObjectMapper objectMapper; + private final ObjectUnwrapper objectUnwrapper; + public static Builder newBuilder() { return new Builder(); } @@ -123,6 +127,7 @@ private JinjavaConfig(Builder builder) { legacyOverrides = builder.legacyOverrides; enablePreciseDivideFilter = builder.enablePreciseDivideFilter; objectMapper = builder.objectMapper; + objectUnwrapper = builder.objectUnwrapper; } public Charset getCharset() { @@ -221,6 +226,10 @@ public ObjectMapper getObjectMapper() { return objectMapper; } + public ObjectUnwrapper getObjectUnwrapper() { + return objectUnwrapper; + } + /** * @deprecated Replaced by {@link LegacyOverrides#isIterateOverMapKeys()} */ @@ -272,6 +281,8 @@ public static class Builder { private boolean enablePreciseDivideFilter = false; private ObjectMapper objectMapper = new ObjectMapper(); + private ObjectUnwrapper objectUnwrapper = new JinjavaObjectUnwrapper(); + private Builder() {} public Builder withCharset(Charset charset) { @@ -427,6 +438,11 @@ public Builder withObjectMapper(ObjectMapper objectMapper) { return this; } + public Builder withObjectUnwrapper(ObjectUnwrapper objectUnwrapper) { + this.objectUnwrapper = objectUnwrapper; + return this; + } + public JinjavaConfig build() { return new JinjavaConfig(this); } diff --git a/src/main/java/com/hubspot/jinjava/el/ExpressionResolver.java b/src/main/java/com/hubspot/jinjava/el/ExpressionResolver.java index 3ddd44b7d..b60ab9efd 100644 --- a/src/main/java/com/hubspot/jinjava/el/ExpressionResolver.java +++ b/src/main/java/com/hubspot/jinjava/el/ExpressionResolver.java @@ -13,7 +13,6 @@ import com.hubspot.jinjava.interpret.InvalidArgumentException; import com.hubspot.jinjava.interpret.InvalidInputException; import com.hubspot.jinjava.interpret.JinjavaInterpreter; -import com.hubspot.jinjava.interpret.LazyExpression; import com.hubspot.jinjava.interpret.TemplateError; import com.hubspot.jinjava.interpret.TemplateError.ErrorItem; import com.hubspot.jinjava.interpret.TemplateError.ErrorReason; @@ -40,6 +39,7 @@ public class ExpressionResolver { private final ExpressionFactory expressionFactory; private final JinjavaInterpreterResolver resolver; private final JinjavaELContext elContext; + private final ObjectUnwrapper objectUnwrapper; private static final String EXPRESSION_START_TOKEN = "#{"; private static final String EXPRESSION_END_TOKEN = "}"; @@ -56,6 +56,7 @@ public ExpressionResolver(JinjavaInterpreter interpreter, Jinjava jinjava) { for (ELFunctionDefinition fn : jinjava.getGlobalContext().getAllFunctions()) { this.elContext.setFunction(fn.getNamespace(), fn.getLocalName(), fn.getMethod()); } + objectUnwrapper = interpreter.getConfig().getObjectUnwrapper(); } /** @@ -110,10 +111,7 @@ private Object resolveExpression(String expression, boolean addToResolvedExpress ); } - // resolve the LazyExpression supplier automatically - if (result instanceof LazyExpression) { - result = ((LazyExpression) result).get(); - } + result = objectUnwrapper.unwrapObject(result); validateResult(result); diff --git a/src/main/java/com/hubspot/jinjava/el/JinjavaInterpreterResolver.java b/src/main/java/com/hubspot/jinjava/el/JinjavaInterpreterResolver.java index 285043d28..3c81ca19a 100644 --- a/src/main/java/com/hubspot/jinjava/el/JinjavaInterpreterResolver.java +++ b/src/main/java/com/hubspot/jinjava/el/JinjavaInterpreterResolver.java @@ -13,7 +13,6 @@ import com.hubspot.jinjava.interpret.DeferredValueException; import com.hubspot.jinjava.interpret.DisabledException; import com.hubspot.jinjava.interpret.JinjavaInterpreter; -import com.hubspot.jinjava.interpret.LazyExpression; import com.hubspot.jinjava.interpret.TemplateError; import com.hubspot.jinjava.interpret.TemplateError.ErrorItem; import com.hubspot.jinjava.interpret.TemplateError.ErrorReason; @@ -40,7 +39,6 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; -import java.util.Optional; import javax.el.ArrayELResolver; import javax.el.CompositeELResolver; import javax.el.ELContext; @@ -74,10 +72,12 @@ public class JinjavaInterpreterResolver extends SimpleResolver { }; private final JinjavaInterpreter interpreter; + private final ObjectUnwrapper objectUnwrapper; public JinjavaInterpreterResolver(JinjavaInterpreter interpreter) { super(interpreter.getConfig().getElResolver()); this.interpreter = interpreter; + this.objectUnwrapper = interpreter.getConfig().getObjectUnwrapper(); } @Override @@ -203,20 +203,9 @@ private Object getValue( } else { // Get property of base object. try { - if (base instanceof Optional) { - Optional optBase = (Optional) base; - if (!optBase.isPresent()) { - return null; - } - - base = optBase.get(); - } - - if (base instanceof LazyExpression) { - base = ((LazyExpression) base).get(); - if (base == null) { - return null; - } + base = objectUnwrapper.unwrapObject(base); + if (base == null) { + return null; } // java doesn't natively support negative array indices, so the @@ -240,20 +229,9 @@ private Object getValue( value = super.getValue(context, base, propertyName); - if (value instanceof Optional) { - Optional optValue = (Optional) value; - if (!optValue.isPresent()) { - return null; - } - - value = optValue.get(); - } - - if (value instanceof LazyExpression) { - value = ((LazyExpression) value).get(); - if (value == null) { - return null; - } + value = objectUnwrapper.unwrapObject(value); + if (value == null) { + return null; } if (value instanceof DeferredValue) { @@ -309,11 +287,9 @@ Object wrap(Object value) { return value; } - if (value instanceof LazyExpression) { - value = ((LazyExpression) value).get(); - if (value == null) { - return null; - } + value = objectUnwrapper.unwrapObject(value); + if (value == null) { + return null; } if (value instanceof PyWrapper) { diff --git a/src/main/java/com/hubspot/jinjava/el/JinjavaObjectUnwrapper.java b/src/main/java/com/hubspot/jinjava/el/JinjavaObjectUnwrapper.java new file mode 100644 index 000000000..338dadeb7 --- /dev/null +++ b/src/main/java/com/hubspot/jinjava/el/JinjavaObjectUnwrapper.java @@ -0,0 +1,25 @@ +package com.hubspot.jinjava.el; + +import com.hubspot.jinjava.interpret.LazyExpression; +import java.util.Optional; + +public class JinjavaObjectUnwrapper implements ObjectUnwrapper { + + @Override + public Object unwrapObject(Object o) { + if (o instanceof LazyExpression) { + o = ((LazyExpression) o).get(); + } + + if (o instanceof Optional) { + Optional optValue = (Optional) o; + if (!optValue.isPresent()) { + return null; + } + + o = optValue.get(); + } + + return o; + } +} diff --git a/src/main/java/com/hubspot/jinjava/el/ObjectUnwrapper.java b/src/main/java/com/hubspot/jinjava/el/ObjectUnwrapper.java new file mode 100644 index 000000000..2d04ad6ec --- /dev/null +++ b/src/main/java/com/hubspot/jinjava/el/ObjectUnwrapper.java @@ -0,0 +1,5 @@ +package com.hubspot.jinjava.el; + +public interface ObjectUnwrapper { + Object unwrapObject(Object o); +} diff --git a/src/main/java/com/hubspot/jinjava/util/EagerContextWatcher.java b/src/main/java/com/hubspot/jinjava/util/EagerContextWatcher.java index a528d3be9..b3e7c02b5 100644 --- a/src/main/java/com/hubspot/jinjava/util/EagerContextWatcher.java +++ b/src/main/java/com/hubspot/jinjava/util/EagerContextWatcher.java @@ -368,6 +368,7 @@ private static Object getObjectOrHashCode(Object o) { if (o instanceof LazyExpression) { o = ((LazyExpression) o).get(); } + if (o instanceof PyList && !((PyList) o).toList().contains(o)) { return o.hashCode(); }