1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 from pyspark import SparkContext
19 from pyspark.mllib._common import \
20 _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
21 _serialize_double_matrix, _deserialize_double_matrix, \
22 _serialize_double_vector, _deserialize_double_vector, \
23 _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
24 _serialize_tuple, RatingDeserializer
25 from pyspark.rdd import RDD
29
30 """A matrix factorisation model trained by regularized alternating
31 least-squares.
32
33 >>> r1 = (1, 1, 1.0)
34 >>> r2 = (1, 2, 2.0)
35 >>> r3 = (2, 1, 2.0)
36 >>> ratings = sc.parallelize([r1, r2, r3])
37 >>> model = ALS.trainImplicit(ratings, 1)
38 >>> model.predict(2,2) is not None
39 True
40 >>> testset = sc.parallelize([(1, 2), (1, 1)])
41 >>> model.predictAll(testset).count() == 2
42 True
43 """
44
46 self._context = sc
47 self._java_model = java_model
48
50 self._context._gateway.detach(self._java_model)
51
53 return self._java_model.predict(user, product)
54
56 usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
57 return RDD(self._java_model.predict(usersProductsJRDD._jrdd),
58 self._context, RatingDeserializer())
59
60
61 -class ALS(object):
62
63 @classmethod
64 - def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
65 sc = ratings.context
66 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
67 mod = sc._jvm.PythonMLLibAPI().trainALSModel(
68 ratingBytes._jrdd, rank, iterations, lambda_, blocks)
69 return MatrixFactorizationModel(sc, mod)
70
71 @classmethod
72 - def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
73 sc = ratings.context
74 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
75 mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(
76 ratingBytes._jrdd, rank, iterations, lambda_, blocks, alpha)
77 return MatrixFactorizationModel(sc, mod)
78
81 import doctest
82 globs = globals().copy()
83 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
84 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
85 globs['sc'].stop()
86 if failure_count:
87 exit(-1)
88
89
90 if __name__ == "__main__":
91 _test()
92