diff --git a/src/main/scala/za/co/absa/hyperdrive/trigger/scheduler/executors/spark/SparkYarnClusterServiceImpl.scala b/src/main/scala/za/co/absa/hyperdrive/trigger/scheduler/executors/spark/SparkYarnClusterServiceImpl.scala index 4e8062dfe..7c8e54c72 100644 --- a/src/main/scala/za/co/absa/hyperdrive/trigger/scheduler/executors/spark/SparkYarnClusterServiceImpl.scala +++ b/src/main/scala/za/co/absa/hyperdrive/trigger/scheduler/executors/spark/SparkYarnClusterServiceImpl.scala @@ -15,18 +15,15 @@ package za.co.absa.hyperdrive.trigger.scheduler.executors.spark -import org.apache.spark.launcher.{ - InProcessLauncher, - NoBackendConnectionInProcessLauncher, - SparkAppHandle, - SparkLauncher -} +import org.apache.hadoop.security.UserGroupInformation +import org.apache.spark.launcher.{InProcessLauncher, NoBackendConnectionInProcessLauncher, SparkAppHandle} import org.springframework.stereotype.Service import za.co.absa.hyperdrive.trigger.configuration.application.SparkConfig import za.co.absa.hyperdrive.trigger.models.enums.JobStatuses.{Lost, SubmissionTimeout, Submitting} import za.co.absa.hyperdrive.trigger.models.{JobInstance, SparkInstanceParameters} import za.co.absa.hyperdrive.trigger.api.rest.utils.Extensions._ +import java.security.PrivilegedExceptionAction import java.util.UUID.randomUUID import java.util.concurrent.{CountDownLatch, TimeUnit} import javax.inject.Inject @@ -38,6 +35,8 @@ class SparkYarnClusterServiceImpl @Inject() ( executionContextProvider: SparkClusterServiceExecutionContextProvider ) extends SparkClusterService { private implicit val executionContext: ExecutionContext = executionContextProvider.get() + private val SparkYarnPrincipalProp = "spark.yarn.principal" + private val SparkYarnKeytabProp = "spark.yarn.keytab" override def submitJob( jobInstance: JobInstance, @@ -49,17 +48,18 @@ class SparkYarnClusterServiceImpl @Inject() ( updateJob(ji).map { _ => val submitTimeout = sparkConfig.yarn.submitTimeout val latch = new CountDownLatch(1) - val sparkAppHandle = - getSparkLauncher(id, ji.jobName, jobParameters).startApplication(new SparkAppHandle.Listener { - import scala.math.Ordered.orderingToOrdered - override def stateChanged(handle: SparkAppHandle): Unit = - if (handle.getState >= SparkAppHandle.State.SUBMITTED) { - latch.countDown() - } - override def infoChanged(handle: SparkAppHandle): Unit = { - // do nothing + val sparkAppHandleListener = new SparkAppHandle.Listener { + import scala.math.Ordered.orderingToOrdered + override def stateChanged(handle: SparkAppHandle): Unit = + if (handle.getState >= SparkAppHandle.State.SUBMITTED) { + latch.countDown() } - }) + override def infoChanged(handle: SparkAppHandle): Unit = { + // do nothing + } + } + val sparkAppHandle = + startSparkJob(getSparkLauncher(id, ji.jobName, jobParameters), sparkAppHandleListener, jobParameters) latch.await(submitTimeout, TimeUnit.MILLISECONDS) sparkAppHandle.kill() } @@ -103,6 +103,24 @@ class SparkYarnClusterServiceImpl @Inject() ( sparkLauncher } + private def startSparkJob(inProcessLauncher: InProcessLauncher, + sparkAppHandleListener: SparkAppHandle.Listener, + jobParameters: SparkInstanceParameters + ): SparkAppHandle = { + val user = jobParameters.additionalSparkConfig.find(_.key == SparkYarnPrincipalProp).map(_.value) + val keytab = jobParameters.additionalSparkConfig.find(_.key == SparkYarnKeytabProp).map(_.value) + (user, keytab) match { + case (Some(u), Some(k)) => + val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(u, k) + ugi.doAs(new PrivilegedExceptionAction[SparkAppHandle]() { + override def run(): SparkAppHandle = { + inProcessLauncher.startApplication(sparkAppHandleListener) + } + }) + case _ => inProcessLauncher.startApplication(sparkAppHandleListener) + } + } + /* Fixed inspired by https://stackoverflow.com/questions/43040793/scala-via-spark-with-yarn-curly-brackets-string-missing See https://issues.apache.org/jira/browse/SPARK-17814