Skip to content
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ object SparkEnv extends Logging {

// NB: blockManager is not valid until initialize() is called later.
val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster,
serializer, conf, memoryManager, mapOutputTracker, shuffleManager,
serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager,
blockTransferService, securityManager, numUsableCores)

val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.network

import scala.reflect.ClassTag

import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.storage.{BlockId, StorageLevel}

Expand All @@ -35,7 +37,11 @@ trait BlockDataManager {
* Returns true if the block was stored and false if the put operation failed or the block
* already existed.
*/
def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Boolean
def putBlockData(
blockId: BlockId,
data: ManagedBuffer,
level: StorageLevel,
classTag: ClassTag[_]): Boolean

/**
* Release locks acquired by [[putBlockData()]] and [[getBlockData()]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.nio.ByteBuffer

import scala.concurrent.{Await, Future, Promise}
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
Expand Down Expand Up @@ -76,7 +77,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit]
level: StorageLevel,
classTag: ClassTag[_]): Future[Unit]

/**
* A special case of [[fetchBlocks]], as it fetches only one block and is blocking.
Expand Down Expand Up @@ -114,7 +116,9 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Unit = {
Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf)
level: StorageLevel,
classTag: ClassTag[_]): Unit = {
val future = uploadBlock(hostname, port, execId, blockId, blockData, level, classTag)
Await.result(future, Duration.Inf)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.network.netty
import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import scala.language.existentials
import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.network.BlockDataManager
Expand Down Expand Up @@ -61,12 +63,16 @@ class NettyBlockRpcServer(
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)

case uploadBlock: UploadBlock =>
// StorageLevel is serialized as bytes using our JavaSerializer.
val level: StorageLevel =
serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
// StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
val (level: StorageLevel, classTag: ClassTag[_]) = {
serializer
.newInstance()
.deserialize(ByteBuffer.wrap(uploadBlock.metadata))
.asInstanceOf[(StorageLevel, ClassTag[_])]
}
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
val blockId = BlockId(uploadBlock.blockId)
blockManager.putBlockData(blockId, data, level)
blockManager.putBlockData(blockId, data, level, classTag)
responseContext.onSuccess(ByteBuffer.allocate(0))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.network._
Expand Down Expand Up @@ -118,18 +119,19 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit] = {
level: StorageLevel,
classTag: ClassTag[_]): Future[Unit] = {
val result = Promise[Unit]()
val client = clientFactory.createClient(hostname, port)

// StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
// using our binary protocol.
val levelBytes = JavaUtils.bufferToArray(serializer.newInstance().serialize(level))
// StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
// Everything else is encoded using our binary protocol.
val metadata = JavaUtils.bufferToArray(serializer.newInstance().serialize((level, classTag)))

// Convert or copy nio buffer into array in order to serialize it.
val array = JavaUtils.bufferToArray(blockData.nioByteBuffer())

client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer,
client.sendRpc(new UploadBlock(appId, execId, blockId.toString, metadata, array).toByteBuffer,
new RpcResponseCallback {
override def onSuccess(response: ByteBuffer): Unit = {
logTrace(s"Successfully uploaded block $blockId")
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ abstract class RDD[T: ClassTag](
val blockId = RDDBlockId(id, partition.index)
var readCachedBlock = true
// This method is called on executors, so we need call SparkEnv.get instead of sc.env.
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, () => {
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
readCachedBlock = false
computeOrReadCheckpoint(partition, context)
}) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag

import com.google.common.collect.ConcurrentHashMultiset

Expand All @@ -37,10 +38,14 @@ import org.apache.spark.internal.Logging
* @param level the block's storage level. This is the requested persistence level, not the
* effective storage level of the block (i.e. if this is MEMORY_AND_DISK, then this
* does not imply that the block is actually resident in memory).
* @param classTag the block's [[ClassTag]], used to select the serializer
* @param tellMaster whether state changes for this block should be reported to the master. This
* is true for most blocks, but is false for broadcast blocks.
*/
private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
private[storage] class BlockInfo(
val level: StorageLevel,
val classTag: ClassTag[_],
val tellMaster: Boolean) {

/**
* The size of the block (in bytes)
Expand Down
Loading