diff --git a/src/main/java/com/hubspot/jinjava/objects/collections/PyMap.java b/src/main/java/com/hubspot/jinjava/objects/collections/PyMap.java index cf867bfed..926505319 100644 --- a/src/main/java/com/hubspot/jinjava/objects/collections/PyMap.java +++ b/src/main/java/com/hubspot/jinjava/objects/collections/PyMap.java @@ -2,6 +2,10 @@ import com.google.common.collect.ForwardingMap; import com.hubspot.jinjava.objects.PyWrapper; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.ListIterator; import java.util.Map; import java.util.Set; @@ -62,9 +66,39 @@ public void putAll(Map m) { @Override public int hashCode() { int h = 0; - for (Entry entry : map.entrySet()) { - if (entry.getValue() != map && entry.getValue() != this) { - h += entry.hashCode(); + List valueList = new ArrayList<>(map.entrySet()); + ListIterator valueIterator = valueList.listIterator(); + Set visited = new HashSet<>(); + while (valueIterator.hasNext()) { + Object next = valueIterator.next(); + int code = System.identityHashCode(next); + if (visited.contains(code)) { + continue; + } else { + visited.add(code); + } + if (next instanceof Entry) { + Entry nextEntry = (Entry) next; + if (nextEntry.getKey() != null) { + h += nextEntry.getKey().hashCode(); + } + valueIterator.add(nextEntry.getValue()); + valueIterator.previous(); + } else if (next instanceof Iterable) { + for (Object o : (Iterable) next) { + valueIterator.add(o); + valueIterator.previous(); + } + } else if (next instanceof Map) { + ((Map) next).entrySet() + .forEach( + e -> { + valueIterator.add(e); + valueIterator.previous(); + } + ); + } else if (next != null) { + h += next.hashCode(); } } return h; diff --git a/src/test/java/com/hubspot/jinjava/objects/collections/PyMapTest.java b/src/test/java/com/hubspot/jinjava/objects/collections/PyMapTest.java index 3ec10fbe8..ea4cb8c15 100644 --- a/src/test/java/com/hubspot/jinjava/objects/collections/PyMapTest.java +++ b/src/test/java/com/hubspot/jinjava/objects/collections/PyMapTest.java @@ -3,6 +3,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import com.google.common.collect.ImmutableList; import com.hubspot.jinjava.BaseJinjavaTest; import com.hubspot.jinjava.Jinjava; import com.hubspot.jinjava.JinjavaConfig; @@ -10,6 +11,7 @@ import com.hubspot.jinjava.interpret.IndexOutOfRangeException; import com.hubspot.jinjava.interpret.RenderResult; import com.hubspot.jinjava.interpret.TemplateError; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import org.junit.Test; @@ -367,4 +369,61 @@ public void itUpdatesKeysWithVariableName() { ) .isEqualTo("value2"); } + + @Test + public void itComputesHashCodeWhenContainedWithinItself() { + PyMap map = new PyMap(new HashMap<>()); + map.put("map1key1", "value1"); + + PyMap map2 = new PyMap(new HashMap<>()); + map2.put("map2key1", map); + + map.put("map1key2", map2); + + assertThat(map.hashCode()).isEqualTo(-413943561); + } + + @Test + public void itComputesHashCodeWhenContainedWithinItselfWithFurtherEntries() { + PyMap map = new PyMap(new HashMap<>()); + map.put("map1key1", "value1"); + + PyMap map2 = new PyMap(new HashMap<>()); + map2.put("map2key1", map); + + map.put("map1key2", map2); + + int originalHashCode = map.hashCode(); + map2.put("newKey", "newValue"); + int newHashCode = map.hashCode(); + assertThat(originalHashCode).isNotEqualTo(newHashCode); + } + + @Test + public void itComputesHashCodeWhenContainedWithinItselfInsideList() { + PyMap map = new PyMap(new HashMap<>()); + map.put("map1key1", "value1"); + + PyMap map2 = new PyMap(new HashMap<>()); + map2.put("map2key1", map); + + map.put("map1key2", new PyList(ImmutableList.of((map2)))); + + assertThat(map.hashCode()).isEqualTo(-413943561); + } + + @Test + public void itComputesHashCodeWithNullKeysAndValues() { + PyMap map = new PyMap(new HashMap<>()); + map.put(null, "value1"); + + PyMap map2 = new PyMap(new HashMap<>()); + map2.put("map2key1", map); + + PyList list = new PyList(new ArrayList<>()); + list.add(null); + map.put("map1key2", new PyList(list)); + + assertThat(map.hashCode()).isEqualTo(-687497624); + } }