AMATERASU-26 Pipeline tasks runs as "yarn" user instead of inheriting the user 18/head
authorArun Manivannan <arun@arunma.com>
Fri, 18 May 2018 12:54:36 +0000 (20:54 +0800)
committerArun Manivannan <arun@arunma.com>
Fri, 18 May 2018 12:54:36 +0000 (20:54 +0800)
leader/src/main/java/org/apache/amaterasu/leader/yarn/Client.java
leader/src/main/scala/org/apache/amaterasu/leader/yarn/ApplicationMaster.scala

index dc4f15e..e3c2812 100644 (file)
@@ -29,6 +29,7 @@ import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.yarn.api.ApplicationConstants;
 import org.apache.hadoop.yarn.api.records.*;
 import org.apache.hadoop.yarn.client.api.YarnClient;
@@ -110,8 +111,9 @@ public class Client {
 
 
         List<String> commands = Collections.singletonList(
-                "env AMA_NODE=" + System.getenv("AMA_NODE") + " " +
-                        "$JAVA_HOME/bin/java" +
+                "env AMA_NODE=" + System.getenv("AMA_NODE") +
+                        " env HADOOP_USER_NAME=" + UserGroupInformation.getCurrentUser().getUserName() +
+                        " $JAVA_HOME/bin/java" +
                         " -Dscala.usejavacp=false" +
                         " -Xmx1G" +
                         " org.apache.amaterasu.leader.yarn.ApplicationMaster " +
index a44202a..1828100 100644 (file)
@@ -21,8 +21,8 @@ import java.net.{InetAddress, ServerSocket, URLEncoder}
 import java.nio.ByteBuffer
 import java.util
 import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}
-import javax.jms.Session
 
+import javax.jms.Session
 import org.apache.activemq.ActiveMQConnectionFactory
 import org.apache.activemq.broker.BrokerService
 import org.apache.amaterasu.common.configuration.ClusterConfig
@@ -153,21 +153,14 @@ class ApplicationMaster extends AMRMClientAsync.CallbackHandler with Logging {
     // TODO: awsEnv currently set to empty string. should be changed to read values from (where?).
     allocListener = new YarnRMCallbackHandler(nmClient, jobManager, env, awsEnv = "", config, executorJar)
 
-    rmClient = AMRMClientAsync.createAMRMClientAsync(1000, this)
-    rmClient.init(conf)
-    rmClient.start()
-
-    // Register with ResourceManager
-    log.info("Registering application")
-    val registrationResponse = rmClient.registerApplicationMaster("", 0, "")
-    log.info("Registered application")
+    rmClient = startRMClient()
+    val registrationResponse = registerAppMaster("", 0, "")
     val maxMem = registrationResponse.getMaximumResourceCapability.getMemory
     log.info("Max mem capability of resources in this cluster " + maxMem)
     val maxVCores = registrationResponse.getMaximumResourceCapability.getVirtualCores
     log.info("Max vcores capability of resources in this cluster " + maxVCores)
     log.info(s"Created jobManager. jobManager.registeredActions.size: ${jobManager.registeredActions.size}")
 
-
     // Resource requirements for worker containers
     this.capability = Records.newRecord(classOf[Resource])
     val frameworkFactory = FrameworkProvidersFactory.apply(env, config)
@@ -194,6 +187,21 @@ class ApplicationMaster extends AMRMClientAsync.CallbackHandler with Logging {
     log.info("Finished asking for containers")
   }
 
+  private def startRMClient(): AMRMClientAsync[ContainerRequest] = {
+    val client = AMRMClientAsync.createAMRMClientAsync[ContainerRequest](1000, this)
+    client.init(conf)
+    client.start()
+    client
+  }
+
+  private def registerAppMaster(host: String, port: Int, url: String) = {
+    // Register with ResourceManager
+    log.info("Registering application")
+    val registrationResponse = rmClient.registerApplicationMaster(host, port, url)
+    log.info("Registered application")
+    registrationResponse
+  }
+
   private def setupMessaging(jobId: String): Unit = {
 
     val cf = new ActiveMQConnectionFactory(address)
@@ -225,20 +233,6 @@ class ApplicationMaster extends AMRMClientAsync.CallbackHandler with Logging {
 
   override def onContainersAllocated(containers: util.List[Container]): Unit = {
 
-    // creating the credentials for container execution
-    val credentials = UserGroupInformation.getCurrentUser.getCredentials
-    val dob = new DataOutputBuffer
-    credentials.writeTokenStorageToStream(dob)
-
-    // removing the AM->RM token so that containers cannot access it.
-    val iter = credentials.getAllTokens.iterator
-    log.info("Executing with tokens:")
-    for (token <- iter) {
-      log.info(token.toString)
-      if (token.getKind == AMRMTokenIdentifier.KIND_NAME) iter.remove()
-    }
-    val allTokens = ByteBuffer.wrap(dob.getData, 0, dob.getLength)
-
     log.info(s"${containers.size()} Containers allocated")
     for (container <- containers.asScala) { // Launch container by create ContainerLaunchContext
       if (actionsBuffer.isEmpty) {
@@ -294,7 +288,8 @@ class ApplicationMaster extends AMRMClientAsync.CallbackHandler with Logging {
         ctx.setEnvironment(Map[String, String](
           "HADOOP_CONF_DIR" -> s"${config.YARN.hadoopHomeDir}/conf/",
           "YARN_CONF_DIR" -> s"${config.YARN.hadoopHomeDir}/conf/",
-          "AMA_NODE" -> sys.env("AMA_NODE")
+          "AMA_NODE" -> sys.env("AMA_NODE"),
+          "HADOOP_USER_NAME" -> UserGroupInformation.getCurrentUser.getUserName
         ))
 
         log.info(s"hadoop conf dir is ${config.YARN.hadoopHomeDir}/conf/")
@@ -316,6 +311,22 @@ class ApplicationMaster extends AMRMClientAsync.CallbackHandler with Logging {
     }
   }
 
+  private def allTokens: ByteBuffer = {
+    // creating the credentials for container execution
+    val credentials = UserGroupInformation.getCurrentUser.getCredentials
+    val dob = new DataOutputBuffer
+    credentials.writeTokenStorageToStream(dob)
+
+    // removing the AM->RM token so that containers cannot access it.
+    val iter = credentials.getAllTokens.iterator
+    log.info("Executing with tokens:")
+    for (token <- iter) {
+      log.info(token.toString)
+      if (token.getKind == AMRMTokenIdentifier.KIND_NAME) iter.remove()
+    }
+    ByteBuffer.wrap(dob.getData, 0, dob.getLength)
+  }
+
   private def setupResources(frameworkPath: String, countainerResources: mutable.Map[String, LocalResource], resourcesPath: String): Unit = {
 
     val sourcePath = Path.mergePaths(jarPath, new Path(s"/$resourcesPath"))