diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 146f754c48f92..b3f33fb21893a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -88,7 +88,7 @@ private[spark] case class PythonFunction( private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) /** Thrown for exceptions in user Python code. */ -private[spark] class PythonException(msg: String, cause: Exception) +private[spark] class PythonException(msg: String, cause: Throwable) extends RuntimeException(msg, cause) /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f7d1461368b9c..d88757bcc1aa0 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -23,6 +23,7 @@ import java.nio.charset.StandardCharsets import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.internal.Logging @@ -143,15 +144,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( context: TaskContext) extends Thread(s"stdout writer for $pythonExec") { - @volatile private var _exception: Exception = null + @volatile private var _exception: Throwable = null private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) setDaemon(true) - /** Contains the exception thrown while writing the parent iterator to the Python process. */ - def exception: Option[Exception] = Option(_exception) + /** Contains the throwable thrown while writing the parent iterator to the Python process. */ + def exception: Option[Throwable] = Option(_exception) /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ def shutdownOnTaskCompletion() { @@ -251,18 +252,21 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() } catch { - case e: Exception if context.isCompleted || context.isInterrupted => - logDebug("Exception thrown after task completion (likely due to cleanup)", e) - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - - case e: Exception => - // We must avoid throwing exceptions here, because the thread uncaught exception handler - // will kill the whole executor (see org.apache.spark.executor.Executor). - _exception = e - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) + case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) => + if (context.isCompleted || context.isInterrupted) { + logDebug("Exception/NonFatal Error thrown after task completion (likely due to " + + "cleanup)", t) + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + } else { + // We must avoid throwing exceptions/NonFatals here, because the thread uncaught + // exception handler will kill the whole executor (see + // org.apache.spark.executor.Executor). + _exception = t + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } } } }