[SPARK-25921][FOLLOW UP][PYSPARK] Fix barrier task run without BarrierTaskContext...
authorYuanjian Li <xyliyuanjian@gmail.com>
Fri, 11 Jan 2019 06:28:37 +0000 (14:28 +0800)
committerHyukjin Kwon <gurwls223@apache.org>
Fri, 11 Jan 2019 06:28:37 +0000 (14:28 +0800)
## What changes were proposed in this pull request?

It's the follow-up PR for #22962, contains the following works:
- Remove `__init__` in TaskContext and BarrierTaskContext.
- Add more comments to explain the fix.
- Rewrite UT in a new class.

## How was this patch tested?

New UT in test_taskcontext.py

Closes #23435 from xuanyuanking/SPARK-25921-follow.

Authored-by: Yuanjian Li <xyliyuanjian@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
python/pyspark/taskcontext.py
python/pyspark/tests/test_taskcontext.py

index 98b505c..de4b6af 100644 (file)
@@ -48,10 +48,6 @@ class TaskContext(object):
         cls._taskContext = taskContext = object.__new__(cls)
         return taskContext
 
-    def __init__(self):
-        """Construct a TaskContext, use get instead"""
-        pass
-
     @classmethod
     def _getOrCreate(cls):
         """Internal function to get or create global TaskContext."""
@@ -140,13 +136,13 @@ class BarrierTaskContext(TaskContext):
     _port = None
     _secret = None
 
-    def __init__(self):
-        """Construct a BarrierTaskContext, use get instead"""
-        pass
-
     @classmethod
     def _getOrCreate(cls):
-        """Internal function to get or create global BarrierTaskContext."""
+        """
+        Internal function to get or create global BarrierTaskContext. We need to make sure
+        BarrierTaskContext is returned from here because it is needed in python worker reuse
+        scenario, see SPARK-25921 for more details.
+        """
         if not isinstance(cls._taskContext, BarrierTaskContext):
             cls._taskContext = object.__new__(cls)
         return cls._taskContext
index b3a9674..fdb5c40 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import os
 import random
 import sys
 import time
+import unittest
 
-from pyspark import SparkContext, TaskContext, BarrierTaskContext
+from pyspark import SparkConf, SparkContext, TaskContext, BarrierTaskContext
 from pyspark.testing.utils import PySparkTestCase
 
 
@@ -118,21 +120,6 @@ class TaskContextTests(PySparkTestCase):
         times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
         self.assertTrue(max(times) - min(times) < 1)
 
-    def test_barrier_with_python_worker_reuse(self):
-        """
-        Verify that BarrierTaskContext.barrier() with reused python worker.
-        """
-        self.sc._conf.set("spark.python.work.reuse", "true")
-        rdd = self.sc.parallelize(range(4), 4)
-        # start a normal job first to start all worker
-        result = rdd.map(lambda x: x ** 2).collect()
-        self.assertEqual([0, 1, 4, 9], result)
-        # make sure `spark.python.work.reuse=true`
-        self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true")
-
-        # worker will be reused in this barrier job
-        self.test_barrier()
-
     def test_barrier_infos(self):
         """
         Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
@@ -149,6 +136,44 @@ class TaskContextTests(PySparkTestCase):
         self.assertTrue(len(taskInfos[0]) == 4)
 
 
+class TaskContextTestsWithWorkerReuse(unittest.TestCase):
+
+    def setUp(self):
+        class_name = self.__class__.__name__
+        conf = SparkConf().set("spark.python.worker.reuse", "true")
+        self.sc = SparkContext('local[2]', class_name, conf=conf)
+
+    def test_barrier_with_python_worker_reuse(self):
+        """
+        Regression test for SPARK-25921: verify that BarrierTaskContext.barrier() with
+        reused python worker.
+        """
+        # start a normal job first to start all workers and get all worker pids
+        worker_pids = self.sc.parallelize(range(2), 2).map(lambda x: os.getpid()).collect()
+        # the worker will reuse in this barrier job
+        rdd = self.sc.parallelize(range(10), 2)
+
+        def f(iterator):
+            yield sum(iterator)
+
+        def context_barrier(x):
+            tc = BarrierTaskContext.get()
+            time.sleep(random.randint(1, 10))
+            tc.barrier()
+            return (time.time(), os.getpid())
+
+        result = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
+        times = list(map(lambda x: x[0], result))
+        pids = list(map(lambda x: x[1], result))
+        # check both barrier and worker reuse effect
+        self.assertTrue(max(times) - min(times) < 1)
+        for pid in pids:
+            self.assertTrue(pid in worker_pids)
+
+    def tearDown(self):
+        self.sc.stop()
+
+
 if __name__ == "__main__":
     import unittest
     from pyspark.tests.test_taskcontext import *