Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,15 @@

package za.co.absa.hyperdrive.trigger.scheduler.executors.spark

import org.apache.spark.launcher.{
InProcessLauncher,
NoBackendConnectionInProcessLauncher,
SparkAppHandle,
SparkLauncher
}
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.launcher.{InProcessLauncher, NoBackendConnectionInProcessLauncher, SparkAppHandle}
import org.springframework.stereotype.Service
import za.co.absa.hyperdrive.trigger.configuration.application.SparkConfig
import za.co.absa.hyperdrive.trigger.models.enums.JobStatuses.{Lost, SubmissionTimeout, Submitting}
import za.co.absa.hyperdrive.trigger.models.{JobInstance, SparkInstanceParameters}
import za.co.absa.hyperdrive.trigger.api.rest.utils.Extensions._

import java.security.PrivilegedExceptionAction
import java.util.UUID.randomUUID
import java.util.concurrent.{CountDownLatch, TimeUnit}
import javax.inject.Inject
Expand All @@ -38,6 +35,8 @@ class SparkYarnClusterServiceImpl @Inject() (
executionContextProvider: SparkClusterServiceExecutionContextProvider
) extends SparkClusterService {
private implicit val executionContext: ExecutionContext = executionContextProvider.get()
private val SparkYarnPrincipalProp = "spark.yarn.principal"
private val SparkYarnKeytabProp = "spark.yarn.keytab"

override def submitJob(
jobInstance: JobInstance,
Expand All @@ -49,17 +48,18 @@ class SparkYarnClusterServiceImpl @Inject() (
updateJob(ji).map { _ =>
val submitTimeout = sparkConfig.yarn.submitTimeout
val latch = new CountDownLatch(1)
val sparkAppHandle =
getSparkLauncher(id, ji.jobName, jobParameters).startApplication(new SparkAppHandle.Listener {
import scala.math.Ordered.orderingToOrdered
override def stateChanged(handle: SparkAppHandle): Unit =
if (handle.getState >= SparkAppHandle.State.SUBMITTED) {
latch.countDown()
}
override def infoChanged(handle: SparkAppHandle): Unit = {
// do nothing
val sparkAppHandleListener = new SparkAppHandle.Listener {
import scala.math.Ordered.orderingToOrdered
override def stateChanged(handle: SparkAppHandle): Unit =
if (handle.getState >= SparkAppHandle.State.SUBMITTED) {
latch.countDown()
}
})
override def infoChanged(handle: SparkAppHandle): Unit = {
// do nothing
}
}
val sparkAppHandle =
startSparkJob(getSparkLauncher(id, ji.jobName, jobParameters), sparkAppHandleListener, jobParameters)
latch.await(submitTimeout, TimeUnit.MILLISECONDS)
sparkAppHandle.kill()
}
Expand Down Expand Up @@ -103,6 +103,24 @@ class SparkYarnClusterServiceImpl @Inject() (
sparkLauncher
}

private def startSparkJob(inProcessLauncher: InProcessLauncher,
sparkAppHandleListener: SparkAppHandle.Listener,
jobParameters: SparkInstanceParameters
): SparkAppHandle = {
val user = jobParameters.additionalSparkConfig.find(_.key == SparkYarnPrincipalProp).map(_.value)
val keytab = jobParameters.additionalSparkConfig.find(_.key == SparkYarnKeytabProp).map(_.value)
(user, keytab) match {
case (Some(u), Some(k)) =>
val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(u, k)
ugi.doAs(new PrivilegedExceptionAction[SparkAppHandle]() {
override def run(): SparkAppHandle = {
inProcessLauncher.startApplication(sparkAppHandleListener)
}
})
case _ => inProcessLauncher.startApplication(sparkAppHandleListener)
}
}

/*
Fixed inspired by https://stackoverflow.com/questions/43040793/scala-via-spark-with-yarn-curly-brackets-string-missing
See https://issues.apache.org/jira/browse/SPARK-17814
Expand Down