AMATERASU-26 Pipeline tasks runs as "yarn" user instead of inheriting the user
[incubator-amaterasu.git] / leader / src / main / scala / org / apache / amaterasu / leader / yarn / ApplicationMaster.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 package org.apache.amaterasu.leader.yarn
18
19 import java.io.{File, FileInputStream, InputStream}
20 import java.net.{InetAddress, ServerSocket, URLEncoder}
21 import java.nio.ByteBuffer
22 import java.util
23 import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}
24
25 import javax.jms.Session
26 import org.apache.activemq.ActiveMQConnectionFactory
27 import org.apache.activemq.broker.BrokerService
28 import org.apache.amaterasu.common.configuration.ClusterConfig
29 import org.apache.amaterasu.common.dataobjects.ActionData
30 import org.apache.amaterasu.common.logging.Logging
31 import org.apache.amaterasu.leader.execution.frameworks.FrameworkProvidersFactory
32 import org.apache.amaterasu.leader.execution.{JobLoader, JobManager}
33 import org.apache.amaterasu.leader.utilities.{ActiveReportListener, Args, DataLoader}
34 import org.apache.curator.framework.recipes.barriers.DistributedBarrier
35 import org.apache.curator.framework.{CuratorFramework, CuratorFrameworkFactory}
36 import org.apache.curator.retry.ExponentialBackoffRetry
37 import org.apache.hadoop.fs.{FileSystem, Path}
38 import org.apache.hadoop.io.DataOutputBuffer
39 import org.apache.hadoop.security.UserGroupInformation
40 import org.apache.hadoop.yarn.api.ApplicationConstants
41 import org.apache.hadoop.yarn.api.records._
42 import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
43 import org.apache.hadoop.yarn.client.api.async.impl.NMClientAsyncImpl
44 import org.apache.hadoop.yarn.client.api.async.{AMRMClientAsync, NMClientAsync}
45 import org.apache.hadoop.yarn.conf.YarnConfiguration
46 import org.apache.hadoop.yarn.security.AMRMTokenIdentifier
47 import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
48 import org.apache.zookeeper.CreateMode
49
50 import scala.collection.JavaConversions._
51 import scala.collection.JavaConverters._
52 import scala.collection.{concurrent, mutable}
53 import scala.concurrent.ExecutionContext.Implicits.global
54 import scala.concurrent.Future
55 import scala.util.{Failure, Success}
56
57 class ApplicationMaster extends AMRMClientAsync.CallbackHandler with Logging {
58
59   var capability: Resource = _
60
61   log.info("ApplicationMaster start")
62
63   private var jobManager: JobManager = _
64   private var client: CuratorFramework = _
65   private var config: ClusterConfig = _
66   private var env: String = _
67   private var branch: String = _
68   private var fs: FileSystem = _
69   private var conf: YarnConfiguration = _
70   private var propPath: String = ""
71   private var props: InputStream = _
72   private var jarPath: Path = _
73   private var executorPath: Path = _
74   private var executorJar: LocalResource = _
75   private var propFile: LocalResource = _
76   private var log4jPropFile: LocalResource = _
77   private var nmClient: NMClientAsync = _
78   private var allocListener: YarnRMCallbackHandler = _
79   private var rmClient: AMRMClientAsync[ContainerRequest] = _
80   private var address: String = _
81
82   private val containersIdsToTask: concurrent.Map[Long, ActionData] = new ConcurrentHashMap[Long, ActionData].asScala
83   private val completedContainersAndTaskIds: concurrent.Map[Long, String] = new ConcurrentHashMap[Long, String].asScala
84   private val actionsBuffer: java.util.concurrent.ConcurrentLinkedQueue[ActionData] = new java.util.concurrent.ConcurrentLinkedQueue[ActionData]()
85   private val host: String = InetAddress.getLocalHost.getHostName
86   private val broker: BrokerService = new BrokerService()
87
88   def setLocalResourceFromPath(path: Path): LocalResource = {
89
90     val stat = fs.getFileStatus(path)
91     val fileResource = Records.newRecord(classOf[LocalResource])
92
93     fileResource.setResource(ConverterUtils.getYarnUrlFromPath(path))
94     fileResource.setSize(stat.getLen)
95     fileResource.setTimestamp(stat.getModificationTime)
96     fileResource.setType(LocalResourceType.FILE)
97     fileResource.setVisibility(LocalResourceVisibility.PUBLIC)
98     fileResource
99
100   }
101
102   def execute(arguments: Args): Unit = {
103
104     log.info(s"started AM with args $arguments")
105
106     propPath = System.getenv("PWD") + "/amaterasu.properties"
107     props = new FileInputStream(new File(propPath))
108
109     // no need for hdfs double check (nod to Aaron Rodgers)
110     // jars on HDFS should have been verified by the YARN client
111     conf = new YarnConfiguration()
112     fs = FileSystem.get(conf)
113
114     config = ClusterConfig(props)
115
116     try {
117       initJob(arguments)
118     } catch {
119       case e: Exception => log.error("error initializing ", e.getMessage)
120     }
121
122     // now that the job was initiated, the curator client is started and we can
123     // register the broker's address
124     client.create().withMode(CreateMode.PERSISTENT).forPath(s"/${jobManager.jobId}/broker")
125     client.setData().forPath(s"/${jobManager.jobId}/broker", address.getBytes)
126
127     // once the broker is registered, we can remove the barrier so clients can connect
128     log.info(s"/${jobManager.jobId}-report-barrier")
129     val barrier = new DistributedBarrier(client, s"/${jobManager.jobId}-report-barrier")
130     barrier.removeBarrier()
131
132     setupMessaging(jobManager.jobId)
133
134     log.info(s"Job ${jobManager.jobId} initiated with ${jobManager.registeredActions.size} actions")
135
136     jarPath = new Path(config.YARN.hdfsJarsPath)
137
138     // TODO: change this to read all dist folder and add to exec path
139     executorPath = Path.mergePaths(jarPath, new Path(s"/dist/executor-${config.version}-all.jar"))
140     log.info("Executor jar path is {}", executorPath)
141     executorJar = setLocalResourceFromPath(executorPath)
142     propFile = setLocalResourceFromPath(Path.mergePaths(jarPath, new Path("/amaterasu.properties")))
143     log4jPropFile = setLocalResourceFromPath(Path.mergePaths(jarPath, new Path("/log4j.properties")))
144
145     log.info("Started execute")
146
147     nmClient = new NMClientAsyncImpl(new YarnNMCallbackHandler())
148
149     // Initialize clients to ResourceManager and NodeManagers
150     nmClient.init(conf)
151     nmClient.start()
152
153     // TODO: awsEnv currently set to empty string. should be changed to read values from (where?).
154     allocListener = new YarnRMCallbackHandler(nmClient, jobManager, env, awsEnv = "", config, executorJar)
155
156     rmClient = startRMClient()
157     val registrationResponse = registerAppMaster("", 0, "")
158     val maxMem = registrationResponse.getMaximumResourceCapability.getMemory
159     log.info("Max mem capability of resources in this cluster " + maxMem)
160     val maxVCores = registrationResponse.getMaximumResourceCapability.getVirtualCores
161     log.info("Max vcores capability of resources in this cluster " + maxVCores)
162     log.info(s"Created jobManager. jobManager.registeredActions.size: ${jobManager.registeredActions.size}")
163
164     // Resource requirements for worker containers
165     this.capability = Records.newRecord(classOf[Resource])
166     val frameworkFactory = FrameworkProvidersFactory.apply(env, config)
167
168     while (!jobManager.outOfActions) {
169       val actionData = jobManager.getNextActionData
170       if (actionData != null) {
171
172         val frameworkProvider = frameworkFactory.providers(actionData.groupId)
173         val driverConfiguration = frameworkProvider.getDriverConfiguration
174
175         var mem: Int = driverConfiguration.getMemory
176         mem = Math.min(mem, maxMem)
177         this.capability.setMemory(mem)
178
179         var cpu = driverConfiguration.getCPUs
180         cpu = Math.min(cpu, maxVCores)
181         this.capability.setVirtualCores(cpu)
182
183         askContainer(actionData)
184       }
185     }
186
187     log.info("Finished asking for containers")
188   }
189
190   private def startRMClient(): AMRMClientAsync[ContainerRequest] = {
191     val client = AMRMClientAsync.createAMRMClientAsync[ContainerRequest](1000, this)
192     client.init(conf)
193     client.start()
194     client
195   }
196
197   private def registerAppMaster(host: String, port: Int, url: String) = {
198     // Register with ResourceManager
199     log.info("Registering application")
200     val registrationResponse = rmClient.registerApplicationMaster(host, port, url)
201     log.info("Registered application")
202     registrationResponse
203   }
204
205   private def setupMessaging(jobId: String): Unit = {
206
207     val cf = new ActiveMQConnectionFactory(address)
208     val conn = cf.createConnection()
209     conn.start()
210
211     val session = conn.createSession(false, Session.AUTO_ACKNOWLEDGE)
212     //TODO: move to a const in common
213     val destination = session.createTopic("JOB.REPORT")
214
215     val consumer = session.createConsumer(destination)
216     consumer.setMessageListener(new ActiveReportListener)
217
218   }
219
220   private def askContainer(actionData: ActionData): Unit = {
221
222     actionsBuffer.add(actionData)
223     log.info(s"About to ask container for action ${actionData.id}. Action buffer size is: ${actionsBuffer.size()}")
224
225     // we have an action to schedule, let's request a container
226     val priority: Priority = Records.newRecord(classOf[Priority])
227     priority.setPriority(1)
228     val containerReq = new ContainerRequest(capability, null, null, priority)
229     rmClient.addContainerRequest(containerReq)
230     log.info(s"Asked container for action ${actionData.id}")
231
232   }
233
234   override def onContainersAllocated(containers: util.List[Container]): Unit = {
235
236     log.info(s"${containers.size()} Containers allocated")
237     for (container <- containers.asScala) { // Launch container by create ContainerLaunchContext
238       if (actionsBuffer.isEmpty) {
239         log.warn(s"Why actionBuffer empty and i was called?. Container ids: ${containers.map(c => c.getId.getContainerId)}")
240         return
241       }
242
243       val actionData = actionsBuffer.poll()
244       val containerTask = Future[ActionData] {
245
246         val taskData = DataLoader.getTaskDataString(actionData, env)
247         val execData = DataLoader.getExecutorDataString(env, config)
248
249         val ctx = Records.newRecord(classOf[ContainerLaunchContext])
250         val commands: List[String] = List(
251           "/bin/bash ./miniconda.sh -b -p $PWD/miniconda && ",
252           s"/bin/bash spark/bin/load-spark-env.sh && ",
253           s"java -cp spark/jars/*:executor.jar:spark/conf/:${config.YARN.hadoopHomeDir}/conf/ " +
254             "-Xmx1G " +
255             "-Dscala.usejavacp=true " +
256             "-Dhdp.version=2.6.1.0-129 " +
257             "org.apache.amaterasu.executor.yarn.executors.ActionsExecutorLauncher " +
258             s"'${jobManager.jobId}' '${config.master}' '${actionData.name}' '${URLEncoder.encode(taskData, "UTF-8")}' '${URLEncoder.encode(execData, "UTF-8")}' '${actionData.id}-${container.getId.getContainerId}' '$address' " +
259             s"1> ${ApplicationConstants.LOG_DIR_EXPANSION_VAR}/stdout " +
260             s"2> ${ApplicationConstants.LOG_DIR_EXPANSION_VAR}/stderr "
261         )
262
263         log.info("Running container id {}.", container.getId.getContainerId)
264         log.info("Running container id {} with command '{}'", container.getId.getContainerId, commands.last)
265
266         ctx.setCommands(commands)
267         ctx.setTokens(allTokens)
268
269         val resources = mutable.Map[String, LocalResource](
270           "executor.jar" -> executorJar,
271           "amaterasu.properties" -> propFile,
272           // TODO: Nadav/Eyal all of these should move to the executor resource setup
273           "miniconda.sh" -> setLocalResourceFromPath(Path.mergePaths(jarPath, new Path("/dist/Miniconda2-latest-Linux-x86_64.sh"))),
274           "codegen.py" -> setLocalResourceFromPath(Path.mergePaths(jarPath, new Path("/dist/codegen.py"))),
275           "runtime.py" -> setLocalResourceFromPath(Path.mergePaths(jarPath, new Path("/dist/runtime.py"))),
276           "spark-version-info.properties" -> setLocalResourceFromPath(Path.mergePaths(jarPath, new Path("/dist/spark-version-info.properties"))),
277           "spark_intp.py" -> setLocalResourceFromPath(Path.mergePaths(jarPath, new Path("/dist/spark_intp.py"))))
278
279         val frameworkFactory = FrameworkProvidersFactory(env, config)
280         val framework = frameworkFactory.getFramework(actionData.groupId)
281
282         //adding the framework and executor resources
283         setupResources(framework.getGroupIdentifier, resources, framework.getGroupIdentifier)
284         setupResources(s"${framework.getGroupIdentifier}/${actionData.typeId}", resources, s"${framework.getGroupIdentifier}-${actionData.typeId}")
285
286         ctx.setLocalResources(resources)
287
288         ctx.setEnvironment(Map[String, String](
289           "HADOOP_CONF_DIR" -> s"${config.YARN.hadoopHomeDir}/conf/",
290           "YARN_CONF_DIR" -> s"${config.YARN.hadoopHomeDir}/conf/",
291           "AMA_NODE" -> sys.env("AMA_NODE"),
292           "HADOOP_USER_NAME" -> UserGroupInformation.getCurrentUser.getUserName
293         ))
294
295         log.info(s"hadoop conf dir is ${config.YARN.hadoopHomeDir}/conf/")
296         nmClient.startContainerAsync(container, ctx)
297         actionData
298       }
299
300       containerTask onComplete {
301         case Failure(t) =>
302           log.error(s"launching container failed", t)
303           askContainer(actionData)
304
305         case Success(requestedActionData) =>
306           jobManager.actionStarted(requestedActionData.id)
307           containersIdsToTask.put(container.getId.getContainerId, requestedActionData)
308           log.info(s"launching container succeeded: ${container.getId.getContainerId}; task: ${requestedActionData.id}")
309
310       }
311     }
312   }
313
314   private def allTokens: ByteBuffer = {
315     // creating the credentials for container execution
316     val credentials = UserGroupInformation.getCurrentUser.getCredentials
317     val dob = new DataOutputBuffer
318     credentials.writeTokenStorageToStream(dob)
319
320     // removing the AM->RM token so that containers cannot access it.
321     val iter = credentials.getAllTokens.iterator
322     log.info("Executing with tokens:")
323     for (token <- iter) {
324       log.info(token.toString)
325       if (token.getKind == AMRMTokenIdentifier.KIND_NAME) iter.remove()
326     }
327     ByteBuffer.wrap(dob.getData, 0, dob.getLength)
328   }
329
330   private def setupResources(frameworkPath: String, countainerResources: mutable.Map[String, LocalResource], resourcesPath: String): Unit = {
331
332     val sourcePath = Path.mergePaths(jarPath, new Path(s"/$resourcesPath"))
333
334     if (fs.exists(sourcePath)) {
335
336       val files = fs.listFiles(sourcePath, true)
337
338       while (files.hasNext) {
339         val res = files.next()
340         val containerPath = res.getPath.toUri.getPath.replace("/apps/amaterasu/", "")
341         countainerResources.put(containerPath, setLocalResourceFromPath(res.getPath))
342       }
343     }
344   }
345
346   def stopApplication(finalApplicationStatus: FinalApplicationStatus, appMessage: String): Unit = {
347     import java.io.IOException
348
349     import org.apache.hadoop.yarn.exceptions.YarnException
350     try
351       rmClient.unregisterApplicationMaster(finalApplicationStatus, appMessage, null)
352     catch {
353       case ex: YarnException =>
354         log.error("Failed to unregister application", ex)
355       case e: IOException =>
356         log.error("Failed to unregister application", e)
357     }
358     rmClient.stop()
359     nmClient.stop()
360   }
361
362   override def onContainersCompleted(statuses: util.List[ContainerStatus]): Unit = {
363
364     for (status <- statuses.asScala) {
365
366       if (status.getState == ContainerState.COMPLETE) {
367
368         val containerId = status.getContainerId.getContainerId
369         val task = containersIdsToTask(containerId)
370         rmClient.releaseAssignedContainer(status.getContainerId)
371
372         if (status.getExitStatus == 0) {
373
374           //completedContainersAndTaskIds.put(containerId, task.id)
375           jobManager.actionComplete(task.id)
376           log.info(s"Container $containerId completed with task ${task.id} with success.")
377         } else {
378           // TODO: Check the getDiagnostics value and see if appropriate
379           jobManager.actionFailed(task.id, status.getDiagnostics)
380           log.warn(s"Container $containerId completed with task ${task.id} with failed status code (${status.getExitStatus})")
381         }
382       }
383     }
384
385     if (jobManager.outOfActions) {
386       log.info("Finished all tasks successfully! Wow!")
387       jobManager.actionsCount()
388       stopApplication(FinalApplicationStatus.SUCCEEDED, "SUCCESS")
389     } else {
390       log.info(s"jobManager.registeredActions.size: ${jobManager.registeredActions.size}; completedContainersAndTaskIds.size: ${completedContainersAndTaskIds.size}")
391     }
392   }
393
394   override def getProgress: Float = {
395     jobManager.registeredActions.size.toFloat / completedContainersAndTaskIds.size
396   }
397
398   override def onNodesUpdated(updatedNodes: util.List[NodeReport]): Unit = {
399     log.info("Nodes change. Nothing to report.")
400   }
401
402   override def onShutdownRequest(): Unit = {
403     log.error("Shutdown requested.")
404     stopApplication(FinalApplicationStatus.KILLED, "Shutdown requested")
405   }
406
407   override def onError(e: Throwable): Unit = {
408     log.error("Error on AM", e)
409     stopApplication(FinalApplicationStatus.FAILED, "Error on AM")
410   }
411
412   def initJob(args: Args): Unit = {
413
414     this.env = args.env
415     this.branch = args.branch
416     try {
417       val retryPolicy = new ExponentialBackoffRetry(1000, 3)
418       client = CuratorFrameworkFactory.newClient(config.zk, retryPolicy)
419       client.start()
420     } catch {
421       case e: Exception =>
422         log.error("Error connecting to zookeeper", e)
423         throw e
424     }
425     if (args.jobId != null && !args.jobId.isEmpty) {
426       log.info("resuming job" + args.jobId)
427       jobManager = JobLoader.reloadJob(
428         args.jobId,
429         client,
430         config.Jobs.Tasks.attempts,
431         new LinkedBlockingQueue[ActionData])
432
433     } else {
434       log.info("new job is being created")
435       try {
436
437         jobManager = JobLoader.loadJob(
438           args.repo,
439           args.branch,
440           args.newJobId,
441           client,
442           config.Jobs.Tasks.attempts,
443           new LinkedBlockingQueue[ActionData])
444       } catch {
445         case e: Exception =>
446           log.error("Error creating JobManager.", e)
447           throw e
448       }
449
450     }
451
452     jobManager.start()
453     log.info("started jobManager")
454   }
455 }
456
457 object ApplicationMaster extends App with Logging {
458
459
460   val parser = Args.getParser
461   parser.parse(args, Args()) match {
462
463     case Some(arguments: Args) =>
464       val appMaster = new ApplicationMaster()
465
466       appMaster.address = s"tcp://${appMaster.host}:$generatePort"
467       appMaster.broker.addConnector(appMaster.address)
468       appMaster.broker.start()
469
470       log.info(s"broker started with address ${appMaster.address}")
471       appMaster.execute(arguments)
472
473     case None =>
474   }
475
476   private def generatePort: Int = {
477     val socket = new ServerSocket(0)
478     val port = socket.getLocalPort
479     socket.close()
480     port
481   }
482 }