diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 24ed449f2a7d1..fe836bf4c1b3c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1177,11 +1177,67 @@ setMethod("dim", setMethod("collect", signature(x = "SparkDataFrame"), function(x, stringsAsFactors = FALSE) { + connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) + useArrow <- FALSE + arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true" + if (arrowEnabled) { + useArrow <- tryCatch({ + requireNamespace1 <- requireNamespace + if (!requireNamespace1("arrow", quietly = TRUE)) { + stop("'arrow' package should be installed.") + } + # Currenty Arrow optimization does not support raw for now. + # Also, it does not support explicit float type set by users. + if (inherits(schema(x), "structType")) { + if (any(sapply(schema(x)$fields(), + function(x) x$dataType.toString() == "FloatType"))) { + stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ", + "DataFrame does not support FloatType yet.")) + } + if (any(sapply(schema(x)$fields(), + function(x) x$dataType.toString() == "BinaryType"))) { + stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ", + "DataFrame does not support BinaryType yet.")) + } + } + TRUE + }, error = function(e) { + warning(paste0("The conversion from Spark DataFrame to R DataFrame was attempted ", + "with Arrow optimization because ", + "'spark.sql.execution.arrow.enabled' is set to true; however, ", + "failed, attempting non-optimization. Reason: ", + e)) + FALSE + }) + } + dtypes <- dtypes(x) ncol <- length(dtypes) if (ncol <= 0) { # empty data.frame with 0 columns and 0 rows data.frame() + } else if (useArrow) { + requireNamespace1 <- requireNamespace + if (requireNamespace1("arrow", quietly = TRUE)) { + read_arrow <- get("read_arrow", envir = asNamespace("arrow"), inherits = FALSE) + as_tibble <- get("as_tibble", envir = asNamespace("arrow")) + + portAuth <- callJMethod(x@sdf, "collectAsArrowToR") + port <- portAuth[[1]] + authSecret <- portAuth[[2]] + conn <- socketConnection( + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) + output <- tryCatch({ + doServerAuth(conn, authSecret) + arrowTable <- read_arrow(readRaw(conn)) + as.data.frame(as_tibble(arrowTable), stringsAsFactors = stringsAsFactors) + }, finally = { + close(conn) + }) + return(output) + } else { + stop("'arrow' package should be installed.") + } } else { # listCols is a list of columns listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 9dc699c09a1e4..21eaa32f0011c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -307,7 +307,7 @@ test_that("create DataFrame from RDD", { unsetHiveContext() }) -test_that("createDataFrame Arrow optimization", { +test_that("createDataFrame/collect Arrow optimization", { skip_if_not_installed("arrow") conf <- callJMethod(sparkSession, "conf") @@ -332,7 +332,24 @@ test_that("createDataFrame Arrow optimization", { }) }) -test_that("createDataFrame Arrow optimization - type specification", { +test_that("createDataFrame/collect Arrow optimization - many partitions (partition order test)", { + skip_if_not_installed("arrow") + + conf <- callJMethod(sparkSession, "conf") + arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] + + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true") + tryCatch({ + expect_equal(collect(createDataFrame(mtcars, numPartitions = 32)), + collect(createDataFrame(mtcars, numPartitions = 1))) + }, + finally = { + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled) + }) +}) + +test_that("createDataFrame/collect Arrow optimization - type specification", { skip_if_not_installed("arrow") rdf <- data.frame(list(list(a = 1, b = "a", diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 14ea289e5f908..0937a63dad19b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -430,6 +430,12 @@ private[spark] object PythonRDD extends Logging { */ private[spark] def serveToStream( threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { + serveToStream(threadName, authHelper)(writeFunc) + } + + private[spark] def serveToStream( + threadName: String, authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit) + : Array[Any] = { val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s => val out = new BufferedOutputStream(s.getOutputStream()) Utils.tryWithSafeFinally { diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 1dc61c7eef33c..04fc6e18c1e5c 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.r -import java.io.{DataInputStream, File} +import java.io.{DataInputStream, File, OutputStream} import java.net.Socket import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Map => JMap} @@ -104,7 +104,7 @@ private class StringRRDD[T: ClassTag]( lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) } -private[r] object RRDD { +private[spark] object RRDD { def createSparkContext( master: String, appName: String, @@ -165,6 +165,11 @@ private[r] object RRDD { JavaRDD[Array[Byte]] = { PythonRDD.readRDDFromFile(jsc, fileName, parallelism) } + + private[spark] def serveToStream( + threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { + PythonRDD.serveToStream(threadName, new RSocketAuthHelper())(writeFunc) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8a26152271a83..bd1ae509cf54b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.io.{CharArrayWriter, DataOutputStream} +import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ import org.apache.spark.api.python.{PythonRDD, SerDeUtil} +import org.apache.spark.api.r.RRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.QueryPlanningTracker @@ -3198,9 +3199,66 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as Arrow batches and serve stream to PySpark. + * Collect a Dataset as Arrow batches and serve stream to SparkR. It sends + * arrow batches in an ordered manner with buffering. This is inevitable + * due to missing R API that reads batches from socket directly. See ARROW-4512. + * Eventually, this code should be deduplicated by `collectAsArrowToPython`. */ - private[sql] def collectAsArrowToPython(): Array[Any] = { + private[sql] def collectAsArrowToR(): Array[Any] = { + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + + withAction("collectAsArrowToR", queryExecution) { plan => + RRDD.serveToStream("serve-Arrow") { outputStream => + val buffer = new ByteArrayOutputStream() + val out = new DataOutputStream(outputStream) + val batchWriter = new ArrowBatchStreamWriter(schema, buffer, timeZoneId) + val arrowBatchRdd = toArrowBatchRdd(plan) + val numPartitions = arrowBatchRdd.partitions.length + + // Store collection results for worst case of 1 to N-1 partitions + val results = new Array[Array[Array[Byte]]](numPartitions - 1) + var lastIndex = -1 // index of last partition written + + // Handler to eagerly write partitions to Python in order + def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { + batchWriter.writeBatches(arrowBatches.iterator) + lastIndex += 1 + // Write stored partitions that come next in order + while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 + } + // After last batch, end the stream + if (lastIndex == results.length) { + batchWriter.end() + val batches = buffer.toByteArray + out.writeInt(batches.length) + out.write(batches) + } + } else { + // Store partitions received out of order + results(index - 1) = arrowBatches + } + } + + sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until numPartitions, + handlePartitionBatches) + } + } + } + + /** + * Collect a Dataset as Arrow batches and serve stream to PySpark. It sends + * arrow batches in an un-ordered manner without buffering, and then batch order + * information at the end. The batches should be reordered at Python side. + */ + private[sql] def collectAsArrowToPython: Array[Any] = { val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone withAction("collectAsArrowToPython", queryExecution) { plan => @@ -3211,7 +3269,7 @@ class Dataset[T] private[sql]( val numPartitions = arrowBatchRdd.partitions.length // Batches ordered by (index of partition, batch index in that partition) tuple - val batchOrder = new ArrayBuffer[(Int, Int)]() + val batchOrder = ArrayBuffer.empty[(Int, Int)] var partitionCount = 0 // Handler to eagerly write batches to Python as they arrive, un-ordered @@ -3220,7 +3278,7 @@ class Dataset[T] private[sql]( // Write all batches (can be more than 1) in the partition, store the batch order tuple batchWriter.writeBatches(arrowBatches.iterator) arrowBatches.indices.foreach { - partition_batch_index => batchOrder.append((index, partition_batch_index)) + partitionBatchIndex => batchOrder.append((index, partitionBatchIndex)) } } partitionCount += 1 @@ -3232,8 +3290,8 @@ class Dataset[T] private[sql]( // Sort by (index of partition, batch index in that partition) tuple to get the // overall_batch_index from 0 to N-1 batches, which can be used to put the // transferred batches in the correct order - batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) => - out.writeInt(overall_batch_index) + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => + out.writeInt(overallBatchIndex) } out.flush() }