Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions dbschema/migrations/00007.edgeql
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
CREATE MIGRATION m1tig3qk3mnrb2xpszwyodgurkdeyza6yt67zo7kfljc2icy3e7yma
ONTO m1vxu37wczr357ppyrbhfon2msem5oczk7mjszhx2xzp2qlxpazana
{
CREATE MODULE tests IF NOT EXISTS;
CREATE TYPE tests::TestDatastructure {
CREATE REQUIRED PROPERTY a: std::str;
CREATE REQUIRED PROPERTY b: std::str;
CREATE REQUIRED PROPERTY c: std::str;
};
};
7 changes: 7 additions & 0 deletions dbschema/tests.esdl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module tests {
type TestDatastructure {
required property a -> str;
required property b -> str;
required property c -> str;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import java.util.concurrent.CompletionStage;

public final class GlobalsAndConfig implements Example {
private static final Logger logger = LoggerFactory.getLogger(AbstractTypes.class);
private static final Logger logger = LoggerFactory.getLogger(GlobalsAndConfig.class);

@Override
public CompletionStage<Void> run(EdgeDBClient client) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.edgedb.driver.util.BinaryProtocolUtils.BYTE_SIZE;
import static com.edgedb.driver.util.BinaryProtocolUtils.INT_SIZE;

public class PacketSerializer {
private static final Logger logger = LoggerFactory.getLogger(PacketSerializer.class);
private static final @NotNull Map<ServerMessageType, Function<PacketReader, Receivable>> deserializerMap;
Expand Down Expand Up @@ -73,41 +70,58 @@ public static <T extends Enum<T> & BinaryEnum<U>, U extends Number> T getEnumVal

@Override
protected void decode(@NotNull ChannelHandlerContext ctx, @NotNull ByteBuf msg, @NotNull List<Object> out) throws Exception {
var fromContract = false;

if(contracts.containsKey(ctx.channel())){
var contract = contracts.get(ctx.channel());

logger.debug("Attempting to complete contract {}", contract);

if (contract.tryComplete(msg)) {
logger.debug("Contract completed of type {} with size {}", contract.messageType, contract.length);

out.add(contract.getPacket());
contracts.remove(ctx.channel());
fromContract = true;
msg = contract.data;
} else {
logger.debug("Contract pending [{}]: {}/{}", contract.messageType, contract.getSize(), contract.length);
return;
}
}

while (msg.readableBytes() > 5) {
var type = getEnumValue(ServerMessageType.class, msg.readByte());
var length = msg.readUnsignedInt() - 4; // remove length of self.

// can we read this packet?
if (msg.readableBytes() >= length) {
var packet = PacketSerializer.deserialize(type, length, msg.readSlice((int) length));

if(packet == null) {
logger.error("Got null result for packet type {}", type);
throw new EdgeDBException("Failed to read message type: malformed data");
}

logger.debug("S->C: T:{}", type);
out.add(packet);
continue;
}

if (contracts.containsKey(ctx.channel())) {
var contract = contracts.get(ctx.channel());

if (contract.tryComplete(msg)) {
out.add(contract.getPacket());
}

return;
} else {
contracts.put(ctx.channel(), new PacketContract(msg, type, length));
}
// if we cannot read the full packet, create a contract for it.
msg.retain();
contracts.put(ctx.channel(), new PacketContract(msg, type, length));
return;
}

if (msg.readableBytes() > 0) {
if (contracts.containsKey(ctx.channel())) {
var contract = contracts.get(ctx.channel());
if(msg.readableBytes() > 0){
msg.retain();
contracts.put(ctx.channel(), new PacketContract(msg, null, null));
return;
}

if (contract.tryComplete(msg)) {
out.add(contract.getPacket());
}
} else {
contracts.put(ctx.channel(), new PacketContract(msg, null, null));
}
if(fromContract){
msg.release();
}
}

Expand All @@ -118,6 +132,8 @@ class PacketContract {
private @Nullable ServerMessageType messageType;
private @Nullable Long length;

private final List<ByteBuf> components;

public PacketContract(
ByteBuf data,
@Nullable ServerMessageType messageType,
Expand All @@ -126,38 +142,47 @@ public PacketContract(
this.data = data;
this.length = length;
this.messageType = messageType;

this.components = new ArrayList<>() {{
add(data);
}};
}

public long getSize() {
long size = 0;

for (var component : components) {
size += component.readableBytes();
}

return size;
}

public boolean tryComplete(@NotNull ByteBuf other) {
var orig = data.slice();
data = Unpooled.wrappedBuffer(orig, other);

if (messageType == null) {
messageType = pick(other, b -> getEnumValue(ServerMessageType.class, b.readByte()), BYTE_SIZE);
messageType = getEnumValue(ServerMessageType.class, data.readByte());
}

if (length == null) {
length = pick(other, b -> b.readUnsignedInt() - 4, INT_SIZE);
length = data.readUnsignedInt() - 4;
}

data = Unpooled.wrappedBuffer(data, other);
other.retain();
components.add(other);

if (data.readableBytes() >= length) {
// read
packet = PacketSerializer.deserialize(messageType, length, data);
packet = PacketSerializer.deserialize(messageType, length, data, false);

return true;
}

return false;
}

private <T> T pick(@NotNull ByteBuf other, @NotNull Function<ByteBuf, T> map, long sz) {
if (data.readableBytes() > sz) {
return map.apply(data);
} else if (other.readableBytes() < sz) {
throw new IndexOutOfBoundsException();
}

return map.apply(other);
}

public @NotNull Receivable getPacket() throws OperationNotSupportedException {
if (packet == null) {
throw new OperationNotSupportedException("Packet contract was incomplete");
Expand Down Expand Up @@ -192,11 +217,20 @@ protected void encode(@NotNull ChannelHandlerContext ctx, @NotNull Sendable msg,
};
}

public static @Nullable Receivable deserialize(ServerMessageType messageType, long length, @NotNull ByteBuf buffer) {
public static @Nullable Receivable deserialize(
ServerMessageType messageType, long length, @NotNull ByteBuf buffer
) {
var reader = new PacketReader(buffer);
return deserializeSingle(messageType, length, reader, true);
}

public static @Nullable Receivable deserialize(
ServerMessageType messageType, long length, @NotNull ByteBuf buffer, boolean verifyEmpty
) {
var reader = new PacketReader(buffer);
return deserializeSingle(messageType, length, reader, verifyEmpty);
}

public static @Nullable Receivable deserializeSingle(PacketReader reader) {
var messageType = reader.readEnum(ServerMessageType.class, Byte.TYPE);
var length = reader.readUInt32().longValue();
Expand Down
103 changes: 103 additions & 0 deletions src/driver/src/test/java/ProtocolTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import com.edgedb.driver.EdgeDBClient;
import com.edgedb.driver.annotations.EdgeDBType;
import com.edgedb.driver.exceptions.EdgeDBException;
import com.fasterxml.jackson.databind.json.JsonMapper;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.*;
import java.util.concurrent.ExecutionException;

import static org.assertj.core.api.Assertions.assertThat;

public class ProtocolTests {
private static final Logger logger = LoggerFactory.getLogger(ProtocolTests.class);

@EdgeDBType
public static class TestDatastructure {
public UUID id;

public String a;

public String b;

public String c;
}

/**
* The goal is to test the contract logic in {@linkplain com.edgedb.driver.binary.PacketSerializer}, specifically
* the decoder returned from the <b>createDecoder</b> function. To achieve this, we can query something that
* returns either multiple data packets amounting up to >16k bytes, or a single data packet that is >16k bytes.
*/
@Test
public void testPacketContract() throws EdgeDBException, IOException, ExecutionException, InterruptedException {
var client = new EdgeDBClient().withModule("tests");

// insert 1k items
logger.info("Removing old data structures...");
client.execute("DELETE TestDatastructure")
.toCompletableFuture().get();

var results = new HashMap<UUID, String[]>();

logger.info("Inserting 1000 items...");

for(int i = 0; i != 1000; i++){
var data = new String[] {
generateRandomString(),
generateRandomString(),
generateRandomString()
};

var result = client.queryRequiredSingle(TestDatastructure.class, "INSERT TestDatastructure { a := <str>$a, b := <str>$b, c := <str>$c }", new HashMap<>(){{
put("a", data[0]);
put("b", data[1]);
put("c", data[2]);
}}).toCompletableFuture().get();

results.put(result.id, data);
}

logger.info("Querying all items...");

// assert the data can be read via binary and json
var structures = client.query(TestDatastructure.class, "SELECT TestDatastructure { id, a, b, c }")
.toCompletableFuture().get();

var json = client.queryJson("SELECT TestDatastructure { id, a, b, c }")
.toCompletableFuture().get();

var structuresFromJson = List.of(new JsonMapper().readValue(json.getValue(), TestDatastructure[].class));

assertStructuresMatch(structures, results);
assertStructuresMatch(structuresFromJson, results);
}

private void assertStructuresMatch(List<TestDatastructure> source, Map<UUID, String[]> truth) {
for(var structure : source) {
assert structure != null;

var expected = truth.get(structure.id);

assertThat(structure.a).isEqualTo(expected[0]);
assertThat(structure.b).isEqualTo(expected[1]);
assertThat(structure.c).isEqualTo(expected[2]);

logger.info("{} passed [a: {}, b: {}, c: {}]", structure.id, structure.a, structure.b, structure.c);
}
}

private static String generateRandomString() {
final var chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890";

Random rand =new Random();
StringBuilder res=new StringBuilder();
for (int i = 0; i < 17; i++) {
int randIndex=rand.nextInt(chars.length());
res.append(chars.charAt(randIndex));
}
return res.toString();
}
}