Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import com.google.protobuf.Descriptors;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.MapEntry;
import com.google.protobuf.WireFormat;
import io.odpf.dagger.common.serde.typehandler.TypeHandler;
import io.odpf.dagger.common.serde.typehandler.RowFactory;
import io.odpf.dagger.common.serde.typehandler.TypeHandler;
import io.odpf.dagger.common.serde.typehandler.TypeHandlerFactory;
import io.odpf.dagger.common.serde.typehandler.TypeInformationFactory;
import io.odpf.dagger.common.serde.typehandler.repeated.RepeatedMessageHandler;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.types.Row;
Expand All @@ -27,6 +27,7 @@
public class MapHandler implements TypeHandler {

private Descriptors.FieldDescriptor fieldDescriptor;
private TypeHandler repeatedMessageHandler;

/**
* Instantiates a new Map proto handler.
Expand All @@ -35,6 +36,7 @@ public class MapHandler implements TypeHandler {
*/
public MapHandler(Descriptors.FieldDescriptor fieldDescriptor) {
this.fieldDescriptor = fieldDescriptor;
this.repeatedMessageHandler = new RepeatedMessageHandler(fieldDescriptor);
}

@Override
Expand All @@ -47,38 +49,44 @@ public DynamicMessage.Builder transformToProtoBuilder(DynamicMessage.Builder bui
if (!canHandle() || field == null) {
return builder;
}

if (field instanceof Map) {
convertFromMap(builder, (Map<String, String>) field);
}

if (field instanceof Object[]) {
convertFromRow(builder, (Object[]) field);
Map<?, ?> mapField = (Map<?, ?>) field;
ArrayList<Row> rows = new ArrayList<>();
for (Entry<?, ?> entry : mapField.entrySet()) {
rows.add(Row.of(entry.getKey(), entry.getValue()));
}
return repeatedMessageHandler.transformToProtoBuilder(builder, rows.toArray());
}

return builder;
return repeatedMessageHandler.transformToProtoBuilder(builder, field);
}

@Override
public Object transformFromPostProcessor(Object field) {
ArrayList<Row> rows = new ArrayList<>();
if (field != null) {
Map<String, String> mapField = (Map<String, String>) field;
for (Entry<String, String> entry : mapField.entrySet()) {
rows.add(getRowFromMap(entry));
if (field == null) {
return rows.toArray();
}
if (field instanceof Map) {
Map<String, ?> mapField = (Map<String, ?>) field;
for (Entry<String, ?> entry : mapField.entrySet()) {
Descriptors.FieldDescriptor keyDescriptor = fieldDescriptor.getMessageType().findFieldByName("key");
Descriptors.FieldDescriptor valueDescriptor = fieldDescriptor.getMessageType().findFieldByName("value");
TypeHandler handler = TypeHandlerFactory.getTypeHandler(keyDescriptor);
Object key = handler.transformFromPostProcessor(entry.getKey());
Object value = TypeHandlerFactory.getTypeHandler(valueDescriptor).transformFromPostProcessor(entry.getValue());
rows.add(Row.of(key, value));
}
return rows.toArray();
}
if (field instanceof List) {
return repeatedMessageHandler.transformFromPostProcessor(field);
}
return rows.toArray();
}

@Override
public Object transformFromProto(Object field) {
ArrayList<Row> rows = new ArrayList<>();
if (field != null) {
List<DynamicMessage> protos = (List<DynamicMessage>) field;
protos.forEach(proto -> rows.add(getRowFromMap(proto)));
}
return rows.toArray();
return repeatedMessageHandler.transformFromProto(field);
}

@Override
Expand Down Expand Up @@ -127,53 +135,4 @@ public Object transformToJson(Object field) {
public TypeInformation getTypeInformation() {
return Types.OBJECT_ARRAY(TypeInformationFactory.getRowType(fieldDescriptor.getMessageType()));
}

private Row getRowFromMap(Entry<String, String> entry) {
Row row = new Row(2);
row.setField(0, entry.getKey());
row.setField(1, entry.getValue());
return row;
}

private Row getRowFromMap(DynamicMessage proto) {
Row row = new Row(2);
row.setField(0, parse(proto, "key"));
row.setField(1, parse(proto, "value"));
return row;
}

private Object parse(DynamicMessage proto, String fieldName) {
Object field = proto.getField(proto.getDescriptorForType().findFieldByName(fieldName));
if (DynamicMessage.class.equals(field.getClass())) {
field = RowFactory.createRow((DynamicMessage) field);
}
return field;
}

private void convertFromRow(DynamicMessage.Builder builder, Object[] field) {
for (Object inputValue : field) {
Row inputRow = (Row) inputValue;
if (inputRow.getArity() != 2) {
throw new IllegalArgumentException("Row: " + inputRow.toString() + " of size: " + inputRow.getArity() + " cannot be converted to map");
}
MapEntry<String, String> mapEntry = MapEntry
.newDefaultInstance(fieldDescriptor.getMessageType(), WireFormat.FieldType.STRING, "", WireFormat.FieldType.STRING, "");
builder.addRepeatedField(fieldDescriptor,
mapEntry.toBuilder()
.setKey((String) inputRow.getField(0))
.setValue((String) inputRow.getField(1))
.buildPartial());
}
}

private void convertFromMap(DynamicMessage.Builder builder, Map<String, String> field) {
for (Entry<String, String> entry : field.entrySet()) {
MapEntry<String, String> mapEntry = MapEntry.newDefaultInstance(fieldDescriptor.getMessageType(), WireFormat.FieldType.STRING, "", WireFormat.FieldType.STRING, "");
builder.addRepeatedField(fieldDescriptor,
mapEntry.toBuilder()
.setKey(entry.getKey())
.setValue(entry.getValue())
.buildPartial());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.PrimitiveType;
import org.junit.Assert;
import org.junit.Test;

import java.util.ArrayList;
Expand Down Expand Up @@ -81,13 +80,11 @@ public void shouldSetMapFieldIfStringMapPassed() {
inputMap.put("b", "456");

DynamicMessage.Builder returnedBuilder = mapHandler.transformToProtoBuilder(builder, inputMap);
List<MapEntry> entries = (List<MapEntry>) returnedBuilder.getField(mapFieldDescriptor);
List<DynamicMessage> entries = (List<DynamicMessage>) returnedBuilder.getField(mapFieldDescriptor);

assertEquals(2, entries.size());
assertEquals("a", entries.get(0).getAllFields().values().toArray()[0]);
assertEquals("123", entries.get(0).getAllFields().values().toArray()[1]);
assertEquals("b", entries.get(1).getAllFields().values().toArray()[0]);
assertEquals("456", entries.get(1).getAllFields().values().toArray()[1]);
assertArrayEquals(Arrays.asList("a", "123").toArray(), entries.get(0).getAllFields().values().toArray());
assertArrayEquals(Arrays.asList("b", "456").toArray(), entries.get(1).getAllFields().values().toArray());
}

@Test
Expand All @@ -111,28 +108,29 @@ public void shouldSetMapFieldIfArrayofObjectsHavingRowsWithStringFieldsPassed()
inputRows.add(inputRow2);

DynamicMessage.Builder returnedBuilder = mapHandler.transformToProtoBuilder(builder, inputRows.toArray());
List<MapEntry> entries = (List<MapEntry>) returnedBuilder.getField(mapFieldDescriptor);
List<DynamicMessage> entries = (List<DynamicMessage>) returnedBuilder.getField(mapFieldDescriptor);

assertEquals(2, entries.size());
assertEquals("a", entries.get(0).getAllFields().values().toArray()[0]);
assertEquals("123", entries.get(0).getAllFields().values().toArray()[1]);
assertEquals("b", entries.get(1).getAllFields().values().toArray()[0]);
assertEquals("456", entries.get(1).getAllFields().values().toArray()[1]);
assertArrayEquals(Arrays.asList("a", "123").toArray(), entries.get(0).getAllFields().values().toArray());
assertArrayEquals(Arrays.asList("b", "456").toArray(), entries.get(1).getAllFields().values().toArray());
}

@Test
public void shouldThrowExceptionIfRowsPassedAreNotOfArityTwo() {
Descriptors.FieldDescriptor mapFieldDescriptor = TestBookingLogMessage.getDescriptor().findFieldByName("metadata");
MapHandler mapHandler = new MapHandler(mapFieldDescriptor);
DynamicMessage.Builder builder = DynamicMessage.newBuilder(mapFieldDescriptor.getContainingType());
public void shouldHandleComplexTypeValuesForSerialization() throws InvalidProtocolBufferException {
Row inputValue1 = Row.of("12345", Row.of(Arrays.asList("a", "b")));
Row inputValue2 = Row.of(1234123, Row.of(Arrays.asList("d", "e")));
Object input = Arrays.asList(inputValue1, inputValue2).toArray();

ArrayList<Row> inputRows = new ArrayList<>();
Descriptors.FieldDescriptor intMessageDescriptor = TestComplexMap.getDescriptor().findFieldByName("int_message");
DynamicMessage.Builder builder = DynamicMessage.newBuilder(TestComplexMap.getDescriptor());

Row inputRow = new Row(3);
inputRows.add(inputRow);
IllegalArgumentException exception = Assert.assertThrows(IllegalArgumentException.class,
() -> mapHandler.transformToProtoBuilder(builder, inputRows.toArray()));
assertEquals("Row: +I[null, null, null] of size: 3 cannot be converted to map", exception.getMessage());
byte[] data = new MapHandler(intMessageDescriptor).transformToProtoBuilder(builder, input).build().toByteArray();
TestComplexMap actualMsg = TestComplexMap.parseFrom(data);
assertArrayEquals(Arrays.asList(12345L, 1234123L).toArray(), actualMsg.getIntMessageMap().keySet().toArray());
TestComplexMap.IdMessage idMessage = (TestComplexMap.IdMessage) actualMsg.getIntMessageMap().values().toArray()[0];
assertTrue(idMessage.getIdsList().containsAll(Arrays.asList("a", "b")));
idMessage = (TestComplexMap.IdMessage) actualMsg.getIntMessageMap().values().toArray()[1];
assertTrue(idMessage.getIdsList().containsAll(Arrays.asList("d", "e")));
}

@Test
Expand All @@ -158,12 +156,8 @@ public void shouldReturnArrayOfRowHavingFieldsSetAsInputMapAndOfSizeTwoForTransf

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromPostProcessor(inputMap));

assertEquals("a", ((Row) outputValues.get(0)).getField(0));
assertEquals("123", ((Row) outputValues.get(0)).getField(1));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
assertEquals("b", ((Row) outputValues.get(1)).getField(0));
assertEquals("456", ((Row) outputValues.get(1)).getField(1));
assertEquals(2, ((Row) outputValues.get(1)).getArity());
assertEquals(Row.of("a", "123"), outputValues.get(0));
assertEquals(Row.of("b", "456"), outputValues.get(1));
}

@Test
Expand Down Expand Up @@ -210,12 +204,8 @@ public void shouldReturnArrayOfRowHavingFieldsSetAsInputMapAndOfSizeTwoForTransf

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));

assertEquals("a", ((Row) outputValues.get(0)).getField(0));
assertEquals("123", ((Row) outputValues.get(0)).getField(1));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
assertEquals("b", ((Row) outputValues.get(1)).getField(0));
assertEquals("456", ((Row) outputValues.get(1)).getField(1));
assertEquals(2, ((Row) outputValues.get(1)).getArity());
assertEquals(Row.of("a", "123"), outputValues.get(0));
assertEquals(Row.of("b", "456"), outputValues.get(1));
}

@Test
Expand Down Expand Up @@ -247,16 +237,11 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));

