1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 from py4j.java_collections import MapConverter
19
20 from pyspark import SparkContext, RDD
21 from pyspark.mllib._common import \
22 _get_unmangled_rdd, _get_unmangled_double_vector_rdd, _serialize_double_vector, \
23 _deserialize_labeled_point, _get_unmangled_labeled_point_rdd, \
24 _deserialize_double
25 from pyspark.mllib.regression import LabeledPoint
26 from pyspark.serializers import NoOpSerializer
30
31 """
32 A decision tree model for classification or regression.
33
34 EXPERIMENTAL: This is an experimental API.
35 It will probably be modified for Spark v1.2.
36 """
37
39 """
40 :param sc: Spark context
41 :param java_model: Handle to Java model object
42 """
43 self._sc = sc
44 self._java_model = java_model
45
48
50 """
51 Predict the label of one or more examples.
52 :param x: Data point (feature vector),
53 or an RDD of data points (feature vectors).
54 """
55 pythonAPI = self._sc._jvm.PythonMLLibAPI()
56 if isinstance(x, RDD):
57
58 if x.count() == 0:
59 return self._sc.parallelize([])
60 dataBytes = _get_unmangled_double_vector_rdd(x, cache=False)
61 jSerializedPreds = \
62 pythonAPI.predictDecisionTreeModel(self._java_model,
63 dataBytes._jrdd)
64 serializedPreds = RDD(jSerializedPreds, self._sc, NoOpSerializer())
65 return serializedPreds.map(lambda bytes: _deserialize_double(bytearray(bytes)))
66 else:
67
68 x_ = _serialize_double_vector(x)
69 return pythonAPI.predictDecisionTreeModel(self._java_model, x_)
70
73
75 return self._java_model.depth()
76
78 return self._java_model.toString()
79
82
83 """
84 Learning algorithm for a decision tree model
85 for classification or regression.
86
87 EXPERIMENTAL: This is an experimental API.
88 It will probably be modified for Spark v1.2.
89
90 Example usage:
91 >>> from numpy import array
92 >>> import sys
93 >>> from pyspark.mllib.regression import LabeledPoint
94 >>> from pyspark.mllib.tree import DecisionTree
95 >>> from pyspark.mllib.linalg import SparseVector
96 >>>
97 >>> data = [
98 ... LabeledPoint(0.0, [0.0]),
99 ... LabeledPoint(1.0, [1.0]),
100 ... LabeledPoint(1.0, [2.0]),
101 ... LabeledPoint(1.0, [3.0])
102 ... ]
103 >>> categoricalFeaturesInfo = {} # no categorical features
104 >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2,
105 ... categoricalFeaturesInfo=categoricalFeaturesInfo)
106 >>> sys.stdout.write(model)
107 DecisionTreeModel classifier
108 If (feature 0 <= 0.5)
109 Predict: 0.0
110 Else (feature 0 > 0.5)
111 Predict: 1.0
112 >>> model.predict(array([1.0])) > 0
113 True
114 >>> model.predict(array([0.0])) == 0
115 True
116 >>> sparse_data = [
117 ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
118 ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
119 ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
120 ... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
121 ... ]
122 >>>
123 >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data),
124 ... categoricalFeaturesInfo=categoricalFeaturesInfo)
125 >>> model.predict(array([0.0, 1.0])) == 1
126 True
127 >>> model.predict(array([0.0, 0.0])) == 0
128 True
129 >>> model.predict(SparseVector(2, {1: 1.0})) == 1
130 True
131 >>> model.predict(SparseVector(2, {1: 0.0})) == 0
132 True
133 """
134
135 @staticmethod
136 - def trainClassifier(data, numClasses, categoricalFeaturesInfo,
137 impurity="gini", maxDepth=4, maxBins=100):
138 """
139 Train a DecisionTreeModel for classification.
140
141 :param data: Training data: RDD of LabeledPoint.
142 Labels are integers {0,1,...,numClasses}.
143 :param numClasses: Number of classes for classification.
144 :param categoricalFeaturesInfo: Map from categorical feature index
145 to number of categories.
146 Any feature not in this map
147 is treated as continuous.
148 :param impurity: Supported values: "entropy" or "gini"
149 :param maxDepth: Max depth of tree.
150 E.g., depth 0 means 1 leaf node.
151 Depth 1 means 1 internal node + 2 leaf nodes.
152 :param maxBins: Number of bins used for finding splits at each node.
153 :return: DecisionTreeModel
154 """
155 sc = data.context
156 dataBytes = _get_unmangled_labeled_point_rdd(data)
157 categoricalFeaturesInfoJMap = \
158 MapConverter().convert(categoricalFeaturesInfo,
159 sc._gateway._gateway_client)
160 model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
161 dataBytes._jrdd, "classification",
162 numClasses, categoricalFeaturesInfoJMap,
163 impurity, maxDepth, maxBins)
164 dataBytes.unpersist()
165 return DecisionTreeModel(sc, model)
166
167 @staticmethod
168 - def trainRegressor(data, categoricalFeaturesInfo,
169 impurity="variance", maxDepth=4, maxBins=100):
170 """
171 Train a DecisionTreeModel for regression.
172
173 :param data: Training data: RDD of LabeledPoint.
174 Labels are real numbers.
175 :param categoricalFeaturesInfo: Map from categorical feature index
176 to number of categories.
177 Any feature not in this map
178 is treated as continuous.
179 :param impurity: Supported values: "variance"
180 :param maxDepth: Max depth of tree.
181 E.g., depth 0 means 1 leaf node.
182 Depth 1 means 1 internal node + 2 leaf nodes.
183 :param maxBins: Number of bins used for finding splits at each node.
184 :return: DecisionTreeModel
185 """
186 sc = data.context
187 dataBytes = _get_unmangled_labeled_point_rdd(data)
188 categoricalFeaturesInfoJMap = \
189 MapConverter().convert(categoricalFeaturesInfo,
190 sc._gateway._gateway_client)
191 model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
192 dataBytes._jrdd, "regression",
193 0, categoricalFeaturesInfoJMap,
194 impurity, maxDepth, maxBins)
195 dataBytes.unpersist()
196 return DecisionTreeModel(sc, model)
197
200 import doctest
201 globs = globals().copy()
202 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
203 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
204 globs['sc'].stop()
205 if failure_count:
206 exit(-1)
207
208 if __name__ == "__main__":
209 _test()
210