diff --git a/dagger-common/src/main/java/io/odpf/dagger/common/serde/typehandler/complex/MapHandler.java b/dagger-common/src/main/java/io/odpf/dagger/common/serde/typehandler/complex/MapHandler.java index b48fbe80b..6b5d05ae6 100644 --- a/dagger-common/src/main/java/io/odpf/dagger/common/serde/typehandler/complex/MapHandler.java +++ b/dagger-common/src/main/java/io/odpf/dagger/common/serde/typehandler/complex/MapHandler.java @@ -2,11 +2,11 @@ import com.google.protobuf.Descriptors; import com.google.protobuf.DynamicMessage; -import com.google.protobuf.MapEntry; -import com.google.protobuf.WireFormat; -import io.odpf.dagger.common.serde.typehandler.TypeHandler; import io.odpf.dagger.common.serde.typehandler.RowFactory; +import io.odpf.dagger.common.serde.typehandler.TypeHandler; +import io.odpf.dagger.common.serde.typehandler.TypeHandlerFactory; import io.odpf.dagger.common.serde.typehandler.TypeInformationFactory; +import io.odpf.dagger.common.serde.typehandler.repeated.RepeatedMessageHandler; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.types.Row; @@ -27,6 +27,7 @@ public class MapHandler implements TypeHandler { private Descriptors.FieldDescriptor fieldDescriptor; + private TypeHandler repeatedMessageHandler; /** * Instantiates a new Map proto handler. @@ -35,6 +36,7 @@ public class MapHandler implements TypeHandler { */ public MapHandler(Descriptors.FieldDescriptor fieldDescriptor) { this.fieldDescriptor = fieldDescriptor; + this.repeatedMessageHandler = new RepeatedMessageHandler(fieldDescriptor); } @Override @@ -47,38 +49,44 @@ public DynamicMessage.Builder transformToProtoBuilder(DynamicMessage.Builder bui if (!canHandle() || field == null) { return builder; } - if (field instanceof Map) { - convertFromMap(builder, (Map) field); - } - - if (field instanceof Object[]) { - convertFromRow(builder, (Object[]) field); + Map mapField = (Map) field; + ArrayList rows = new ArrayList<>(); + for (Entry entry : mapField.entrySet()) { + rows.add(Row.of(entry.getKey(), entry.getValue())); + } + return repeatedMessageHandler.transformToProtoBuilder(builder, rows.toArray()); } - - return builder; + return repeatedMessageHandler.transformToProtoBuilder(builder, field); } @Override public Object transformFromPostProcessor(Object field) { ArrayList rows = new ArrayList<>(); - if (field != null) { - Map mapField = (Map) field; - for (Entry entry : mapField.entrySet()) { - rows.add(getRowFromMap(entry)); + if (field == null) { + return rows.toArray(); + } + if (field instanceof Map) { + Map mapField = (Map) field; + for (Entry entry : mapField.entrySet()) { + Descriptors.FieldDescriptor keyDescriptor = fieldDescriptor.getMessageType().findFieldByName("key"); + Descriptors.FieldDescriptor valueDescriptor = fieldDescriptor.getMessageType().findFieldByName("value"); + TypeHandler handler = TypeHandlerFactory.getTypeHandler(keyDescriptor); + Object key = handler.transformFromPostProcessor(entry.getKey()); + Object value = TypeHandlerFactory.getTypeHandler(valueDescriptor).transformFromPostProcessor(entry.getValue()); + rows.add(Row.of(key, value)); } + return rows.toArray(); + } + if (field instanceof List) { + return repeatedMessageHandler.transformFromPostProcessor(field); } return rows.toArray(); } @Override public Object transformFromProto(Object field) { - ArrayList rows = new ArrayList<>(); - if (field != null) { - List protos = (List) field; - protos.forEach(proto -> rows.add(getRowFromMap(proto))); - } - return rows.toArray(); + return repeatedMessageHandler.transformFromProto(field); } @Override @@ -127,53 +135,4 @@ public Object transformToJson(Object field) { public TypeInformation getTypeInformation() { return Types.OBJECT_ARRAY(TypeInformationFactory.getRowType(fieldDescriptor.getMessageType())); } - - private Row getRowFromMap(Entry entry) { - Row row = new Row(2); - row.setField(0, entry.getKey()); - row.setField(1, entry.getValue()); - return row; - } - - private Row getRowFromMap(DynamicMessage proto) { - Row row = new Row(2); - row.setField(0, parse(proto, "key")); - row.setField(1, parse(proto, "value")); - return row; - } - - private Object parse(DynamicMessage proto, String fieldName) { - Object field = proto.getField(proto.getDescriptorForType().findFieldByName(fieldName)); - if (DynamicMessage.class.equals(field.getClass())) { - field = RowFactory.createRow((DynamicMessage) field); - } - return field; - } - - private void convertFromRow(DynamicMessage.Builder builder, Object[] field) { - for (Object inputValue : field) { - Row inputRow = (Row) inputValue; - if (inputRow.getArity() != 2) { - throw new IllegalArgumentException("Row: " + inputRow.toString() + " of size: " + inputRow.getArity() + " cannot be converted to map"); - } - MapEntry mapEntry = MapEntry - .newDefaultInstance(fieldDescriptor.getMessageType(), WireFormat.FieldType.STRING, "", WireFormat.FieldType.STRING, ""); - builder.addRepeatedField(fieldDescriptor, - mapEntry.toBuilder() - .setKey((String) inputRow.getField(0)) - .setValue((String) inputRow.getField(1)) - .buildPartial()); - } - } - - private void convertFromMap(DynamicMessage.Builder builder, Map field) { - for (Entry entry : field.entrySet()) { - MapEntry mapEntry = MapEntry.newDefaultInstance(fieldDescriptor.getMessageType(), WireFormat.FieldType.STRING, "", WireFormat.FieldType.STRING, ""); - builder.addRepeatedField(fieldDescriptor, - mapEntry.toBuilder() - .setKey(entry.getKey()) - .setValue(entry.getValue()) - .buildPartial()); - } - } } diff --git a/dagger-common/src/test/java/io/odpf/dagger/common/serde/typehandler/complex/MapHandlerTest.java b/dagger-common/src/test/java/io/odpf/dagger/common/serde/typehandler/complex/MapHandlerTest.java index c522a4459..5a9c018e2 100644 --- a/dagger-common/src/test/java/io/odpf/dagger/common/serde/typehandler/complex/MapHandlerTest.java +++ b/dagger-common/src/test/java/io/odpf/dagger/common/serde/typehandler/complex/MapHandlerTest.java @@ -16,7 +16,6 @@ import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; -import org.junit.Assert; import org.junit.Test; import java.util.ArrayList; @@ -81,13 +80,11 @@ public void shouldSetMapFieldIfStringMapPassed() { inputMap.put("b", "456"); DynamicMessage.Builder returnedBuilder = mapHandler.transformToProtoBuilder(builder, inputMap); - List entries = (List) returnedBuilder.getField(mapFieldDescriptor); + List entries = (List) returnedBuilder.getField(mapFieldDescriptor); assertEquals(2, entries.size()); - assertEquals("a", entries.get(0).getAllFields().values().toArray()[0]); - assertEquals("123", entries.get(0).getAllFields().values().toArray()[1]); - assertEquals("b", entries.get(1).getAllFields().values().toArray()[0]); - assertEquals("456", entries.get(1).getAllFields().values().toArray()[1]); + assertArrayEquals(Arrays.asList("a", "123").toArray(), entries.get(0).getAllFields().values().toArray()); + assertArrayEquals(Arrays.asList("b", "456").toArray(), entries.get(1).getAllFields().values().toArray()); } @Test @@ -111,28 +108,29 @@ public void shouldSetMapFieldIfArrayofObjectsHavingRowsWithStringFieldsPassed() inputRows.add(inputRow2); DynamicMessage.Builder returnedBuilder = mapHandler.transformToProtoBuilder(builder, inputRows.toArray()); - List entries = (List) returnedBuilder.getField(mapFieldDescriptor); + List entries = (List) returnedBuilder.getField(mapFieldDescriptor); assertEquals(2, entries.size()); - assertEquals("a", entries.get(0).getAllFields().values().toArray()[0]); - assertEquals("123", entries.get(0).getAllFields().values().toArray()[1]); - assertEquals("b", entries.get(1).getAllFields().values().toArray()[0]); - assertEquals("456", entries.get(1).getAllFields().values().toArray()[1]); + assertArrayEquals(Arrays.asList("a", "123").toArray(), entries.get(0).getAllFields().values().toArray()); + assertArrayEquals(Arrays.asList("b", "456").toArray(), entries.get(1).getAllFields().values().toArray()); } @Test - public void shouldThrowExceptionIfRowsPassedAreNotOfArityTwo() { - Descriptors.FieldDescriptor mapFieldDescriptor = TestBookingLogMessage.getDescriptor().findFieldByName("metadata"); - MapHandler mapHandler = new MapHandler(mapFieldDescriptor); - DynamicMessage.Builder builder = DynamicMessage.newBuilder(mapFieldDescriptor.getContainingType()); + public void shouldHandleComplexTypeValuesForSerialization() throws InvalidProtocolBufferException { + Row inputValue1 = Row.of("12345", Row.of(Arrays.asList("a", "b"))); + Row inputValue2 = Row.of(1234123, Row.of(Arrays.asList("d", "e"))); + Object input = Arrays.asList(inputValue1, inputValue2).toArray(); - ArrayList inputRows = new ArrayList<>(); + Descriptors.FieldDescriptor intMessageDescriptor = TestComplexMap.getDescriptor().findFieldByName("int_message"); + DynamicMessage.Builder builder = DynamicMessage.newBuilder(TestComplexMap.getDescriptor()); - Row inputRow = new Row(3); - inputRows.add(inputRow); - IllegalArgumentException exception = Assert.assertThrows(IllegalArgumentException.class, - () -> mapHandler.transformToProtoBuilder(builder, inputRows.toArray())); - assertEquals("Row: +I[null, null, null] of size: 3 cannot be converted to map", exception.getMessage()); + byte[] data = new MapHandler(intMessageDescriptor).transformToProtoBuilder(builder, input).build().toByteArray(); + TestComplexMap actualMsg = TestComplexMap.parseFrom(data); + assertArrayEquals(Arrays.asList(12345L, 1234123L).toArray(), actualMsg.getIntMessageMap().keySet().toArray()); + TestComplexMap.IdMessage idMessage = (TestComplexMap.IdMessage) actualMsg.getIntMessageMap().values().toArray()[0]; + assertTrue(idMessage.getIdsList().containsAll(Arrays.asList("a", "b"))); + idMessage = (TestComplexMap.IdMessage) actualMsg.getIntMessageMap().values().toArray()[1]; + assertTrue(idMessage.getIdsList().containsAll(Arrays.asList("d", "e"))); } @Test @@ -158,12 +156,8 @@ public void shouldReturnArrayOfRowHavingFieldsSetAsInputMapAndOfSizeTwoForTransf List outputValues = Arrays.asList((Object[]) mapHandler.transformFromPostProcessor(inputMap)); - assertEquals("a", ((Row) outputValues.get(0)).getField(0)); - assertEquals("123", ((Row) outputValues.get(0)).getField(1)); - assertEquals(2, ((Row) outputValues.get(0)).getArity()); - assertEquals("b", ((Row) outputValues.get(1)).getField(0)); - assertEquals("456", ((Row) outputValues.get(1)).getField(1)); - assertEquals(2, ((Row) outputValues.get(1)).getArity()); + assertEquals(Row.of("a", "123"), outputValues.get(0)); + assertEquals(Row.of("b", "456"), outputValues.get(1)); } @Test @@ -210,12 +204,8 @@ public void shouldReturnArrayOfRowHavingFieldsSetAsInputMapAndOfSizeTwoForTransf List outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor))); - assertEquals("a", ((Row) outputValues.get(0)).getField(0)); - assertEquals("123", ((Row) outputValues.get(0)).getField(1)); - assertEquals(2, ((Row) outputValues.get(0)).getArity()); - assertEquals("b", ((Row) outputValues.get(1)).getField(0)); - assertEquals("456", ((Row) outputValues.get(1)).getField(1)); - assertEquals(2, ((Row) outputValues.get(1)).getArity()); + assertEquals(Row.of("a", "123"), outputValues.get(0)); + assertEquals(Row.of("b", "456"), outputValues.get(1)); } @Test @@ -247,16 +237,11 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie List outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor))); - assertEquals(1, ((Row) outputValues.get(0)).getField(0)); - assertEquals("123", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0)); - assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1)); - assertEquals("abc", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2)); - assertEquals(2, ((Row) outputValues.get(0)).getArity()); - assertEquals(2, ((Row) outputValues.get(1)).getField(0)); - assertEquals("456", ((Row) ((Row) outputValues.get(1)).getField(1)).getField(0)); - assertEquals("", ((Row) ((Row) outputValues.get(1)).getField(1)).getField(1)); - assertEquals("efg", ((Row) ((Row) outputValues.get(1)).getField(1)).getField(2)); - assertEquals(2, ((Row) outputValues.get(1)).getArity()); + Row mapEntry1 = Row.of(1, Row.of("123", "", "abc")); + Row mapEntry2 = Row.of(2, Row.of("456", "", "efg")); + + assertEquals(mapEntry1, outputValues.get(0)); + assertEquals(mapEntry2, outputValues.get(1)); } @Test @@ -271,11 +256,8 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie List outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor))); - assertEquals(0, ((Row) outputValues.get(0)).getField(0)); - assertEquals("123", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0)); - assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1)); - assertEquals("abc", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2)); - assertEquals(2, ((Row) outputValues.get(0)).getArity()); + Row expected = Row.of(0, Row.of("123", "", "abc")); + assertEquals(expected, outputValues.get(0)); } @Test @@ -290,11 +272,9 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie List outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor))); - assertEquals(1, ((Row) outputValues.get(0)).getField(0)); - assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0)); - assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1)); - assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2)); - assertEquals(2, ((Row) outputValues.get(0)).getArity()); + Row expected = Row.of(1, Row.of("", "", "")); + + assertEquals(expected, outputValues.get(0)); } @Test @@ -309,11 +289,9 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie List outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor))); - assertEquals(0, ((Row) outputValues.get(0)).getField(0)); - assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0)); - assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1)); - assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2)); - assertEquals(2, ((Row) outputValues.get(0)).getArity()); + Row expected = Row.of(0, Row.of("", "", "")); + + assertEquals(expected, outputValues.get(0)); } @Test @@ -328,11 +306,8 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie List outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor))); - assertEquals(0, ((Row) outputValues.get(0)).getField(0)); - assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0)); - assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1)); - assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2)); - assertEquals(2, ((Row) outputValues.get(0)).getArity()); + Row expected = Row.of(0, Row.of("", "", "")); + assertEquals(expected, outputValues.get(0)); } @Test diff --git a/dagger-common/src/test/proto/TestMessage.proto b/dagger-common/src/test/proto/TestMessage.proto index 5f88433ca..64075a06f 100644 --- a/dagger-common/src/test/proto/TestMessage.proto +++ b/dagger-common/src/test/proto/TestMessage.proto @@ -53,7 +53,12 @@ message TestEnumMessage { } message TestComplexMap { + message IdMessage { + repeated string ids = 1; + } map complex_map = 1; + map int_message = 2; + map string_message = 3; } message TestRepeatedPrimitiveMessage { diff --git a/dagger-core/src/test/java/io/odpf/dagger/core/processors/internal/processor/function/functions/JsonPayloadFunctionTest.java b/dagger-core/src/test/java/io/odpf/dagger/core/processors/internal/processor/function/functions/JsonPayloadFunctionTest.java index b2b618b5b..a13da4371 100644 --- a/dagger-core/src/test/java/io/odpf/dagger/core/processors/internal/processor/function/functions/JsonPayloadFunctionTest.java +++ b/dagger-core/src/test/java/io/odpf/dagger/core/processors/internal/processor/function/functions/JsonPayloadFunctionTest.java @@ -269,7 +269,7 @@ public void shouldGetCorrectJsonPayloadForComplexFields() throws InvalidProtocol DynamicMessage dynamicMessage = DynamicMessage.parseFrom(complexMapMessage.getDescriptor(), complexMapMessage.toByteArray()); RowManager rowManager = getRowManagerForMessage(dynamicMessage); - String expectedJsonPayload = "{\"complex_map\":[{\"key\":1,\"value\":{\"order_number\":\"order-number-123\",\"order_url\":\"https://order-url\",\"order_details\":\"pickup\"}}]}"; + String expectedJsonPayload = "{\"complex_map\":[{\"key\":1,\"value\":{\"order_number\":\"order-number-123\",\"order_url\":\"https://order-url\",\"order_details\":\"pickup\"}}],\"int_message\":[],\"string_message\":[]}"; String actualJsonPayload = (String) jsonPayloadFunction.getResult(rowManager); assertEquals(expectedJsonPayload, actualJsonPayload);