diff --git a/examples/java-examples/src/main/java/com/edgedb/examples/Main.java b/examples/java-examples/src/main/java/com/edgedb/examples/Main.java index d0e9cb38..0204e280 100644 --- a/examples/java-examples/src/main/java/com/edgedb/examples/Main.java +++ b/examples/java-examples/src/main/java/com/edgedb/examples/Main.java @@ -1,7 +1,6 @@ package com.edgedb.examples; -import com.edgedb.driver.EdgeDBClient; -import com.edgedb.driver.EdgeDBClientConfig; +import com.edgedb.driver.*; import com.edgedb.driver.exceptions.EdgeDBException; import com.edgedb.driver.namingstrategies.NamingStrategy; import org.slf4j.Logger; @@ -24,6 +23,8 @@ public static void main(String[] args) throws IOException, EdgeDBException { runJavaExamples(client); logger.info("Examples complete"); + + System.exit(0); } private static void runJavaExamples(EdgeDBClient client) { diff --git a/examples/java-examples/src/main/java/com/edgedb/examples/Transactions.java b/examples/java-examples/src/main/java/com/edgedb/examples/Transactions.java index 56d6bd86..3534f1d6 100644 --- a/examples/java-examples/src/main/java/com/edgedb/examples/Transactions.java +++ b/examples/java-examples/src/main/java/com/edgedb/examples/Transactions.java @@ -4,6 +4,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; public final class Transactions implements Example { @@ -11,6 +12,12 @@ public final class Transactions implements Example { @Override public CompletionStage run(EdgeDBClient client) { + // verify we can run transactions + if(!client.supportsTransactions()) { + logger.info("Skipping transactions, client type {} doesn't support it", client.getClientType()); + return CompletableFuture.completedFuture(null); + } + return client.transaction(tx -> { logger.info("In transaction"); return tx.queryRequiredSingle(String.class, "select 'Result from Transaction'"); diff --git a/examples/kotlin-examples/src/main/kotlin/com/edgedb/examples/Main.kt b/examples/kotlin-examples/src/main/kotlin/com/edgedb/examples/Main.kt index c47ae201..b5928bb8 100644 --- a/examples/kotlin-examples/src/main/kotlin/com/edgedb/examples/Main.kt +++ b/examples/kotlin-examples/src/main/kotlin/com/edgedb/examples/Main.kt @@ -2,9 +2,11 @@ package com.edgedb.examples import com.edgedb.driver.EdgeDBClient import com.edgedb.driver.EdgeDBClientConfig +import com.edgedb.driver.EdgeDBConnection import com.edgedb.driver.namingstrategies.NamingStrategy import kotlinx.coroutines.runBlocking import org.slf4j.LoggerFactory +import kotlin.system.exitProcess object Main { private val logger = LoggerFactory.getLogger(Main::class.java) @@ -37,5 +39,7 @@ object Main { } } } + + exitProcess(0) } } \ No newline at end of file diff --git a/examples/kotlin-examples/src/main/kotlin/com/edgedb/examples/Transactions.kt b/examples/kotlin-examples/src/main/kotlin/com/edgedb/examples/Transactions.kt index b51dcfe1..93269e53 100644 --- a/examples/kotlin-examples/src/main/kotlin/com/edgedb/examples/Transactions.kt +++ b/examples/kotlin-examples/src/main/kotlin/com/edgedb/examples/Transactions.kt @@ -10,6 +10,12 @@ class Transactions : Example { } override suspend fun runAsync(client: EdgeDBClient) { + // verify we can run transactions + if (!client.supportsTransactions()) { + logger.info("Skipping transactions, client type {} doesn't support it", client.clientType) + return + } + val transactionResult = client.transaction { tx -> tx.queryRequiredSingle(String::class.java, "SELECT 'Hello from transaction!'") }.await() diff --git a/examples/scala-examples/build.sbt b/examples/scala-examples/build.sbt index 732e7166..c5e02420 100644 --- a/examples/scala-examples/build.sbt +++ b/examples/scala-examples/build.sbt @@ -5,7 +5,7 @@ ThisBuild / scalaVersion := "3.1.3" //resolvers += Resolver.file("my-test-repo", file("test")) libraryDependencies ++= Seq( - "com.edgedb" % "driver" % "0.0.1" from "file:///" + System.getProperty("user.dir") + "/lib/com.edgedb.driver-0.0.1-SNAPSHOT.jar", + "com.edgedb" % "driver" % "0.1.1" from "file:///" + System.getProperty("user.dir") + "/lib/com.edgedb.driver-0.1.1-SNAPSHOT.jar", "ch.qos.logback" % "logback-classic" % "1.4.7", "ch.qos.logback" % "logback-core" % "1.4.7", "com.fasterxml.jackson.core" % "jackson-databind" % "2.15.1", diff --git a/examples/scala-examples/src/main/scala/Main.scala b/examples/scala-examples/src/main/scala/Main.scala index b2f546a7..4a4e8b5d 100644 --- a/examples/scala-examples/src/main/scala/Main.scala +++ b/examples/scala-examples/src/main/scala/Main.scala @@ -32,6 +32,8 @@ def main(): Unit = { Await.ready(runExample(logger, client, example), Duration.Inf) logger.info("Examples complete!") + + System.exit(0) } private def runExample(logger: Logger, client: EdgeDBClient, example: Example)(implicit context: ExecutionContext): Future[Unit] = { diff --git a/examples/scala-examples/src/main/scala/Transactions.scala b/examples/scala-examples/src/main/scala/Transactions.scala index 6eaa76ec..d6d9f20f 100644 --- a/examples/scala-examples/src/main/scala/Transactions.scala +++ b/examples/scala-examples/src/main/scala/Transactions.scala @@ -8,6 +8,12 @@ import scala.concurrent.{ExecutionContext, Future} class Transactions extends Example { private val logger = LoggerFactory.getLogger(classOf[Transactions]) override def run(client: EdgeDBClient)(implicit context: ExecutionContext): Future[Unit] = { + // verify we can run transactions + if (!client.supportsTransactions()) { + logger.info("Skipping transactions, client type {} doesn't support it", client.getClientType) + return Future.unit + } + client.transaction((tx: Transaction) => { logger.info("In transaction") tx.queryRequiredSingle(classOf[String], "select 'Result from Transaction'") diff --git a/src/driver/build.gradle b/src/driver/build.gradle index cfd2c7e7..754e4a50 100644 --- a/src/driver/build.gradle +++ b/src/driver/build.gradle @@ -20,7 +20,7 @@ dependencies { testRuntimeOnly "org.junit.jupiter:junit-jupiter-engine:$junit_version" testImplementation "org.assertj:assertj-core:$assertj_version" testImplementation 'org.burningwave:core:12.62.6' - testImplementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310:$jackson_databind_version' + testImplementation "com.fasterxml.jackson.datatype:jackson-datatype-jsr310:$jackson_version" testImplementation "ch.qos.logback:logback-classic:$logback_version" testImplementation "ch.qos.logback:logback-core:$logback_version" } @@ -33,13 +33,24 @@ jar { } } -tasks.register('copyJarToBin') { +def deleteOldJar = tasks.register('deleteOldJarInBin') { + var path = Paths.get(project.rootDir.toString(), 'examples', 'scala-examples', 'lib') + var paths = path.toFile().listFiles((FileFilter) { File f -> f.name.startsWith('com.edgedb.driver') }) + delete files(paths) +} + +def copyJar = tasks.register('copyJarToBin') { copy { from jar into Paths.get(project.rootDir.toString(), 'examples', 'scala-examples', 'lib') } } +copyJar.configure { + dependsOn(deleteOldJar, compileJava) +} + + publishing { publications { mavenJava(MavenPublication) { diff --git a/src/driver/src/main/java/com/edgedb/driver/EdgeDBClient.java b/src/driver/src/main/java/com/edgedb/driver/EdgeDBClient.java index 765763e7..0f63694c 100644 --- a/src/driver/src/main/java/com/edgedb/driver/EdgeDBClient.java +++ b/src/driver/src/main/java/com/edgedb/driver/EdgeDBClient.java @@ -1,10 +1,7 @@ package com.edgedb.driver; import com.edgedb.driver.abstractions.ClientQueryDelegate; -import com.edgedb.driver.clients.BaseEdgeDBClient; -import com.edgedb.driver.clients.EdgeDBTCPClient; -import com.edgedb.driver.clients.StatefulClient; -import com.edgedb.driver.clients.TransactableClient; +import com.edgedb.driver.clients.*; import com.edgedb.driver.datatypes.Json; import com.edgedb.driver.exceptions.ConfigurationException; import com.edgedb.driver.exceptions.EdgeDBException; @@ -80,14 +77,6 @@ public EdgeDBClient(EdgeDBConnection connection, @NotNull EdgeDBClientConfig con this.clientAvailability = config.getClientAvailability(); } - private @NotNull ClientFactory createClientFactory() throws ConfigurationException { - if(config.getClientType() == ClientType.TCP) { - return EdgeDBTCPClient::new; - } - - throw new ConfigurationException(String.format("No such implementation for client type %s found", this.config.getClientType())); - } - /** * Constructs a new {@linkplain EdgeDBClient}. * @param connection The connection parameters used to connect this client to EdgeDB. @@ -126,6 +115,25 @@ private EdgeDBClient(@NotNull EdgeDBClient other, Session session) { this.clientAvailability = other.clientAvailability; } + private @NotNull ClientFactory createClientFactory() throws ConfigurationException { + if(config.getClientType() == ClientType.TCP) { + return EdgeDBTCPClient::new; + } else if (config.getClientType() == ClientType.HTTP) { + return EdgeDBHttpClient::new; + } + + throw new ConfigurationException(String.format("No such implementation for client type %s found", this.config.getClientType())); + } + + /** + * Gets the underlying client type for this client pool. + * @return The underlying client type, usually based on transport. + * @see ClientType + */ + public ClientType getClientType() { + return config.getClientType(); + } + /** * Gets whether this client supports transactions. * @return {@code true} if the client supports transactions; otherwise {@code false}. @@ -232,14 +240,23 @@ public CompletionStage transaction(@NotNull Function { - public final BaseEdgeDBClient client; - public final U result; + private final BaseEdgeDBClient client; + private final @Nullable U result; - private ExecutePair(BaseEdgeDBClient client, U result) { + private ExecutePair(BaseEdgeDBClient client, @Nullable U result) { this.client = client; this.result = result; } + + public @Nullable U getResult() { + return result; + } + + public BaseEdgeDBClient getClient() { + return client; + } } private CompletionStage executePooledQuery( @@ -254,21 +271,18 @@ private CompletionStage executePooledQuery( query, args, capabilities - ).handle((r, x) -> new ExecutePair<>(client, r)) + ).thenApply(r -> new ExecutePair<>(client, r)) ) - .handle((pair, exc) -> { - try { - pair.client.close(); - } catch (Exception e) { - throw new CompletionException(e); + .whenComplete((entry, exc) -> { + if(entry != null) { + try { + entry.getClient().close(); + } catch (Exception e) { + throw new CompletionException(e); + } } - - if(exc != null) { - throw new CompletionException(exc); - } - - return pair.result; - }); + }) + .thenApply(ExecutePair::getResult); } @Override @@ -374,7 +388,12 @@ private CompletionStage createClient() { return this.poolHolder.acquireContract() .thenApply(contract -> { logger.trace("Contract acquired, remaining handles: {}", this.poolHolder.remaining()); - var client = clientFactory.create(this.connection, this.config, contract); + BaseEdgeDBClient client; + try { + client = clientFactory.create(this.connection, this.config, contract); + } catch (EdgeDBException e) { + throw new CompletionException(e); + } contract.register(client, this::acceptClient); client.onReady(this::onClientReady); logger.debug("client instance created: {}", client); @@ -385,6 +404,7 @@ private CompletionStage createClient() { @FunctionalInterface private interface ClientFactory { - BaseEdgeDBClient create(EdgeDBConnection connection, EdgeDBClientConfig config, AutoCloseable poolHandle); + BaseEdgeDBClient create(EdgeDBConnection connection, EdgeDBClientConfig config, AutoCloseable poolHandle) + throws EdgeDBException; } } diff --git a/src/driver/src/main/java/com/edgedb/driver/EdgeDBConnection.java b/src/driver/src/main/java/com/edgedb/driver/EdgeDBConnection.java index 6015505c..6d188d5e 100644 --- a/src/driver/src/main/java/com/edgedb/driver/EdgeDBConnection.java +++ b/src/driver/src/main/java/com/edgedb/driver/EdgeDBConnection.java @@ -146,7 +146,7 @@ protected void setPassword(String value) { * @return The hostname part of the connection. */ public @NotNull String getHostname() { - return hostname == null ? "127.0.0.1" : hostname; + return hostname == null ? "localhost" : hostname; } /** diff --git a/src/driver/src/main/java/com/edgedb/driver/binary/PacketReader.java b/src/driver/src/main/java/com/edgedb/driver/binary/PacketReader.java index 11bc120f..0e4be577 100644 --- a/src/driver/src/main/java/com/edgedb/driver/binary/PacketReader.java +++ b/src/driver/src/main/java/com/edgedb/driver/binary/PacketReader.java @@ -61,6 +61,21 @@ public void skip(int count) { this.buffer.skipBytes(count); } + public void skip(long count) { + if(count >> 32 == 0) { + // can convert to int + skip((int)count); + return; + } + + var temp = count; + do { + this.buffer.skipBytes((int) temp); + temp -= Integer.MAX_VALUE; + } + while (temp >= Integer.MAX_VALUE); + } + public boolean isEmpty() { return this.buffer.readableBytes() == 0; } diff --git a/src/driver/src/main/java/com/edgedb/driver/binary/PacketSerializer.java b/src/driver/src/main/java/com/edgedb/driver/binary/PacketSerializer.java index 43aeb4a7..36255bb5 100644 --- a/src/driver/src/main/java/com/edgedb/driver/binary/PacketSerializer.java +++ b/src/driver/src/main/java/com/edgedb/driver/binary/PacketSerializer.java @@ -3,6 +3,8 @@ import com.edgedb.driver.binary.packets.ServerMessageType; import com.edgedb.driver.binary.packets.receivable.*; import com.edgedb.driver.binary.packets.sendables.Sendable; +import com.edgedb.driver.exceptions.ConnectionFailedException; +import com.edgedb.driver.exceptions.EdgeDBException; import com.edgedb.driver.util.HexUtils; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -16,7 +18,12 @@ import org.slf4j.LoggerFactory; import javax.naming.OperationNotSupportedException; +import java.net.http.HttpResponse; +import java.nio.ByteBuffer; import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.Flow; import java.util.function.Function; import java.util.stream.Collectors; @@ -67,7 +74,7 @@ public static & BinaryEnum, U extends Number> T getEnumVal @Override protected void decode(@NotNull ChannelHandlerContext ctx, @NotNull ByteBuf msg, @NotNull List out) throws Exception { while (msg.readableBytes() > 5) { - var type = ServerMessageType.valueOf(msg.readByte()); + var type = getEnumValue(ServerMessageType.class, msg.readByte()); var length = msg.readUnsignedInt() - 4; // remove length of self. // can we read this packet? @@ -123,7 +130,7 @@ public PacketContract( public boolean tryComplete(@NotNull ByteBuf other) { if (messageType == null) { - messageType = pick(other, b -> ServerMessageType.valueOf(b.readByte()), BYTE_SIZE); + messageType = pick(other, b -> getEnumValue(ServerMessageType.class, b.readByte()), BYTE_SIZE); } if (length == null) { @@ -187,15 +194,28 @@ protected void encode(@NotNull ChannelHandlerContext ctx, @NotNull Sendable msg, public static @Nullable Receivable deserialize(ServerMessageType messageType, long length, @NotNull ByteBuf buffer) { var reader = new PacketReader(buffer); + return deserializeSingle(messageType, length, reader, true); + } + + public static @Nullable Receivable deserializeSingle(PacketReader reader) { + var messageType = reader.readEnum(ServerMessageType.class, Byte.TYPE); + var length = reader.readUInt32().longValue(); + + return deserializeSingle(messageType, length, reader, false); + } - if(!deserializerMap.containsKey(messageType)) { - logger.error("Unknown packet type {}", messageType); - reader.skip((int)length); + public static @Nullable Receivable deserializeSingle( + ServerMessageType type, long length, @NotNull PacketReader reader, + boolean verifyEmpty + ) { + if(!deserializerMap.containsKey(type)) { + logger.error("Unknown packet type {}", type); + reader.skip(length); return null; } try { - return deserializerMap.get(messageType).apply(reader); + return deserializerMap.get(type).apply(reader); } catch (Exception x) { logger.error("Failed to deserialize packet", x); @@ -203,8 +223,96 @@ protected void encode(@NotNull ChannelHandlerContext ctx, @NotNull Sendable msg, } finally { // ensure we read the entire packet - if(!reader.isEmpty()) { - logger.warn("Hanging data left inside packet reader of type {} with length {}", messageType, length); + if(verifyEmpty && !reader.isEmpty()) { + logger.warn("Hanging data left inside packet reader of type {} with length {}", type, length); + } + } + } + + public static HttpResponse.BodyHandler> PACKET_BODY_HANDLER = new PacketBodyHandler(); + + private static class PacketBodyHandler implements HttpResponse.BodyHandler> { + @Override + public HttpResponse.BodySubscriber> apply(HttpResponse.ResponseInfo responseInfo) { + // ensure success + var isSuccess = responseInfo.statusCode() / 100 == 2; + + return isSuccess + ? new PacketBodySubscriber() + : new PacketBodySubscriber(responseInfo.statusCode()); + } + + private static class PacketBodySubscriber implements HttpResponse.BodySubscriber> { + private final @Nullable List<@NotNull ByteBuf> buffers; + private final CompletableFuture> promise; + + public PacketBodySubscriber(int errorCode) { + buffers = null; + promise = CompletableFuture.failedFuture( + new ConnectionFailedException("Got HTTP error code " + errorCode) + ); + } + + public PacketBodySubscriber() { + promise = new CompletableFuture<>(); + buffers = new ArrayList<>(); + } + + @Override + public CompletionStage> getBody() { + return promise; + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + if(buffers == null) { + return; // failed + } + + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(List items) { + if(buffers == null) { + return; // failed + } + + for(var item : items) { + buffers.add(Unpooled.wrappedBuffer(item)); + } + } + + @Override + public void onError(Throwable throwable) { + promise.completeExceptionally(throwable); + } + + @Override + public void onComplete() { + if(buffers == null) { + return; // failed + } + + var completeBuffer = Unpooled.wrappedBuffer(buffers.toArray(new ByteBuf[0])); + + var reader = new PacketReader(completeBuffer); + var data = new ArrayList(); + + while(completeBuffer.readableBytes() > 0) { + var packet = deserializeSingle(reader); + + if(packet == null && completeBuffer.readableBytes() > 0) { + promise.completeExceptionally( + new EdgeDBException("Failed to deserialize packet, buffer had " + completeBuffer.readableBytes() + " bytes remaining") + ); + return; + } + + data.add(packet); + } + + promise.complete(data); } } } diff --git a/src/driver/src/main/java/com/edgedb/driver/binary/duplexers/HttpDuplexer.java b/src/driver/src/main/java/com/edgedb/driver/binary/duplexers/HttpDuplexer.java new file mode 100644 index 00000000..b776d149 --- /dev/null +++ b/src/driver/src/main/java/com/edgedb/driver/binary/duplexers/HttpDuplexer.java @@ -0,0 +1,230 @@ +package com.edgedb.driver.binary.duplexers; + +import com.edgedb.driver.binary.PacketSerializer; +import com.edgedb.driver.binary.packets.receivable.Receivable; +import com.edgedb.driver.binary.packets.sendables.Sendable; +import com.edgedb.driver.clients.EdgeDBHttpClient; +import com.edgedb.driver.exceptions.ConnectionFailedException; +import com.edgedb.driver.exceptions.EdgeDBException; +import io.netty.buffer.ByteBufInputStream; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.naming.OperationNotSupportedException; +import java.net.http.HttpRequest; +import java.util.ArrayDeque; +import java.util.Objects; +import java.util.Queue; +import java.util.concurrent.*; + +import static com.edgedb.driver.util.ComposableUtil.composeWith; + +public class HttpDuplexer extends Duplexer { + private static final Logger logger = LoggerFactory.getLogger(HttpDuplexer.class); + private static final String HTTP_BINARY_CONTENT_TYPE = "application/x.edgedb.v_1_0.binary"; + + private final EdgeDBHttpClient client; + private final Semaphore lock; + private final Executor lockExecutor; + private final Queue<@NotNull Receivable> packetQueue; + private final Queue> readPromises; + + public HttpDuplexer(EdgeDBHttpClient client) { + this.client = client; + this.lock = new Semaphore(1); + this.lockExecutor = Executors.newSingleThreadExecutor(); + this.packetQueue = new ArrayDeque<>(); + this.readPromises = new ArrayDeque<>(); + } + + @Override + public void reset() { + packetQueue.clear(); + readPromises.clear(); + + client.clearToken(); + } + + @Override + public boolean isConnected() { + return client.getToken() != null; + } + + @Override + public CompletionStage disconnect() { + return CompletableFuture.runAsync(client::clearToken); + } + + @Override + public CompletionStage readNext() { + return acquireLock("READ") + .thenCompose((v) -> readNext0()) + .whenCompleteAsync((v,e) -> { + logger.debug("[READ]: Releasing lock"); + lock.release(); + }, lockExecutor); + } + + private CompletionStage readNext0() { + logger.debug("Preforming read, is authed?: {}", isConnected()); + if(!isConnected()) { + return CompletableFuture.failedFuture( + new EdgeDBException("Cannot preform read without authorization") + ); + } + + logger.debug("Packet queue empty?: {}", packetQueue.isEmpty()); + if(packetQueue.isEmpty()) { + var promise = new CompletableFuture(); + logger.debug("Enqueueing read promise {}...", promise.hashCode()); + readPromises.offer(promise); + promise.whenComplete((v,e) -> { + logger.debug("Read promise {} complete", promise.hashCode()); + }); + return promise; + } else { + logger.debug("Completing from polled packet"); + return CompletableFuture.completedFuture(packetQueue.poll()); + } + } + + @Override + public CompletionStage send(Sendable packet, @Nullable Sendable... packets) { + return acquireLock("WRITE") + .thenCompose((v) -> send0(packet, packets)) + .whenCompleteAsync((v,e) -> { + logger.debug("[WRITE]: Releasing lock"); + lock.release(); + }, lockExecutor); + } + + private CompletionStage send0(Sendable packet, @Nullable Sendable... packets) { + return verifyAuthenticated() + .thenApply((v) -> { + try { + logger.debug("Creating buffer stream"); + return new ByteBufInputStream(PacketSerializer.serialize(packet, packets), true); + } catch (OperationNotSupportedException e) { + logger.debug("Failed to create buffer stream", e); + throw new CompletionException(e); + } + }) + .thenApply(stream -> + HttpRequest.newBuilder() + .uri(client.getExecUri()) + .header("Authorization", "Bearer " + client.getToken()) + .header("Content-Type", HTTP_BINARY_CONTENT_TYPE) + .header("X-EdgeDB-User", client.getConnectionArguments().getUsername()) + .POST(HttpRequest.BodyPublishers.ofInputStream(() -> stream)) + .build() + ) + .thenCompose((request) -> { + logger.debug("Sending execution request..."); + return client.httpClient.sendAsync(request, PacketSerializer.PACKET_BODY_HANDLER); + }) + .thenCompose(EdgeDBHttpClient::ensureSuccess) + .thenAccept(response -> { + logger.debug("Enqueueing {} packets", response.body().size()); + for(var receivable : response.body()) { + packetQueue.offer(receivable); + } + }) + .thenCompose((v) -> processReadPromises()); + } + + private CompletionStage processReadPromises() { + logger.debug( + "Processing read promises. has promises?: {}, has data?: {}", + !readPromises.isEmpty(), !packetQueue.isEmpty() + ); + + if(!readPromises.isEmpty() && !packetQueue.isEmpty()) { + var promise = readPromises.poll(); + var receivable = Objects.requireNonNull(packetQueue.poll()); // closed by the 'composeWith' func + + logger.debug("Executing promise {} with {}", promise.hashCode(), receivable.getMessageType()); + + return composeWith(receivable, (v) -> promise.completeAsync(() -> v)) + .thenCompose((v) -> processReadPromises()); + } + + logger.debug("Completed read promise steps"); + return CompletableFuture.completedFuture(null); + } + + private CompletionStage verifyAuthenticated() { + return CompletableFuture + .runAsync(() -> { + logger.debug("Verifying authentication state... is authed?: {}", isConnected()); + if(!isConnected()) { + throw new CompletionException( + new ConnectionFailedException("Cannot send to an unauthorized connection") + ); + } + }); + } + + private CompletionStage acquireLock(String operation) { + return CompletableFuture + .runAsync(() -> { + try { + logger.debug("[{}]: Acquiring lock...", operation); + if(!lock.tryAcquire( + client.getConfig().getMessageTimeoutValue(), + client.getConfig().getMessageTimeoutUnit()) + ) { + logger.debug("[{}]: Lock timed out", operation); + throw new CompletionException( + new TimeoutException("A message read process passed the configured message timeout") + ); + } + + logger.debug("[{}]: Lock acquired", operation); + } catch (InterruptedException v) { + throw new CompletionException(v); + } + }, lockExecutor); + } + + + @Override + public CompletionStage duplex(DuplexCallback func, @NotNull Sendable packet, @Nullable Sendable... packets) { + return acquireLock("DUPLEX") + .thenCompose((v) -> duplex0(func, packet, packets)) + .whenCompleteAsync((v,e) -> { + logger.debug("[DUPLEX]: Releasing lock"); + lock.release(); + }, lockExecutor); + } + + private CompletionStage duplex0(DuplexCallback func, @NotNull Sendable packet, @Nullable Sendable... packets) { + var duplexPromise = new CompletableFuture(); + return send0(packet, packets) + .thenCompose((v) -> processDuplexStep(func, duplexPromise)); + } + + private CompletionStage processDuplexStep(DuplexCallback func, CompletableFuture promise) { + return readNext0() + .thenApply((packet) -> new DuplexResult(packet, promise)) + .thenCompose((state) -> { + try { + return func.process(state); + } catch (EdgeDBException | OperationNotSupportedException e) { + return CompletableFuture.failedFuture(e); + } + }) + .thenCompose((v) -> { + if(promise.isDone()) { + if(promise.isCompletedExceptionally() || promise.isCancelled()) { + return promise; + } + + return CompletableFuture.completedFuture(null); + } + + return processDuplexStep(func, promise); + }); + } +} diff --git a/src/driver/src/main/java/com/edgedb/driver/binary/packets/ServerMessageType.java b/src/driver/src/main/java/com/edgedb/driver/binary/packets/ServerMessageType.java index 26a4fb24..2b14284f 100644 --- a/src/driver/src/main/java/com/edgedb/driver/binary/packets/ServerMessageType.java +++ b/src/driver/src/main/java/com/edgedb/driver/binary/packets/ServerMessageType.java @@ -1,9 +1,8 @@ package com.edgedb.driver.binary.packets; -import java.util.HashMap; -import java.util.Map; +import com.edgedb.driver.binary.BinaryEnum; -public enum ServerMessageType { +public enum ServerMessageType implements BinaryEnum { AUTHENTICATION (0x52), COMMAND_COMPLETE (0x43), COMMAND_DATA_DESCRIPTION (0x54), @@ -20,23 +19,13 @@ public enum ServerMessageType { SERVER_KEY_DATA (0x4b); private final byte code; - private final static Map map = new HashMap<>(); ServerMessageType(int code) { this.code = (byte)code; } - static { - for (ServerMessageType v : ServerMessageType.values()) { - map.put(v.code, v); - } - } - - public static ServerMessageType valueOf(Byte raw) { - return map.get(raw); - } - - public byte getCode() { + @Override + public Byte getValue() { return code; } } diff --git a/src/driver/src/main/java/com/edgedb/driver/binary/packets/receivable/Receivable.java b/src/driver/src/main/java/com/edgedb/driver/binary/packets/receivable/Receivable.java index 69ece2cf..36668057 100644 --- a/src/driver/src/main/java/com/edgedb/driver/binary/packets/receivable/Receivable.java +++ b/src/driver/src/main/java/com/edgedb/driver/binary/packets/receivable/Receivable.java @@ -2,8 +2,12 @@ import com.edgedb.driver.binary.packets.ServerMessageType; import org.jetbrains.annotations.NotNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public interface Receivable extends AutoCloseable { + Logger logger = LoggerFactory.getLogger(Receivable.class); + ServerMessageType getMessageType(); @SuppressWarnings("unchecked") @@ -18,5 +22,7 @@ default void release(T @NotNull [] closeable) throws E } @Override - default void close() throws Exception {} + default void close() throws Exception { + logger.debug("Closed {}:{}", this.hashCode(), getMessageType()); + } } diff --git a/src/driver/src/main/java/com/edgedb/driver/clients/EdgeDBBinaryClient.java b/src/driver/src/main/java/com/edgedb/driver/clients/EdgeDBBinaryClient.java index e087766c..70f5c0ba 100644 --- a/src/driver/src/main/java/com/edgedb/driver/clients/EdgeDBBinaryClient.java +++ b/src/driver/src/main/java/com/edgedb/driver/clients/EdgeDBBinaryClient.java @@ -55,8 +55,6 @@ public abstract class EdgeDBBinaryClient extends BaseEdgeDBClient { private UUID stateDescriptorId; private short connectionAttempts; - - protected Duplexer duplexer; private boolean isIdle; private final @NotNull Semaphore connectionSemaphore; private final @NotNull Semaphore querySemaphore; @@ -71,6 +69,8 @@ public EdgeDBBinaryClient(EdgeDBConnection connection, EdgeDBClientConfig config this.stateDescriptorId = CodecBuilder.INVALID_CODEC_ID; } + protected abstract Duplexer getDuplexer(); + @Override public @NotNull Optional getSuggestedPoolConcurrency() { return Optional.ofNullable(this.suggestedPoolConcurrency); @@ -88,10 +88,6 @@ protected void setIsIdle(boolean value) { isIdle = value; } - protected void setDuplexer(Duplexer duplexer) { - this.duplexer = duplexer; - } - private CompletionStage parse(@NotNull ExecutionArguments args) { return runWithAttempts( args, @@ -110,7 +106,7 @@ private CompletionStage parse0(@NotNull ExecutionArguments args) { logger.debug("Starting to parse... attempt {}/{}", args.parseAttempts + 1, MAX_PARSE_ATTEMPTS); - return duplexer.duplexAndSync(args.toParsePacket(), (result) -> { + return getDuplexer().duplexAndSync(args.toParsePacket(), (result) -> { logger.trace("parse duplex result: {}", result.packet.getMessageType()); switch (result.packet.getMessageType()) { case ERROR_RESPONSE: @@ -173,7 +169,7 @@ private CompletionStage parse0(@NotNull ExecutionArguments args) { case READY_FOR_COMMAND: var ready = result.packet.as(ReadyForCommand.class); setTransactionState(ready.transactionState); - args.completedParse = true; + args.completedParse = args.codecs != null; result.finishDuplexing(); } @@ -206,7 +202,7 @@ private CompletionStage execute0(@NotNull ExecutionArguments args) { } try { - return duplexer.duplexAndSync(args.toExecutePacket(), (result) -> { + return getDuplexer().duplexAndSync(args.toExecutePacket(), (result) -> { switch (result.packet.getMessageType()) { case DATA: var data = result.packet.as(Data.class); @@ -258,16 +254,14 @@ private CompletionStage runWithAttempts( private void handleCommandError(@NotNull ExecutionArguments args, Duplexer.@NotNull DuplexResult result, @NotNull ErrorResponse err) { if(err.errorCode == ErrorCode.STATE_MISMATCH_ERROR) { + logger.debug("Has updated state?: {}", args.stateUpdated); // should have new state if(!args.stateUpdated) { result.finishExceptionally( "Failed to properly encode state data, this is a bug", EdgeDBException::new ); - return; } - - result.finishDuplexing(); } else { result.finishExceptionally(err, args.query, ErrorResponse::toException); @@ -285,10 +279,6 @@ private void updateStateCodec(Duplexer.@NotNull DuplexResult result, @NotNull Ex stateDescriptor.typeDescriptorBuffer, Map.class ); - - if(codec == null) { - throw new MissingCodecException("Failed to build state codec"); - } } catch (EdgeDBException | OperationNotSupportedException e) { result.finishExceptionally("Failed to parse state codec", e, EdgeDBException::new); } @@ -308,8 +298,8 @@ private void updateStateCodec(Duplexer.@NotNull DuplexResult result, @NotNull Ex public final CompletionStage executeQuery( @NotNull ExecutionArguments args ) { - logger.debug("Execute request: is connected? {}", duplexer.isConnected()); - if(!duplexer.isConnected()) { + logger.debug("Execute request: is connected? {}", getDuplexer().isConnected()); + if(!getDuplexer().isConnected()) { // TODO: check for recursion return reconnect() .thenCompose(v -> executeQuery(args)); @@ -798,7 +788,7 @@ private CompletionStage startSASLAuthentication(@NotNull AuthenticationSta AtomicReference signature = new AtomicReference<>(new byte[0]); try{ - return this.duplexer.duplex(initialMessage, (state) -> { + return getDuplexer().duplex(initialMessage, (state) -> { logger.debug("Authentication duplex: M:{}", state.packet.getMessageType()); try { switch (state.packet.getMessageType()) { @@ -811,7 +801,7 @@ private CompletionStage startSASLAuthentication(@NotNull AuthenticationSta case AUTHENTICATION_SASL_CONTINUE: var result = scram.buildFinalMessage(auth, connection.getPassword()); signature.set(result.signature); - return this.duplexer.send(result.buildPacket()); + return getDuplexer().send(result.buildPacket()); case AUTHENTICATION_SASL_FINAL: var key = Scram.parseServerFinalMessage(auth); @@ -883,7 +873,7 @@ public CompletionStage connect() { } private CompletionStage doClientHandshake() { - return this.duplexer.readNext() + return getDuplexer().readNext() .thenCompose(packet -> { logger.debug("Processing handshake step with packet: {}", packet == null ? "NULL" : packet.getMessageType()); @@ -905,15 +895,15 @@ private CompletionStage doClientHandshake() { @Override public CompletionStage disconnect() { - return this.duplexer.disconnect(); + return getDuplexer().disconnect(); } private CompletionStage connectInternal() { - if(this.duplexer.isConnected()) { + if(getDuplexer().isConnected()) { return CompletableFuture.completedFuture(null); } - this.duplexer.reset(); + getDuplexer().reset(); return retryableConnect() .thenApply(v -> getConnectionArguments()) @@ -926,7 +916,7 @@ private CompletionStage connectInternal() { }, new ProtocolExtension[0] )) - .thenCompose(this.duplexer::send); + .thenCompose(getDuplexer()::send); } private CompletionStage retryableConnect() { @@ -966,7 +956,7 @@ private CompletionStage retryableConnect() { @Override public boolean isConnected() { - return this.duplexer.isConnected(); + return getDuplexer().isConnected(); } public final class ExecutionArguments { diff --git a/src/driver/src/main/java/com/edgedb/driver/clients/EdgeDBHttpClient.java b/src/driver/src/main/java/com/edgedb/driver/clients/EdgeDBHttpClient.java new file mode 100644 index 00000000..17197fb4 --- /dev/null +++ b/src/driver/src/main/java/com/edgedb/driver/clients/EdgeDBHttpClient.java @@ -0,0 +1,240 @@ +package com.edgedb.driver.clients; + +import com.edgedb.driver.EdgeDBClientConfig; +import com.edgedb.driver.EdgeDBConnection; +import com.edgedb.driver.TransactionState; +import com.edgedb.driver.binary.duplexers.Duplexer; +import com.edgedb.driver.binary.duplexers.HttpDuplexer; +import com.edgedb.driver.exceptions.ConnectionFailedException; +import com.edgedb.driver.exceptions.EdgeDBException; +import com.edgedb.driver.exceptions.ScramException; +import com.edgedb.driver.util.Scram; +import com.edgedb.driver.util.SslUtils; +import org.jetbrains.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.net.ProtocolException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.security.KeyManagementException; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; +import java.time.Duration; +import java.util.Arrays; +import java.util.Base64; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.stream.Collectors; + +public final class EdgeDBHttpClient extends EdgeDBBinaryClient { + private static final Logger logger = LoggerFactory.getLogger(EdgeDBHttpClient.class); + private static final String HTTP_TOKEN_AUTH_METHOD = "SCRAM-SHA-256"; + private final HttpDuplexer duplexer; + public final HttpClient httpClient; + + private @Nullable String authToken; + private @Nullable URI baseUri; + private @Nullable URI authUri; + private @Nullable URI execUri; + + public EdgeDBHttpClient(EdgeDBConnection connection, EdgeDBClientConfig config, AutoCloseable poolHandle) throws EdgeDBException { + super(connection, config, poolHandle); + this.duplexer = new HttpDuplexer(this); + SSLContext context; + try { + context = SSLContext.getInstance("TLS"); + SslUtils.initContextWithConnectionDetails(context, getConnectionArguments()); + } catch (NoSuchAlgorithmException | CertificateException | KeyStoreException | IOException | + KeyManagementException e) { + throw new EdgeDBException("Failed to initialize SSL context", e); + } + + this.httpClient = HttpClient.newBuilder() + .sslContext(context) + .build(); + } + + public String getToken() { + return this.authToken; + } + + public void clearToken() { + this.authToken = null; + } + + private CompletionStage authenticate() { + return CompletableFuture + .supplyAsync(() -> { + logger.debug("Creating SCRAM and initial auth message..."); + var scram = new Scram(); + + var first = scram.buildInitialMessage(getConnectionArguments().getUsername()); + + var request = HttpRequest.newBuilder() + .uri(getAuthUri()) + .version(HttpClient.Version.HTTP_2) + .header( + "Authorization", + HTTP_TOKEN_AUTH_METHOD + + " data=" + + Base64.getEncoder().encodeToString(first.getBytes(StandardCharsets.UTF_8))) + .GET() + .timeout(Duration.of( + getConfig().getMessageTimeoutValue(), + getConfig().getMessageTimeoutUnit().toChronoUnit())) + .build(); + + return Map.entry(scram, request); + }) + .thenCompose(entry -> { + logger.debug("Executing initial auth request"); + + return httpClient.sendAsync(entry.getValue(), HttpResponse.BodyHandlers.ofByteArray()) + .thenApply(response -> Map.entry(entry.getKey(), response)); + }) + .thenCompose(entry -> { + var authenticate = entry.getValue().headers().firstValue("www-authenticate"); + + logger.debug( + "Verifying response authenticate, is match?: {}", + authenticate.isPresent() && authenticate.get().startsWith(HTTP_TOKEN_AUTH_METHOD) + ); + + if(authenticate.isEmpty()) { + return CompletableFuture.failedFuture( + new ProtocolException("The only supported auth method is " + HTTP_TOKEN_AUTH_METHOD) + ); + } + + var authenticateData = authenticate.get().substring(HTTP_TOKEN_AUTH_METHOD.length() + 1); + + var keys = parseKeys(authenticateData); + + Scram.SASLFinalMessage finalMsg; + + try { + logger.debug("Building final message..."); + finalMsg = entry.getKey().buildFinalMessage( + new String(Base64.getDecoder().decode(keys.get("data")), StandardCharsets.UTF_8), + getConnectionArguments().getPassword() + ); + } catch (ScramException e) { + logger.debug("Failed to build final message", e); + return CompletableFuture.failedFuture(e); + } + + String payload = "sid=" + + keys.get("sid") + + " data=" + + Base64.getEncoder().encodeToString(finalMsg.message.getBytes(StandardCharsets.UTF_8)); + + var request = HttpRequest.newBuilder() + .uri(getAuthUri()) + .header( + "Authorization", + HTTP_TOKEN_AUTH_METHOD + " " + payload + ) + .GET() + .build(); + + logger.debug("Sending final auth message..."); + + return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()); + }) + .thenCompose(EdgeDBHttpClient::ensureSuccess) + .thenApply(HttpResponse::body); + } + + public static CompletionStage> ensureSuccess(HttpResponse response) { + logger.debug("Verifying success for code {}", response.statusCode()); + if(response.statusCode() / 100 != 2) { + return CompletableFuture.failedFuture( + new ConnectionFailedException( + "Could not authenticate: " + response.statusCode() + ) + ); + } + + return CompletableFuture.completedFuture(response); + } + + private static CompletionStage>> ensureSuccess( + Map.Entry> entry + ) { + return ensureSuccess(entry.getValue()) + .thenApply(v -> entry); + } + + private Map parseKeys(String s) { + return Arrays.stream(s.split(",")) + .map(v -> v.split("=")) + .collect(Collectors.toMap( + v -> v[0].trim(), + v -> v[1] + )); + } + + public synchronized URI getAuthUri() { + if(authUri != null) { + return authUri; + } + + return authUri = getBaseUri().resolve("/auth/token"); + } + + public synchronized URI getBaseUri() { + if(baseUri != null) { + return baseUri; + } + + return baseUri = URI.create( + "https://" + getConnectionArguments().getHostname() + ":" + getConnectionArguments().getPort() + ); + } + + public synchronized URI getExecUri() { + if(execUri != null) { + return execUri; + } + + return execUri = getBaseUri().resolve("/db/" + getConnectionArguments().getDatabase()); + } + + @Override + protected Duplexer getDuplexer() { + return this.duplexer; + } + + @Override + protected void setTransactionState(TransactionState state) { + // invalid for this client + } + + @Override + protected CompletionStage openConnection() { + return CompletableFuture.completedFuture(null); // not valid for this client. + } + + @Override + public CompletionStage connect() { + if(authToken == null) { + return authenticate() + .thenAccept(token -> this.authToken = token); + } + + return CompletableFuture.completedFuture(null); + } + + @Override + protected CompletionStage closeConnection() { + return CompletableFuture.completedFuture(null); + } +} diff --git a/src/driver/src/main/java/com/edgedb/driver/clients/EdgeDBTCPClient.java b/src/driver/src/main/java/com/edgedb/driver/clients/EdgeDBTCPClient.java index 842376a2..4fb55abe 100644 --- a/src/driver/src/main/java/com/edgedb/driver/clients/EdgeDBTCPClient.java +++ b/src/driver/src/main/java/com/edgedb/driver/clients/EdgeDBTCPClient.java @@ -42,7 +42,6 @@ public class EdgeDBTCPClient extends EdgeDBBinaryClient implements TransactableC public EdgeDBTCPClient(EdgeDBConnection connection, EdgeDBClientConfig config, AutoCloseable poolHandle) { super(connection, config, poolHandle); this.duplexer = new ChannelDuplexer(this); - setDuplexer(this.duplexer); this.bootstrap = new Bootstrap() .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) @@ -79,6 +78,12 @@ protected void initChannel(@NotNull SocketChannel ch) throws Exception { } }); } + + @Override + protected @NotNull ChannelDuplexer getDuplexer() { + return this.duplexer; + } + @Override protected void setTransactionState(TransactionState state) { this.transactionState = state; diff --git a/src/driver/src/main/java/com/edgedb/driver/util/SslUtils.java b/src/driver/src/main/java/com/edgedb/driver/util/SslUtils.java index bc32402a..c012eb21 100644 --- a/src/driver/src/main/java/com/edgedb/driver/util/SslUtils.java +++ b/src/driver/src/main/java/com/edgedb/driver/util/SslUtils.java @@ -5,40 +5,52 @@ import io.netty.handler.ssl.SslContextBuilder; import org.jetbrains.annotations.NotNull; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.X509TrustManager; import java.io.ByteArrayInputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.security.GeneralSecurityException; -import java.security.KeyStore; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; +import java.security.*; import java.security.cert.CertificateException; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; public class SslUtils { + public static final X509TrustManager INSECURE_TRUST_MANAGER = new X509TrustManager() { + public java.security.cert.X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + public void checkClientTrusted( + java.security.cert.X509Certificate[] certs, String authType) { + } + public void checkServerTrusted( + java.security.cert.X509Certificate[] certs, String authType) { + } + }; + + public static void initContextWithConnectionDetails( + @NotNull SSLContext context, @NotNull EdgeDBConnection connection) + throws KeyManagementException, CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException { + if(connection.getTLSSecurity() == TLSSecurityMode.INSECURE) { + context.init(null, new TrustManager[] {INSECURE_TRUST_MANAGER}, null); + return; + } + + context.init(null, getTrustManagerFactory(connection).getTrustManagers(), null); + } + public static void applyTrustManager(@NotNull EdgeDBConnection connection, @NotNull SslContextBuilder builder) throws GeneralSecurityException, IOException { if(connection.getTLSSecurity() == TLSSecurityMode.INSECURE) { - builder.trustManager(new X509TrustManager() { - public java.security.cert.X509Certificate[] getAcceptedIssuers() { - return new X509Certificate[0]; - } - public void checkClientTrusted( - java.security.cert.X509Certificate[] certs, String authType) { - } - public void checkServerTrusted( - java.security.cert.X509Certificate[] certs, String authType) { - } - }); + builder.trustManager(INSECURE_TRUST_MANAGER); } else { builder.trustManager(getTrustManagerFactory(connection)); } } - private static @NotNull TrustManagerFactory getTrustManagerFactory(@NotNull EdgeDBConnection connection) throws NoSuchAlgorithmException, KeyStoreException, CertificateException, IOException { + public static @NotNull TrustManagerFactory getTrustManagerFactory(@NotNull EdgeDBConnection connection) throws NoSuchAlgorithmException, KeyStoreException, CertificateException, IOException { var authority = connection.getTLSCertificateAuthority(); TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); diff --git a/src/driver/src/main/java/module-info.java b/src/driver/src/main/java/module-info.java index 16bdd907..c4b15bbe 100644 --- a/src/driver/src/main/java/module-info.java +++ b/src/driver/src/main/java/module-info.java @@ -18,6 +18,7 @@ requires io.netty.handler; requires org.jooq.joou; requires org.reflections; + requires java.net.http; opens com.edgedb.driver; }