1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 import os
19 import shutil
20 import sys
21 from threading import Lock
22 from tempfile import NamedTemporaryFile
23 from collections import namedtuple
24
25 from pyspark import accumulators
26 from pyspark.accumulators import Accumulator
27 from pyspark.broadcast import Broadcast
28 from pyspark.conf import SparkConf
29 from pyspark.files import SparkFiles
30 from pyspark.java_gateway import launch_gateway
31 from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
32 PairDeserializer
33 from pyspark.storagelevel import StorageLevel
34 from pyspark import rdd
35 from pyspark.rdd import RDD
36
37 from py4j.java_collections import ListConverter
38
39
40 -class SparkContext(object):
41 """
42 Main entry point for Spark functionality. A SparkContext represents the
43 connection to a Spark cluster, and can be used to create L{RDD}s and
44 broadcast variables on that cluster.
45 """
46
47 _gateway = None
48 _jvm = None
49 _writeToFile = None
50 _next_accum_id = 0
51 _active_spark_context = None
52 _lock = Lock()
53 _python_includes = None
54
55
56 - def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
57 environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None,
58 gateway=None):
59 """
60 Create a new SparkContext. At least the master and app name should be set,
61 either through the named parameters here or through C{conf}.
62
63 @param master: Cluster URL to connect to
64 (e.g. mesos://host:port, spark://host:port, local[4]).
65 @param appName: A name for your job, to display on the cluster web UI.
66 @param sparkHome: Location where Spark is installed on cluster nodes.
67 @param pyFiles: Collection of .zip or .py files to send to the cluster
68 and add to PYTHONPATH. These can be paths on the local file
69 system or HDFS, HTTP, HTTPS, or FTP URLs.
70 @param environment: A dictionary of environment variables to set on
71 worker nodes.
72 @param batchSize: The number of Python objects represented as a single
73 Java object. Set 1 to disable batching or -1 to use an
74 unlimited batch size.
75 @param serializer: The serializer for RDDs.
76 @param conf: A L{SparkConf} object setting Spark properties.
77 @param gateway: Use an existing gateway and JVM, otherwise a new JVM
78 will be instatiated.
79
80
81 >>> from pyspark.context import SparkContext
82 >>> sc = SparkContext('local', 'test')
83
84 >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL
85 Traceback (most recent call last):
86 ...
87 ValueError:...
88 """
89 if rdd._extract_concise_traceback() is not None:
90 self._callsite = rdd._extract_concise_traceback()
91 else:
92 tempNamedTuple = namedtuple("Callsite", "function file linenum")
93 self._callsite = tempNamedTuple(function=None, file=None, linenum=None)
94 SparkContext._ensure_initialized(self, gateway=gateway)
95
96 self.environment = environment or {}
97 self._conf = conf or SparkConf(_jvm=self._jvm)
98 self._batchSize = batchSize
99 self._unbatched_serializer = serializer
100 if batchSize == 1:
101 self.serializer = self._unbatched_serializer
102 else:
103 self.serializer = BatchedSerializer(self._unbatched_serializer,
104 batchSize)
105
106
107 if master:
108 self._conf.setMaster(master)
109 if appName:
110 self._conf.setAppName(appName)
111 if sparkHome:
112 self._conf.setSparkHome(sparkHome)
113 if environment:
114 for key, value in environment.iteritems():
115 self._conf.setExecutorEnv(key, value)
116
117
118 if not self._conf.contains("spark.master"):
119 raise Exception("A master URL must be set in your configuration")
120 if not self._conf.contains("spark.app.name"):
121 raise Exception("An application name must be set in your configuration")
122
123
124
125 self.master = self._conf.get("spark.master")
126 self.appName = self._conf.get("spark.app.name")
127 self.sparkHome = self._conf.get("spark.home", None)
128 for (k, v) in self._conf.getAll():
129 if k.startswith("spark.executorEnv."):
130 varName = k[len("spark.executorEnv."):]
131 self.environment[varName] = v
132
133
134 self._jsc = self._initialize_context(self._conf._jconf)
135
136
137
138 self._accumulatorServer = accumulators._start_update_server()
139 (host, port) = self._accumulatorServer.server_address
140 self._javaAccumulator = self._jsc.accumulator(
141 self._jvm.java.util.ArrayList(),
142 self._jvm.PythonAccumulatorParam(host, port))
143
144 self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
145
146
147
148
149
150 self._pickled_broadcast_vars = set()
151
152 SparkFiles._sc = self
153 root_dir = SparkFiles.getRootDirectory()
154 sys.path.append(root_dir)
155
156
157 self._python_includes = list()
158 for path in (pyFiles or []):
159 self.addPyFile(path)
160
161
162
163 for path in self._conf.get("spark.submit.pyFiles", "").split(","):
164 if path != "":
165 (dirname, filename) = os.path.split(path)
166 self._python_includes.append(filename)
167 sys.path.append(path)
168 if not dirname in sys.path:
169 sys.path.append(dirname)
170
171
172 local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
173 self._temp_dir = \
174 self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
175
176 - def _initialize_context(self, jconf):
177 """
178 Initialize SparkContext in function to allow subclass specific initialization
179 """
180 return self._jvm.JavaSparkContext(jconf)
181
182 @classmethod
183 - def _ensure_initialized(cls, instance=None, gateway=None):
184 """
185 Checks whether a SparkContext is initialized or not.
186 Throws error if a SparkContext is already running.
187 """
188 with SparkContext._lock:
189 if not SparkContext._gateway:
190 SparkContext._gateway = gateway or launch_gateway()
191 SparkContext._jvm = SparkContext._gateway.jvm
192 SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
193
194 if instance:
195 if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
196 currentMaster = SparkContext._active_spark_context.master
197 currentAppName = SparkContext._active_spark_context.appName
198 callsite = SparkContext._active_spark_context._callsite
199
200
201 raise ValueError("Cannot run multiple SparkContexts at once; existing SparkContext(app=%s, master=%s)" \
202 " created by %s at %s:%s " \
203 % (currentAppName, currentMaster, callsite.function, callsite.file, callsite.linenum))
204 else:
205 SparkContext._active_spark_context = instance
206
207 @classmethod
208 - def setSystemProperty(cls, key, value):
209 """
210 Set a Java system property, such as spark.executor.memory. This must
211 must be invoked before instantiating SparkContext.
212 """
213 SparkContext._ensure_initialized()
214 SparkContext._jvm.java.lang.System.setProperty(key, value)
215
216 @property
218 """
219 Default level of parallelism to use when not given by user (e.g. for
220 reduce tasks)
221 """
222 return self._jsc.sc().defaultParallelism()
223
224 @property
226 """
227 Default min number of partitions for Hadoop RDDs when not given by user
228 """
229 return self._jsc.sc().defaultMinPartitions()
230
233
235 """
236 Shut down the SparkContext.
237 """
238 if self._jsc:
239 self._jsc.stop()
240 self._jsc = None
241 if self._accumulatorServer:
242 self._accumulatorServer.shutdown()
243 self._accumulatorServer = None
244 with SparkContext._lock:
245 SparkContext._active_spark_context = None
246
247 - def parallelize(self, c, numSlices=None):
248 """
249 Distribute a local Python collection to form an RDD.
250
251 >>> sc.parallelize(range(5), 5).glom().collect()
252 [[0], [1], [2], [3], [4]]
253 """
254 numSlices = numSlices or self.defaultParallelism
255
256
257
258 tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
259
260 if "__len__" not in dir(c):
261 c = list(c)
262 batchSize = min(len(c) // numSlices, self._batchSize)
263 if batchSize > 1:
264 serializer = BatchedSerializer(self._unbatched_serializer,
265 batchSize)
266 else:
267 serializer = self._unbatched_serializer
268 serializer.dump_stream(c, tempFile)
269 tempFile.close()
270 readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
271 jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
272 return RDD(jrdd, self, serializer)
273
274 - def textFile(self, name, minPartitions=None):
275 """
276 Read a text file from HDFS, a local file system (available on all
277 nodes), or any Hadoop-supported file system URI, and return it as an
278 RDD of Strings.
279
280 >>> path = os.path.join(tempdir, "sample-text.txt")
281 >>> with open(path, "w") as testFile:
282 ... testFile.write("Hello world!")
283 >>> textFile = sc.textFile(path)
284 >>> textFile.collect()
285 [u'Hello world!']
286 """
287 minPartitions = minPartitions or min(self.defaultParallelism, 2)
288 return RDD(self._jsc.textFile(name, minPartitions), self,
289 UTF8Deserializer())
290
291 - def wholeTextFiles(self, path, minPartitions=None):
292 """
293 Read a directory of text files from HDFS, a local file system
294 (available on all nodes), or any Hadoop-supported file system
295 URI. Each file is read as a single record and returned in a
296 key-value pair, where the key is the path of each file, the
297 value is the content of each file.
298
299 For example, if you have the following files::
300
301 hdfs://a-hdfs-path/part-00000
302 hdfs://a-hdfs-path/part-00001
303 ...
304 hdfs://a-hdfs-path/part-nnnnn
305
306 Do C{rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")},
307 then C{rdd} contains::
308
309 (a-hdfs-path/part-00000, its content)
310 (a-hdfs-path/part-00001, its content)
311 ...
312 (a-hdfs-path/part-nnnnn, its content)
313
314 NOTE: Small files are preferred, as each file will be loaded
315 fully in memory.
316
317 >>> dirPath = os.path.join(tempdir, "files")
318 >>> os.mkdir(dirPath)
319 >>> with open(os.path.join(dirPath, "1.txt"), "w") as file1:
320 ... file1.write("1")
321 >>> with open(os.path.join(dirPath, "2.txt"), "w") as file2:
322 ... file2.write("2")
323 >>> textFiles = sc.wholeTextFiles(dirPath)
324 >>> sorted(textFiles.collect())
325 [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')]
326 """
327 minPartitions = minPartitions or self.defaultMinPartitions
328 return RDD(self._jsc.wholeTextFiles(path, minPartitions), self,
329 PairDeserializer(UTF8Deserializer(), UTF8Deserializer()))
330
331 - def _checkpointFile(self, name, input_deserializer):
332 jrdd = self._jsc.checkpointFile(name)
333 return RDD(jrdd, self, input_deserializer)
334
335 - def union(self, rdds):
336 """
337 Build the union of a list of RDDs.
338
339 This supports unions() of RDDs with different serialized formats,
340 although this forces them to be reserialized using the default
341 serializer:
342
343 >>> path = os.path.join(tempdir, "union-text.txt")
344 >>> with open(path, "w") as testFile:
345 ... testFile.write("Hello")
346 >>> textFile = sc.textFile(path)
347 >>> textFile.collect()
348 [u'Hello']
349 >>> parallelized = sc.parallelize(["World!"])
350 >>> sorted(sc.union([textFile, parallelized]).collect())
351 [u'Hello', 'World!']
352 """
353 first_jrdd_deserializer = rdds[0]._jrdd_deserializer
354 if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
355 rdds = [x._reserialize() for x in rdds]
356 first = rdds[0]._jrdd
357 rest = [x._jrdd for x in rdds[1:]]
358 rest = ListConverter().convert(rest, self._gateway._gateway_client)
359 return RDD(self._jsc.union(first, rest), self,
360 rdds[0]._jrdd_deserializer)
361
362 - def broadcast(self, value):
363 """
364 Broadcast a read-only variable to the cluster, returning a
365 L{Broadcast<pyspark.broadcast.Broadcast>}
366 object for reading it in distributed functions. The variable will be
367 sent to each cluster only once.
368 """
369 pickleSer = PickleSerializer()
370 pickled = pickleSer.dumps(value)
371 jbroadcast = self._jsc.broadcast(bytearray(pickled))
372 return Broadcast(jbroadcast.id(), value, jbroadcast,
373 self._pickled_broadcast_vars)
374
375 - def accumulator(self, value, accum_param=None):
376 """
377 Create an L{Accumulator} with the given initial value, using a given
378 L{AccumulatorParam} helper object to define how to add values of the
379 data type if provided. Default AccumulatorParams are used for integers
380 and floating-point numbers if you do not provide one. For other types,
381 a custom AccumulatorParam can be used.
382 """
383 if accum_param is None:
384 if isinstance(value, int):
385 accum_param = accumulators.INT_ACCUMULATOR_PARAM
386 elif isinstance(value, float):
387 accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
388 elif isinstance(value, complex):
389 accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
390 else:
391 raise Exception("No default accumulator param for type %s" % type(value))
392 SparkContext._next_accum_id += 1
393 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
394
395 - def addFile(self, path):
396 """
397 Add a file to be downloaded with this Spark job on every node.
398 The C{path} passed can be either a local file, a file in HDFS
399 (or other Hadoop-supported filesystems), or an HTTP, HTTPS or
400 FTP URI.
401
402 To access the file in Spark jobs, use
403 L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its
404 download location.
405
406 >>> from pyspark import SparkFiles
407 >>> path = os.path.join(tempdir, "test.txt")
408 >>> with open(path, "w") as testFile:
409 ... testFile.write("100")
410 >>> sc.addFile(path)
411 >>> def func(iterator):
412 ... with open(SparkFiles.get("test.txt")) as testFile:
413 ... fileVal = int(testFile.readline())
414 ... return [x * 100 for x in iterator]
415 >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
416 [100, 200, 300, 400]
417 """
418 self._jsc.sc().addFile(path)
419
420 - def clearFiles(self):
421 """
422 Clear the job's list of files added by L{addFile} or L{addPyFile} so
423 that they do not get downloaded to any new nodes.
424 """
425
426 self._jsc.sc().clearFiles()
427
428 - def addPyFile(self, path):
429 """
430 Add a .py or .zip dependency for all tasks to be executed on this
431 SparkContext in the future. The C{path} passed can be either a local
432 file, a file in HDFS (or other Hadoop-supported filesystems), or an
433 HTTP, HTTPS or FTP URI.
434 """
435 self.addFile(path)
436 (dirname, filename) = os.path.split(path)
437
438 if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
439 self._python_includes.append(filename)
440 sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename))
441
442 - def setCheckpointDir(self, dirName):
443 """
444 Set the directory under which RDDs are going to be checkpointed. The
445 directory must be a HDFS path if running on a cluster.
446 """
447 self._jsc.sc().setCheckpointDir(dirName)
448
449 - def _getJavaStorageLevel(self, storageLevel):
450 """
451 Returns a Java StorageLevel based on a pyspark.StorageLevel.
452 """
453 if not isinstance(storageLevel, StorageLevel):
454 raise Exception("storageLevel must be of type pyspark.StorageLevel")
455
456 newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel
457 return newStorageLevel(storageLevel.useDisk,
458 storageLevel.useMemory,
459 storageLevel.useOffHeap,
460 storageLevel.deserialized,
461 storageLevel.replication)
462
463 - def setJobGroup(self, groupId, description, interruptOnCancel=False):
464 """
465 Assigns a group ID to all the jobs started by this thread until the group ID is set to a
466 different value or cleared.
467
468 Often, a unit of execution in an application consists of multiple Spark actions or jobs.
469 Application programmers can use this method to group all those jobs together and give a
470 group description. Once set, the Spark web UI will associate such jobs with this group.
471
472 The application can use L{SparkContext.cancelJobGroup} to cancel all
473 running jobs in this group.
474
475 >>> import thread, threading
476 >>> from time import sleep
477 >>> result = "Not Set"
478 >>> lock = threading.Lock()
479 >>> def map_func(x):
480 ... sleep(100)
481 ... raise Exception("Task should have been cancelled")
482 >>> def start_job(x):
483 ... global result
484 ... try:
485 ... sc.setJobGroup("job_to_cancel", "some description")
486 ... result = sc.parallelize(range(x)).map(map_func).collect()
487 ... except Exception as e:
488 ... result = "Cancelled"
489 ... lock.release()
490 >>> def stop_job():
491 ... sleep(5)
492 ... sc.cancelJobGroup("job_to_cancel")
493 >>> supress = lock.acquire()
494 >>> supress = thread.start_new_thread(start_job, (10,))
495 >>> supress = thread.start_new_thread(stop_job, tuple())
496 >>> supress = lock.acquire()
497 >>> print result
498 Cancelled
499
500 If interruptOnCancel is set to true for the job group, then job cancellation will result
501 in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure
502 that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208,
503 where HDFS may respond to Thread.interrupt() by marking nodes as dead.
504 """
505 self._jsc.setJobGroup(groupId, description, interruptOnCancel)
506
507 - def setLocalProperty(self, key, value):
508 """
509 Set a local property that affects jobs submitted from this thread, such as the
510 Spark fair scheduler pool.
511 """
512 self._jsc.setLocalProperty(key, value)
513
514 - def getLocalProperty(self, key):
515 """
516 Get a local property set in this thread, or null if it is missing. See
517 L{setLocalProperty}
518 """
519 return self._jsc.getLocalProperty(key)
520
521 - def sparkUser(self):
522 """
523 Get SPARK_USER for user who is running SparkContext.
524 """
525 return self._jsc.sc().sparkUser()
526
527 - def cancelJobGroup(self, groupId):
528 """
529 Cancel active jobs for the specified group. See L{SparkContext.setJobGroup}
530 for more information.
531 """
532 self._jsc.sc().cancelJobGroup(groupId)
533
534 - def cancelAllJobs(self):
535 """
536 Cancel all jobs that have been scheduled or are running.
537 """
538 self._jsc.sc().cancelAllJobs()
539
541 import atexit
542 import doctest
543 import tempfile
544 globs = globals().copy()
545 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
546 globs['tempdir'] = tempfile.mkdtemp()
547 atexit.register(lambda: shutil.rmtree(globs['tempdir']))
548 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
549 globs['sc'].stop()
550 if failure_count:
551 exit(-1)
552
553
554 if __name__ == "__main__":
555 _test()
556