1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 from pyspark.rdd import RDD, PipelinedRDD
19 from pyspark.serializers import BatchedSerializer, PickleSerializer
20
21 from py4j.protocol import Py4JError
22
23 __all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]
24
25
26 -class SQLContext:
27 """Main entry point for SparkSQL functionality.
28
29 A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as
30 tables, execute SQL over tables, cache tables, and read parquet files.
31 """
32
33 - def __init__(self, sparkContext, sqlContext = None):
34 """Create a new SQLContext.
35
36 @param sparkContext: The SparkContext to wrap.
37
38 >>> srdd = sqlCtx.inferSchema(rdd)
39 >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
40 Traceback (most recent call last):
41 ...
42 ValueError:...
43
44 >>> bad_rdd = sc.parallelize([1,2,3])
45 >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
46 Traceback (most recent call last):
47 ...
48 ValueError:...
49
50 >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
51 ... "boolean" : True}])
52 >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
53 ... x.boolean))
54 >>> srdd.collect()[0]
55 (1, u'string', 1.0, 1, True)
56 """
57 self._sc = sparkContext
58 self._jsc = self._sc._jsc
59 self._jvm = self._sc._jvm
60 self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap
61
62 if sqlContext:
63 self._scala_SQLContext = sqlContext
64
65 @property
66 - def _ssql_ctx(self):
67 """Accessor for the JVM SparkSQL context.
68
69 Subclasses can override this property to provide their own
70 JVM Contexts.
71 """
72 if not hasattr(self, '_scala_SQLContext'):
73 self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
74 return self._scala_SQLContext
75
76 - def inferSchema(self, rdd):
77 """Infer and apply a schema to an RDD of L{dict}s.
78
79 We peek at the first row of the RDD to determine the fields names
80 and types, and then use that to extract all the dictionaries. Nested
81 collections are supported, which include array, dict, list, set, and
82 tuple.
83
84 >>> srdd = sqlCtx.inferSchema(rdd)
85 >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
86 ... {"field1" : 3, "field2": "row3"}]
87 True
88
89 >>> from array import array
90 >>> srdd = sqlCtx.inferSchema(nestedRdd1)
91 >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
92 ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
93 True
94
95 >>> srdd = sqlCtx.inferSchema(nestedRdd2)
96 >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
97 ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
98 True
99 """
100 if (rdd.__class__ is SchemaRDD):
101 raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
102 elif not isinstance(rdd.first(), dict):
103 raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" %
104 (SchemaRDD.__name__, rdd.first()))
105
106 jrdd = self._pythonToJavaMap(rdd._jrdd)
107 srdd = self._ssql_ctx.inferSchema(jrdd.rdd())
108 return SchemaRDD(srdd, self)
109
110 - def registerRDDAsTable(self, rdd, tableName):
111 """Registers the given RDD as a temporary table in the catalog.
112
113 Temporary tables exist only during the lifetime of this instance of
114 SQLContext.
115
116 >>> srdd = sqlCtx.inferSchema(rdd)
117 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
118 """
119 if (rdd.__class__ is SchemaRDD):
120 jschema_rdd = rdd._jschema_rdd
121 self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
122 else:
123 raise ValueError("Can only register SchemaRDD as table")
124
125 - def parquetFile(self, path):
126 """Loads a Parquet file, returning the result as a L{SchemaRDD}.
127
128 >>> import tempfile, shutil
129 >>> parquetFile = tempfile.mkdtemp()
130 >>> shutil.rmtree(parquetFile)
131 >>> srdd = sqlCtx.inferSchema(rdd)
132 >>> srdd.saveAsParquetFile(parquetFile)
133 >>> srdd2 = sqlCtx.parquetFile(parquetFile)
134 >>> sorted(srdd.collect()) == sorted(srdd2.collect())
135 True
136 """
137 jschema_rdd = self._ssql_ctx.parquetFile(path)
138 return SchemaRDD(jschema_rdd, self)
139
140
141 - def jsonFile(self, path):
142 """Loads a text file storing one JSON object per line,
143 returning the result as a L{SchemaRDD}.
144 It goes through the entire dataset once to determine the schema.
145
146 >>> import tempfile, shutil
147 >>> jsonFile = tempfile.mkdtemp()
148 >>> shutil.rmtree(jsonFile)
149 >>> ofn = open(jsonFile, 'w')
150 >>> for json in jsonStrings:
151 ... print>>ofn, json
152 >>> ofn.close()
153 >>> srdd = sqlCtx.jsonFile(jsonFile)
154 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
155 >>> srdd2 = sqlCtx.sql(
156 ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1")
157 >>> srdd2.collect() == [
158 ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None},
159 ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
160 ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
161 True
162 """
163 jschema_rdd = self._ssql_ctx.jsonFile(path)
164 return SchemaRDD(jschema_rdd, self)
165
166 - def jsonRDD(self, rdd):
167 """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}.
168 It goes through the entire dataset once to determine the schema.
169
170 >>> srdd = sqlCtx.jsonRDD(json)
171 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
172 >>> srdd2 = sqlCtx.sql(
173 ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1")
174 >>> srdd2.collect() == [
175 ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None},
176 ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
177 ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
178 True
179 """
180 def func(split, iterator):
181 for x in iterator:
182 if not isinstance(x, basestring):
183 x = unicode(x)
184 yield x.encode("utf-8")
185 keyed = PipelinedRDD(rdd, func)
186 keyed._bypass_serializer = True
187 jrdd = keyed._jrdd.map(self._jvm.BytesToString())
188 jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
189 return SchemaRDD(jschema_rdd, self)
190
191 - def sql(self, sqlQuery):
192 """Return a L{SchemaRDD} representing the result of the given query.
193
194 >>> srdd = sqlCtx.inferSchema(rdd)
195 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
196 >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
197 >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"},
198 ... {"f1" : 3, "f2": "row3"}]
199 True
200 """
201 return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
202
203 - def table(self, tableName):
204 """Returns the specified table as a L{SchemaRDD}.
205
206 >>> srdd = sqlCtx.inferSchema(rdd)
207 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
208 >>> srdd2 = sqlCtx.table("table1")
209 >>> sorted(srdd.collect()) == sorted(srdd2.collect())
210 True
211 """
212 return SchemaRDD(self._ssql_ctx.table(tableName), self)
213
214 - def cacheTable(self, tableName):
215 """Caches the specified table in-memory."""
216 self._ssql_ctx.cacheTable(tableName)
217
218 - def uncacheTable(self, tableName):
219 """Removes the specified table from the in-memory cache."""
220 self._ssql_ctx.uncacheTable(tableName)
221
222
223 -class HiveContext(SQLContext):
224 """A variant of Spark SQL that integrates with data stored in Hive.
225
226 Configuration for Hive is read from hive-site.xml on the classpath.
227 It supports running both SQL and HiveQL commands.
228 """
229
230 @property
231 - def _ssql_ctx(self):
232 try:
233 if not hasattr(self, '_scala_HiveContext'):
234 self._scala_HiveContext = self._get_hive_ctx()
235 return self._scala_HiveContext
236 except Py4JError as e:
237 raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \
238 "sbt/sbt assembly" , e)
239
240 - def _get_hive_ctx(self):
241 return self._jvm.HiveContext(self._jsc.sc())
242
243 - def hiveql(self, hqlQuery):
244 """
245 Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
246 """
247 return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
248
249 - def hql(self, hqlQuery):
250 """
251 Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
252 """
253 return self.hiveql(hqlQuery)
254
255
256 -class LocalHiveContext(HiveContext):
257 """Starts up an instance of hive where metadata is stored locally.
258
259 An in-process metadata data is created with data stored in ./metadata.
260 Warehouse data is stored in in ./warehouse.
261
262 >>> import os
263 >>> hiveCtx = LocalHiveContext(sc)
264 >>> try:
265 ... supress = hiveCtx.hql("DROP TABLE src")
266 ... except Exception:
267 ... pass
268 >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt')
269 >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
270 >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1)
271 >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1]))
272 >>> num = results.count()
273 >>> reduce_sum = results.reduce(lambda x, y: x + y)
274 >>> num
275 500
276 >>> reduce_sum
277 130091
278 """
279
280 - def _get_hive_ctx(self):
281 return self._jvm.LocalHiveContext(self._jsc.sc())
282
283
284 -class TestHiveContext(HiveContext):
285
286 - def _get_hive_ctx(self):
287 return self._jvm.TestHiveContext(self._jsc.sc())
288
289
290
291
292 -class Row(dict):
293 """A row in L{SchemaRDD}.
294
295 An extended L{dict} that takes a L{dict} in its constructor, and
296 exposes those items as fields.
297
298 >>> r = Row({"hello" : "world", "foo" : "bar"})
299 >>> r.hello
300 'world'
301 >>> r.foo
302 'bar'
303 """
304
306 d.update(self.__dict__)
307 self.__dict__ = d
308 dict.__init__(self, d)
309
312 """An RDD of L{Row} objects that has an associated schema.
313
314 The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can
315 utilize the relational query api exposed by SparkSQL.
316
317 For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the
318 L{SchemaRDD} is not operated on directly, as it's underlying
319 implementation is an RDD composed of Java objects. Instead it is
320 converted to a PythonRDD in the JVM, on which Python operations can
321 be done.
322 """
323
324 - def __init__(self, jschema_rdd, sql_ctx):
325 self.sql_ctx = sql_ctx
326 self._sc = sql_ctx._sc
327 self._jschema_rdd = jschema_rdd
328
329 self.is_cached = False
330 self.is_checkpointed = False
331 self.ctx = self.sql_ctx._sc
332 self._jrdd_deserializer = self.ctx.serializer
333
334 @property
336 """Lazy evaluation of PythonRDD object.
337
338 Only done when a user calls methods defined by the
339 L{pyspark.rdd.RDD} super class (map, filter, etc.).
340 """
341 if not hasattr(self, '_lazy_jrdd'):
342 self._lazy_jrdd = self._toPython()._jrdd
343 return self._lazy_jrdd
344
345 @property
347 return self._jrdd.id()
348
350 """Save the contents as a Parquet file, preserving the schema.
351
352 Files that are written out using this method can be read back in as
353 a SchemaRDD using the L{SQLContext.parquetFile} method.
354
355 >>> import tempfile, shutil
356 >>> parquetFile = tempfile.mkdtemp()
357 >>> shutil.rmtree(parquetFile)
358 >>> srdd = sqlCtx.inferSchema(rdd)
359 >>> srdd.saveAsParquetFile(parquetFile)
360 >>> srdd2 = sqlCtx.parquetFile(parquetFile)
361 >>> sorted(srdd2.collect()) == sorted(srdd.collect())
362 True
363 """
364 self._jschema_rdd.saveAsParquetFile(path)
365
367 """Registers this RDD as a temporary table using the given name.
368
369 The lifetime of this temporary table is tied to the L{SQLContext}
370 that was used to create this SchemaRDD.
371
372 >>> srdd = sqlCtx.inferSchema(rdd)
373 >>> srdd.registerAsTable("test")
374 >>> srdd2 = sqlCtx.sql("select * from test")
375 >>> sorted(srdd.collect()) == sorted(srdd2.collect())
376 True
377 """
378 self._jschema_rdd.registerAsTable(name)
379
380 - def insertInto(self, tableName, overwrite = False):
381 """Inserts the contents of this SchemaRDD into the specified table.
382
383 Optionally overwriting any existing data.
384 """
385 self._jschema_rdd.insertInto(tableName, overwrite)
386
388 """Creates a new table with the contents of this SchemaRDD."""
389 self._jschema_rdd.saveAsTable(tableName)
390
392 """Returns the output schema in the tree format."""
393 return self._jschema_rdd.schemaString()
394
396 """Prints out the schema in the tree format."""
397 print self.schemaString()
398
400 """Return the number of elements in this RDD.
401
402 Unlike the base RDD implementation of count, this implementation
403 leverages the query optimizer to compute the count on the SchemaRDD,
404 which supports features such as filter pushdown.
405
406 >>> srdd = sqlCtx.inferSchema(rdd)
407 >>> srdd.count()
408 3L
409 >>> srdd.count() == srdd.map(lambda x: x).count()
410 True
411 """
412 return self._jschema_rdd.count()
413
424
425
426
428 self.is_cached = True
429 self._jschema_rdd.cache()
430 return self
431
433 self.is_cached = True
434 javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
435 self._jschema_rdd.persist(javaStorageLevel)
436 return self
437
439 self.is_cached = False
440 self._jschema_rdd.unpersist()
441 return self
442
444 self.is_checkpointed = True
445 self._jschema_rdd.checkpoint()
446
449
451 checkpointFile = self._jschema_rdd.getCheckpointFile()
452 if checkpointFile.isDefined():
453 return checkpointFile.get()
454 else:
455 return None
456
457 - def coalesce(self, numPartitions, shuffle=False):
460
464
466 if (other.__class__ is SchemaRDD):
467 rdd = self._jschema_rdd.intersection(other._jschema_rdd)
468 return SchemaRDD(rdd, self.sql_ctx)
469 else:
470 raise ValueError("Can only intersect with another SchemaRDD")
471
475
476 - def subtract(self, other, numPartitions=None):
477 if (other.__class__ is SchemaRDD):
478 if numPartitions is None:
479 rdd = self._jschema_rdd.subtract(other._jschema_rdd)
480 else:
481 rdd = self._jschema_rdd.subtract(other._jschema_rdd, numPartitions)
482 return SchemaRDD(rdd, self.sql_ctx)
483 else:
484 raise ValueError("Can only subtract another SchemaRDD")
485
487 import doctest
488 from array import array
489 from pyspark.context import SparkContext
490 globs = globals().copy()
491
492
493 sc = SparkContext('local[4]', 'PythonTest', batchSize=2)
494 globs['sc'] = sc
495 globs['sqlCtx'] = SQLContext(sc)
496 globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
497 {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
498 jsonStrings = ['{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
499 '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}',
500 '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}']
501 globs['jsonStrings'] = jsonStrings
502 globs['json'] = sc.parallelize(jsonStrings)
503 globs['nestedRdd1'] = sc.parallelize([
504 {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
505 {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}])
506 globs['nestedRdd2'] = sc.parallelize([
507 {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
508 {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}])
509 (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
510 globs['sc'].stop()
511 if failure_count:
512 exit(-1)
513
514
515 if __name__ == "__main__":
516 _test()
517