1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 from pyspark import SparkContext
19 from pyspark.mllib._common import \
20 _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
21 _serialize_double_matrix, _deserialize_double_matrix, \
22 _serialize_double_vector, _deserialize_double_vector, \
23 _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
24 _serialize_tuple, RatingDeserializer
25 from pyspark.rdd import RDD
29 """A matrix factorisation model trained by regularized alternating
30 least-squares.
31
32 >>> r1 = (1, 1, 1.0)
33 >>> r2 = (1, 2, 2.0)
34 >>> r3 = (2, 1, 2.0)
35 >>> ratings = sc.parallelize([r1, r2, r3])
36 >>> model = ALS.trainImplicit(ratings, 1)
37 >>> model.predict(2,2) is not None
38 True
39 >>> testset = sc.parallelize([(1, 2), (1, 1)])
40 >>> model.predictAll(testset).count() == 2
41 True
42 """
43
45 self._context = sc
46 self._java_model = java_model
47
49 self._context._gateway.detach(self._java_model)
50
52 return self._java_model.predict(user, product)
53
55 usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
56 return RDD(self._java_model.predict(usersProductsJRDD._jrdd),
57 self._context, RatingDeserializer())
58
59
60 -class ALS(object):
61 @classmethod
62 - def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
63 sc = ratings.context
64 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
65 mod = sc._jvm.PythonMLLibAPI().trainALSModel(
66 ratingBytes._jrdd, rank, iterations, lambda_, blocks)
67 return MatrixFactorizationModel(sc, mod)
68
69 @classmethod
70 - def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
71 sc = ratings.context
72 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
73 mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(
74 ratingBytes._jrdd, rank, iterations, lambda_, blocks, alpha)
75 return MatrixFactorizationModel(sc, mod)
76
79 import doctest
80 globs = globals().copy()
81 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
82 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
83 globs['sc'].stop()
84 if failure_count:
85 exit(-1)
86
87
88 if __name__ == "__main__":
89 _test()
90