assertEquals(1, ((Row) outputValues.get(0)).getField(0));
assertEquals("123", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
assertEquals("abc", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
assertEquals(2, ((Row) outputValues.get(1)).getField(0));
assertEquals("456", ((Row) ((Row) outputValues.get(1)).getField(1)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(1)).getField(1)).getField(1));
assertEquals("efg", ((Row) ((Row) outputValues.get(1)).getField(1)).getField(2));
assertEquals(2, ((Row) outputValues.get(1)).getArity());
Row mapEntry1 = Row.of(1, Row.of("123", "", "abc"));
Row mapEntry2 = Row.of(2, Row.of("456", "", "efg"));

assertEquals(mapEntry1, outputValues.get(0));
assertEquals(mapEntry2, outputValues.get(1));
}

@Test
Expand All @@ -271,11 +256,8 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));

assertEquals(0, ((Row) outputValues.get(0)).getField(0));
assertEquals("123", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
assertEquals("abc", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
Row expected = Row.of(0, Row.of("123", "", "abc"));
assertEquals(expected, outputValues.get(0));
}

@Test
Expand All @@ -290,11 +272,9 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));

assertEquals(1, ((Row) outputValues.get(0)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
Row expected = Row.of(1, Row.of("", "", ""));

assertEquals(expected, outputValues.get(0));
}

@Test
Expand All @@ -309,11 +289,9 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));

