1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 """
19 >>> from pyspark.context import SparkContext
20 >>> sc = SparkContext('local', 'test')
21 >>> a = sc.accumulator(1)
22 >>> a.value
23 1
24 >>> a.value = 2
25 >>> a.value
26 2
27 >>> a += 5
28 >>> a.value
29 7
30
31 >>> sc.accumulator(1.0).value
32 1.0
33
34 >>> sc.accumulator(1j).value
35 1j
36
37 >>> rdd = sc.parallelize([1,2,3])
38 >>> def f(x):
39 ... global a
40 ... a += x
41 >>> rdd.foreach(f)
42 >>> a.value
43 13
44
45 >>> b = sc.accumulator(0)
46 >>> def g(x):
47 ... b.add(x)
48 >>> rdd.foreach(g)
49 >>> b.value
50 6
51
52 >>> from pyspark.accumulators import AccumulatorParam
53 >>> class VectorAccumulatorParam(AccumulatorParam):
54 ... def zero(self, value):
55 ... return [0.0] * len(value)
56 ... def addInPlace(self, val1, val2):
57 ... for i in xrange(len(val1)):
58 ... val1[i] += val2[i]
59 ... return val1
60 >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
61 >>> va.value
62 [1.0, 2.0, 3.0]
63 >>> def g(x):
64 ... global va
65 ... va += [x] * 3
66 >>> rdd.foreach(g)
67 >>> va.value
68 [7.0, 8.0, 9.0]
69
70 >>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL
71 Traceback (most recent call last):
72 ...
73 Py4JJavaError:...
74
75 >>> def h(x):
76 ... global a
77 ... a.value = 7
78 >>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL
79 Traceback (most recent call last):
80 ...
81 Py4JJavaError:...
82
83 >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL
84 Traceback (most recent call last):
85 ...
86 Exception:...
87 """
88
89 import select
90 import struct
91 import SocketServer
92 import threading
93 from pyspark.cloudpickle import CloudPickler
94 from pyspark.serializers import read_int, PickleSerializer
95
96
97 pickleSer = PickleSerializer()
98
99
100
101 _accumulatorRegistry = {}
110
113
114 """
115 A shared variable that can be accumulated, i.e., has a commutative and associative "add"
116 operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=}
117 operator, but only the driver program is allowed to access its value, using C{value}.
118 Updates from the workers get propagated automatically to the driver program.
119
120 While C{SparkContext} supports accumulators for primitive data types like C{int} and
121 C{float}, users can also define accumulators for custom types by providing a custom
122 L{AccumulatorParam} object. Refer to the doctest of this module for an example.
123 """
124
125 - def __init__(self, aid, value, accum_param):
133
135 """Custom serialization; saves the zero value from our AccumulatorParam"""
136 param = self.accum_param
137 return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
138
139 @property
141 """Get the accumulator's value; only usable in driver program"""
142 if self._deserialized:
143 raise Exception("Accumulator.value cannot be accessed inside tasks")
144 return self._value
145
146 @value.setter
148 """Sets the accumulator's value; only usable in driver program"""
149 if self._deserialized:
150 raise Exception("Accumulator.value cannot be accessed inside tasks")
151 self._value = value
152
153 - def add(self, term):
154 """Adds a term to this accumulator's value"""
155 self._value = self.accum_param.addInPlace(self._value, term)
156
158 """The += operator; adds a term to this accumulator's value"""
159 self.add(term)
160 return self
161
163 return str(self._value)
164
166 return "Accumulator<id=%i, value=%s>" % (self.aid, self._value)
167
170
171 """
172 Helper object that defines how to accumulate values of a given type.
173 """
174
175 - def zero(self, value):
176 """
177 Provide a "zero value" for the type, compatible in dimensions with the
178 provided C{value} (e.g., a zero vector)
179 """
180 raise NotImplementedError
181
183 """
184 Add two values of the accumulator's data type, returning a new value;
185 for efficiency, can also update C{value1} in place and return it.
186 """
187 raise NotImplementedError
188
191
192 """
193 An AccumulatorParam that uses the + operators to add values. Designed for simple types
194 such as integers, floats, and lists. Requires the zero value for the underlying type
195 as a parameter.
196 """
197
199 self.zero_value = zero_value
200
201 - def zero(self, value):
202 return self.zero_value
203
205 value1 += value2
206 return value1
207
208
209
210 INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
211 FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
212 COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
216
217 """
218 This handler will keep polling updates from the same socket until the
219 server is shutdown.
220 """
221
234
237
238 """
239 A simple TCP server that intercepts shutdown() in order to interrupt
240 our continuous polling on the handler.
241 """
242 server_shutdown = False
243
247
250 """Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
251 server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler)
252 thread = threading.Thread(target=server.serve_forever)
253 thread.daemon = True
254 thread.start()
255 return server
256