1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 from pyspark.rdd import RDD
19
20 from py4j.protocol import Py4JError
21
22 __all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]
23
24
25 -class SQLContext:
26 """Main entry point for SparkSQL functionality.
27
28 A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as
29 tables, execute SQL over tables, cache tables, and read parquet files.
30 """
31
32 - def __init__(self, sparkContext, sqlContext = None):
33 """Create a new SQLContext.
34
35 @param sparkContext: The SparkContext to wrap.
36
37 >>> srdd = sqlCtx.inferSchema(rdd)
38 >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
39 Traceback (most recent call last):
40 ...
41 ValueError:...
42
43 >>> bad_rdd = sc.parallelize([1,2,3])
44 >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
45 Traceback (most recent call last):
46 ...
47 ValueError:...
48
49 >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
50 ... "boolean" : True}])
51 >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
52 ... x.boolean))
53 >>> srdd.collect()[0]
54 (1, u'string', 1.0, 1, True)
55 """
56 self._sc = sparkContext
57 self._jsc = self._sc._jsc
58 self._jvm = self._sc._jvm
59 self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap
60
61 if sqlContext:
62 self._scala_SQLContext = sqlContext
63
64 @property
65 - def _ssql_ctx(self):
66 """Accessor for the JVM SparkSQL context.
67
68 Subclasses can override this property to provide their own
69 JVM Contexts.
70 """
71 if not hasattr(self, '_scala_SQLContext'):
72 self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
73 return self._scala_SQLContext
74
75 - def inferSchema(self, rdd):
76 """Infer and apply a schema to an RDD of L{dict}s.
77
78 We peek at the first row of the RDD to determine the fields names
79 and types, and then use that to extract all the dictionaries.
80
81 >>> srdd = sqlCtx.inferSchema(rdd)
82 >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
83 ... {"field1" : 3, "field2": "row3"}]
84 True
85 """
86 if (rdd.__class__ is SchemaRDD):
87 raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
88 elif not isinstance(rdd.first(), dict):
89 raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" %
90 (SchemaRDD.__name__, rdd.first()))
91
92 jrdd = self._pythonToJavaMap(rdd._jrdd)
93 srdd = self._ssql_ctx.inferSchema(jrdd.rdd())
94 return SchemaRDD(srdd, self)
95
96 - def registerRDDAsTable(self, rdd, tableName):
97 """Registers the given RDD as a temporary table in the catalog.
98
99 Temporary tables exist only during the lifetime of this instance of
100 SQLContext.
101
102 >>> srdd = sqlCtx.inferSchema(rdd)
103 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
104 """
105 if (rdd.__class__ is SchemaRDD):
106 jschema_rdd = rdd._jschema_rdd
107 self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
108 else:
109 raise ValueError("Can only register SchemaRDD as table")
110
111 - def parquetFile(self, path):
112 """Loads a Parquet file, returning the result as a L{SchemaRDD}.
113
114 >>> import tempfile, shutil
115 >>> parquetFile = tempfile.mkdtemp()
116 >>> shutil.rmtree(parquetFile)
117 >>> srdd = sqlCtx.inferSchema(rdd)
118 >>> srdd.saveAsParquetFile(parquetFile)
119 >>> srdd2 = sqlCtx.parquetFile(parquetFile)
120 >>> srdd.collect() == srdd2.collect()
121 True
122 """
123 jschema_rdd = self._ssql_ctx.parquetFile(path)
124 return SchemaRDD(jschema_rdd, self)
125
126 - def sql(self, sqlQuery):
127 """Return a L{SchemaRDD} representing the result of the given query.
128
129 >>> srdd = sqlCtx.inferSchema(rdd)
130 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
131 >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
132 >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"},
133 ... {"f1" : 3, "f2": "row3"}]
134 True
135 """
136 return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
137
138 - def table(self, tableName):
139 """Returns the specified table as a L{SchemaRDD}.
140
141 >>> srdd = sqlCtx.inferSchema(rdd)
142 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
143 >>> srdd2 = sqlCtx.table("table1")
144 >>> srdd.collect() == srdd2.collect()
145 True
146 """
147 return SchemaRDD(self._ssql_ctx.table(tableName), self)
148
149 - def cacheTable(self, tableName):
150 """Caches the specified table in-memory."""
151 self._ssql_ctx.cacheTable(tableName)
152
153 - def uncacheTable(self, tableName):
154 """Removes the specified table from the in-memory cache."""
155 self._ssql_ctx.uncacheTable(tableName)
156
157
158 -class HiveContext(SQLContext):
159 """A variant of Spark SQL that integrates with data stored in Hive.
160
161 Configuration for Hive is read from hive-site.xml on the classpath.
162 It supports running both SQL and HiveQL commands.
163 """
164
165 @property
166 - def _ssql_ctx(self):
167 try:
168 if not hasattr(self, '_scala_HiveContext'):
169 self._scala_HiveContext = self._get_hive_ctx()
170 return self._scala_HiveContext
171 except Py4JError as e:
172 raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \
173 "sbt/sbt assembly" , e)
174
175 - def _get_hive_ctx(self):
176 return self._jvm.HiveContext(self._jsc.sc())
177
178 - def hiveql(self, hqlQuery):
179 """
180 Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
181 """
182 return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
183
184 - def hql(self, hqlQuery):
185 """
186 Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
187 """
188 return self.hiveql(hqlQuery)
189
190
191 -class LocalHiveContext(HiveContext):
192 """Starts up an instance of hive where metadata is stored locally.
193
194 An in-process metadata data is created with data stored in ./metadata.
195 Warehouse data is stored in in ./warehouse.
196
197 >>> import os
198 >>> hiveCtx = LocalHiveContext(sc)
199 >>> try:
200 ... supress = hiveCtx.hql("DROP TABLE src")
201 ... except Exception:
202 ... pass
203 >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt')
204 >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
205 >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1)
206 >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1]))
207 >>> num = results.count()
208 >>> reduce_sum = results.reduce(lambda x, y: x + y)
209 >>> num
210 500
211 >>> reduce_sum
212 130091
213 """
214
215 - def _get_hive_ctx(self):
216 return self._jvm.LocalHiveContext(self._jsc.sc())
217
218
219 -class TestHiveContext(HiveContext):
220
221 - def _get_hive_ctx(self):
222 return self._jvm.TestHiveContext(self._jsc.sc())
223
224
225
226
227 -class Row(dict):
228 """A row in L{SchemaRDD}.
229
230 An extended L{dict} that takes a L{dict} in its constructor, and
231 exposes those items as fields.
232
233 >>> r = Row({"hello" : "world", "foo" : "bar"})
234 >>> r.hello
235 'world'
236 >>> r.foo
237 'bar'
238 """
239
241 d.update(self.__dict__)
242 self.__dict__ = d
243 dict.__init__(self, d)
244
247 """An RDD of L{Row} objects that has an associated schema.
248
249 The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can
250 utilize the relational query api exposed by SparkSQL.
251
252 For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the
253 L{SchemaRDD} is not operated on directly, as it's underlying
254 implementation is a RDD composed of Java objects. Instead it is
255 converted to a PythonRDD in the JVM, on which Python operations can
256 be done.
257 """
258
259 - def __init__(self, jschema_rdd, sql_ctx):
260 self.sql_ctx = sql_ctx
261 self._sc = sql_ctx._sc
262 self._jschema_rdd = jschema_rdd
263
264 self.is_cached = False
265 self.is_checkpointed = False
266 self.ctx = self.sql_ctx._sc
267 self._jrdd_deserializer = self.ctx.serializer
268
269 @property
271 """Lazy evaluation of PythonRDD object.
272
273 Only done when a user calls methods defined by the
274 L{pyspark.rdd.RDD} super class (map, filter, etc.).
275 """
276 if not hasattr(self, '_lazy_jrdd'):
277 self._lazy_jrdd = self._toPython()._jrdd
278 return self._lazy_jrdd
279
280 @property
282 return self._jrdd.id()
283
285 """Save the contents as a Parquet file, preserving the schema.
286
287 Files that are written out using this method can be read back in as
288 a SchemaRDD using the L{SQLContext.parquetFile} method.
289
290 >>> import tempfile, shutil
291 >>> parquetFile = tempfile.mkdtemp()
292 >>> shutil.rmtree(parquetFile)
293 >>> srdd = sqlCtx.inferSchema(rdd)
294 >>> srdd.saveAsParquetFile(parquetFile)
295 >>> srdd2 = sqlCtx.parquetFile(parquetFile)
296 >>> srdd2.collect() == srdd.collect()
297 True
298 """
299 self._jschema_rdd.saveAsParquetFile(path)
300
302 """Registers this RDD as a temporary table using the given name.
303
304 The lifetime of this temporary table is tied to the L{SQLContext}
305 that was used to create this SchemaRDD.
306
307 >>> srdd = sqlCtx.inferSchema(rdd)
308 >>> srdd.registerAsTable("test")
309 >>> srdd2 = sqlCtx.sql("select * from test")
310 >>> srdd.collect() == srdd2.collect()
311 True
312 """
313 self._jschema_rdd.registerAsTable(name)
314
315 - def insertInto(self, tableName, overwrite = False):
316 """Inserts the contents of this SchemaRDD into the specified table.
317
318 Optionally overwriting any existing data.
319 """
320 self._jschema_rdd.insertInto(tableName, overwrite)
321
323 """Creates a new table with the contents of this SchemaRDD."""
324 self._jschema_rdd.saveAsTable(tableName)
325
327 """Return the number of elements in this RDD.
328
329 Unlike the base RDD implementation of count, this implementation
330 leverages the query optimizer to compute the count on the SchemaRDD,
331 which supports features such as filter pushdown.
332
333 >>> srdd = sqlCtx.inferSchema(rdd)
334 >>> srdd.count()
335 3L
336 >>> srdd.count() == srdd.map(lambda x: x).count()
337 True
338 """
339 return self._jschema_rdd.count()
340
342
343
344 from pyspark.sql import Row
345 jrdd = self._jschema_rdd.javaToPython()
346
347
348
349 return RDD(jrdd, self._sc, self._sc.serializer).map(lambda d: Row(d))
350
351
352
354 self.is_cached = True
355 self._jschema_rdd.cache()
356 return self
357
359 self.is_cached = True
360 javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
361 self._jschema_rdd.persist(javaStorageLevel)
362 return self
363
365 self.is_cached = False
366 self._jschema_rdd.unpersist()
367 return self
368
370 self.is_checkpointed = True
371 self._jschema_rdd.checkpoint()
372
375
377 checkpointFile = self._jschema_rdd.getCheckpointFile()
378 if checkpointFile.isDefined():
379 return checkpointFile.get()
380 else:
381 return None
382
383 - def coalesce(self, numPartitions, shuffle=False):
386
390
392 if (other.__class__ is SchemaRDD):
393 rdd = self._jschema_rdd.intersection(other._jschema_rdd)
394 return SchemaRDD(rdd, self.sql_ctx)
395 else:
396 raise ValueError("Can only intersect with another SchemaRDD")
397
401
402 - def subtract(self, other, numPartitions=None):
403 if (other.__class__ is SchemaRDD):
404 if numPartitions is None:
405 rdd = self._jschema_rdd.subtract(other._jschema_rdd)
406 else:
407 rdd = self._jschema_rdd.subtract(other._jschema_rdd, numPartitions)
408 return SchemaRDD(rdd, self.sql_ctx)
409 else:
410 raise ValueError("Can only subtract another SchemaRDD")
411
413 import doctest
414 from pyspark.context import SparkContext
415 globs = globals().copy()
416
417
418 sc = SparkContext('local[4]', 'PythonTest', batchSize=2)
419 globs['sc'] = sc
420 globs['sqlCtx'] = SQLContext(sc)
421 globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
422 {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
423 (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
424 globs['sc'].stop()
425 if failure_count:
426 exit(-1)
427
428
429 if __name__ == "__main__":
430 _test()
431