assertEquals(0, ((Row) outputValues.get(0)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
Row expected = Row.of(0, Row.of("", "", ""));

assertEquals(expected, outputValues.get(0));
}

@Test
Expand All @@ -328,11 +306,8 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));

assertEquals(0, ((Row) outputValues.get(0)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
Row expected = Row.of(0, Row.of("", "", ""));
assertEquals(expected, outputValues.get(0));
}

@Test
Expand Down
5 changes: 5 additions & 0 deletions dagger-common/src/test/proto/TestMessage.proto
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ message TestEnumMessage {
}

message TestComplexMap {
message IdMessage {
repeated string ids = 1;
}
map<int32, TestMessage> complex_map = 1;
map<int64, IdMessage> int_message = 2;
map<string, IdMessage> string_message = 3;
}

message TestRepeatedPrimitiveMessage {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ public void shouldGetCorrectJsonPayloadForComplexFields() throws InvalidProtocol
DynamicMessage dynamicMessage = DynamicMessage.parseFrom(complexMapMessage.getDescriptor(), complexMapMessage.toByteArray());
RowManager rowManager = getRowManagerForMessage(dynamicMessage);

String expectedJsonPayload = "{\"complex_map\":[{\"key\":1,\"value\":{\"order_number\":\"order-number-123\",\"order_url\":\"https://order-url\",\"order_details\":\"pickup\"}}]}";
String expectedJsonPayload = "{\"complex_map\":[{\"key\":1,\"value\":{\"order_number\":\"order-number-123\",\"order_url\":\"https://order-url\",\"order_details\":\"pickup\"}}],\"int_message\":[],\"string_message\":[]}";
String actualJsonPayload = (String) jsonPayloadFunction.getResult(rowManager);

assertEquals(expectedJsonPayload, actualJsonPayload);
Expand Down