Source code for pyspark.sql.dataframe

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys
import warnings
import random

if sys.version >= '3':
    basestring = unicode = str
    long = int
    from functools import reduce
else:
    from itertools import imap as map

from pyspark import copy_func, since
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import _parse_datatype_json_string
from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
from pyspark.sql.readwriter import DataFrameWriter
from pyspark.sql.types import *

__all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"]


[docs]class DataFrame(object): """A distributed collection of data grouped into named columns. A :class:`DataFrame` is equivalent to a relational table in Spark SQL, and can be created using various functions in :class:`SQLContext`:: people = sqlContext.read.parquet("...") Once created, it can be manipulated using the various domain-specific-language (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. To select a column from the data frame, use the apply method:: ageCol = people.age A more concrete example:: # To create DataFrame using SQLContext people = sqlContext.read.parquet("...") department = sqlContext.read.parquet("...") people.filter(people.age > 30).join(department, people.deptId == department.id)\ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) .. versionadded:: 1.3 """ def __init__(self, jdf, sql_ctx): self._jdf = jdf self.sql_ctx = sql_ctx self._sc = sql_ctx and sql_ctx._sc self.is_cached = False self._schema = None # initialized lazily self._lazy_rdd = None @property @since(1.3) def rdd(self): """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. """ if self._lazy_rdd is None: jrdd = self._jdf.javaToPython() self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) return self._lazy_rdd @property @since("1.3.1") def na(self): """Returns a :class:`DataFrameNaFunctions` for handling missing values. """ return DataFrameNaFunctions(self) @property @since(1.4) def stat(self): """Returns a :class:`DataFrameStatFunctions` for statistic functions. """ return DataFrameStatFunctions(self) @ignore_unicode_prefix @since(1.3)
[docs] def toJSON(self, use_unicode=True): """Converts a :class:`DataFrame` into a :class:`RDD` of string. Each row is turned into a JSON document as one element in the returned RDD. >>> df.toJSON().first() u'{"age":2,"name":"Alice"}' """ rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
@since(1.3)
[docs] def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. The lifetime of this temporary table is tied to the :class:`SQLContext` that was used to create this :class:`DataFrame`. >>> df.registerTempTable("people") >>> df2 = spark.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) True >>> spark.catalog.dropTempView("people") .. note:: Deprecated in 2.0, use createOrReplaceTempView instead. """ self._jdf.createOrReplaceTempView(name)
@since(2.0)
[docs] def createTempView(self, name): """Creates a temporary view with this DataFrame. The lifetime of this temporary table is tied to the :class:`SparkSession` that was used to create this :class:`DataFrame`. throws :class:`TempTableAlreadyExistsException`, if the view name already exists in the catalog. >>> df.createTempView("people") >>> df2 = spark.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) True >>> df.createTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... Py4JJavaError: ... : org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException... >>> spark.catalog.dropTempView("people") """ self._jdf.createTempView(name)
@since(2.0)
[docs] def createOrReplaceTempView(self, name): """Creates or replaces a temporary view with this DataFrame. The lifetime of this temporary table is tied to the :class:`SparkSession` that was used to create this :class:`DataFrame`. >>> df.createOrReplaceTempView("people") >>> df2 = df.filter(df.age > 3) >>> df2.createOrReplaceTempView("people") >>> df3 = spark.sql("select * from people") >>> sorted(df3.collect()) == sorted(df2.collect()) True >>> spark.catalog.dropTempView("people") """ self._jdf.createOrReplaceTempView(name)
@property @since(1.4) def write(self): """ Interface for saving the content of the :class:`DataFrame` out into external storage. :return: :class:`DataFrameWriter` """ return DataFrameWriter(self) @property @since(1.3) def schema(self): """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`. >>> df.schema StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) """ if self._schema is None: try: self._schema = _parse_datatype_json_string(self._jdf.schema().json()) except AttributeError as e: raise Exception( "Unable to parse datatype from schema. %s" % e) return self._schema @since(1.3)
[docs] def printSchema(self): """Prints out the schema in the tree format. >>> df.printSchema() root |-- age: integer (nullable = true) |-- name: string (nullable = true) <BLANKLINE> """ print(self._jdf.schema().treeString())
@since(1.3)
[docs] def explain(self, extended=False): """Prints the (logical and physical) plans to the console for debugging purpose. :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. >>> df.explain() == Physical Plan == Scan ExistingRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == ... == Analyzed Logical Plan == ... == Optimized Logical Plan == ... == Physical Plan == ... """ if extended: print(self._jdf.queryExecution().toString()) else: print(self._jdf.queryExecution().simpleString())
@since(1.3)
[docs] def isLocal(self): """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally (without any Spark executors). """ return self._jdf.isLocal()
@property @since(2.0) def isStreaming(self): """Returns true if this :class:`Dataset` contains one or more sources that continuously return data as it arrives. A :class:`Dataset` that reads data from a streaming source must be executed as a :class:`ContinuousQuery` using the :func:`startStream` method in :class:`DataFrameWriter`. Methods that return a single answer, (e.g., :func:`count` or :func:`collect`) will throw an :class:`AnalysisException` when there is a streaming source present. .. note:: Experimental """ return self._jdf.isStreaming() @since(1.3)
[docs] def show(self, n=20, truncate=True): """Prints the first ``n`` rows to the console. :param n: Number of rows to show. :param truncate: Whether truncate long strings and align cells right. >>> df DataFrame[age: int, name: string] >>> df.show() +---+-----+ |age| name| +---+-----+ | 2|Alice| | 5| Bob| +---+-----+ """ print(self._jdf.showString(n, truncate))
def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) @since(1.3)
[docs] def count(self): """Returns the number of rows in this :class:`DataFrame`. >>> df.count() 2 """ return int(self._jdf.count())
@ignore_unicode_prefix @since(1.3)
[docs] def collect(self): """Returns all the records as a list of :class:`Row`. >>> df.collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: port = self._jdf.collectToPython() return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
@ignore_unicode_prefix @since(2.0)
[docs] def toLocalIterator(self): """ Returns an iterator that contains all of the rows in this :class:`DataFrame`. The iterator will consume as much memory as the largest partition in this DataFrame. >>> list(df.toLocalIterator()) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: port = self._jdf.toPythonIterator() return _load_from_socket(port, BatchedSerializer(PickleSerializer()))
@ignore_unicode_prefix @since(1.3)
[docs] def limit(self, num): """Limits the result count to the number specified. >>> df.limit(1).collect() [Row(age=2, name=u'Alice')] >>> df.limit(0).collect() [] """ jdf = self._jdf.limit(num) return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def take(self, num): """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe( self._jdf, num) return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
@since(1.3)
[docs] def foreach(self, f): """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. This is a shorthand for ``df.rdd.foreach()``. >>> def f(person): ... print(person.name) >>> df.foreach(f) """ self.rdd.foreach(f)
@since(1.3)
[docs] def foreachPartition(self, f): """Applies the ``f`` function to each partition of this :class:`DataFrame`. This a shorthand for ``df.rdd.foreachPartition()``. >>> def f(people): ... for person in people: ... print(person.name) >>> df.foreachPartition(f) """ self.rdd.foreachPartition(f)
@since(1.3)
[docs] def cache(self): """ Persists with the default storage level (C{MEMORY_ONLY}). """ self.is_cached = True self._jdf.cache() return self
@since(1.3)
[docs] def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): """Sets the storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. If no storage level is specified defaults to (C{MEMORY_ONLY}). """ self.is_cached = True javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) self._jdf.persist(javaStorageLevel) return self
@since(1.3)
[docs] def unpersist(self, blocking=False): """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from memory and disk. .. note:: `blocking` default has changed to False to match Scala in 2.0. """ self.is_cached = False self._jdf.unpersist(blocking) return self
@since(1.4)
[docs] def coalesce(self, numPartitions): """ Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions. Similar to coalesce defined on an :class:`RDD`, this operation results in a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions. >>> df.coalesce(1).rdd.getNumPartitions() 1 """ return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx)
@since(1.3)
[docs] def repartition(self, numPartitions, *cols): """ Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The resulting DataFrame is hash partitioned. ``numPartitions`` can be an int to specify the target number of partitions or a Column. If it is a Column, it will be used as the first partitioning column. If not specified, the default number of partitions is used. .. versionchanged:: 1.6 Added optional arguments to specify the partitioning columns. Also made numPartitions optional if partitioning columns are specified. >>> df.repartition(10).rdd.getNumPartitions() 10 >>> data = df.union(df).repartition("age") >>> data.show() +---+-----+ |age| name| +---+-----+ | 5| Bob| | 5| Bob| | 2|Alice| | 2|Alice| +---+-----+ >>> data = data.repartition(7, "age") >>> data.show() +---+-----+ |age| name| +---+-----+ | 5| Bob| | 5| Bob| | 2|Alice| | 2|Alice| +---+-----+ >>> data.rdd.getNumPartitions() 7 >>> data = data.repartition("name", "age") >>> data.show() +---+-----+ |age| name| +---+-----+ | 5| Bob| | 5| Bob| | 2|Alice| | 2|Alice| +---+-----+ """ if isinstance(numPartitions, int): if len(cols) == 0: return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) else: return DataFrame( self._jdf.repartition(numPartitions, self._jcols(*cols)), self.sql_ctx) elif isinstance(numPartitions, (basestring, Column)): cols = (numPartitions, ) + cols return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sql_ctx) else: raise TypeError("numPartitions should be an int or Column")
@since(1.3)
[docs] def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. >>> df.distinct().count() 2 """ return DataFrame(self._jdf.distinct(), self.sql_ctx)
@since(1.3)
[docs] def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. >>> df.sample(False, 0.5, 42).count() 2 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction seed = seed if seed is not None else random.randint(0, sys.maxsize) rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx)
@since(1.5)
[docs] def sampleBy(self, col, fractions, seed=None): """ Returns a stratified sample without replacement based on the fraction given on each stratum. :param col: column that defines strata :param fractions: sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as zero. :param seed: random seed :return: a new DataFrame that represents the stratified sample >>> from pyspark.sql.functions import col >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) >>> sampled.groupBy("key").count().orderBy("key").show() +---+-----+ |key|count| +---+-----+ | 0| 5| | 1| 9| +---+-----+ """ if not isinstance(col, str): raise ValueError("col must be a string, but got %r" % type(col)) if not isinstance(fractions, dict): raise ValueError("fractions must be a dict but got %r" % type(fractions)) for k, v in fractions.items(): if not isinstance(k, (float, int, long, basestring)): raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) fractions[k] = float(v) seed = seed if seed is not None else random.randint(0, sys.maxsize) return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)
@since(1.4)
[docs] def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. :param weights: list of doubles as weights with which to split the DataFrame. Weights will be normalized if they don't sum up to 1.0. :param seed: The seed for sampling. >>> splits = df4.randomSplit([1.0, 2.0], 24) >>> splits[0].count() 1 >>> splits[1].count() 3 """ for w in weights: if w < 0.0: raise ValueError("Weights must be positive. Found weight value: %s" % w) seed = seed if seed is not None else random.randint(0, sys.maxsize) rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights), long(seed)) return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
@property @since(1.3) def dtypes(self): """Returns all column names and their data types as a list. >>> df.dtypes [('age', 'int'), ('name', 'string')] """ return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] @property @since(1.3) def columns(self): """Returns all column names as a list. >>> df.columns ['age', 'name'] """ return [f.name for f in self.schema.fields] @ignore_unicode_prefix @since(1.3)
[docs] def alias(self, alias): """Returns a new :class:`DataFrame` with an alias set. >>> from pyspark.sql.functions import * >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') >>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age").collect() [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)] """ assert isinstance(alias, basestring), "alias should be a string" return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def join(self, other, on=None, how=None): """Joins with another :class:`DataFrame`, using the given join expression. The following performs a full outer join between ``df1`` and ``df2``. :param other: Right side of the join :param on: a string for join column name, a list of column names, , a join expression (Column) or a list of Columns. If `on` is a string or a list of string indicating the name of the join column(s), the column(s) must exist on both sides, and this performs an equi-join. :param how: str, default 'inner'. One of `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] >>> df.join(df2, 'name', 'outer').select('name', 'height').collect() [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] >>> cond = [df.name == df3.name, df.age == df3.age] >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df.join(df2, 'name').select(df.name, df2.height).collect() [Row(name=u'Bob', height=85)] >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect() [Row(name=u'Bob', age=5)] """ if on is not None and not isinstance(on, list): on = [on] if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) elif isinstance(on[0], basestring): if how is None: jdf = self._jdf.join(other._jdf, self._jseq(on), "inner") else: assert isinstance(how, basestring), "how should be basestring" jdf = self._jdf.join(other._jdf, self._jseq(on), how) else: assert isinstance(on[0], Column), "on should be Column or list of Column" if len(on) > 1: on = reduce(lambda x, y: x.__and__(y), on) else: on = on[0] if how is None: jdf = self._jdf.join(other._jdf, on._jc, "inner") else: assert isinstance(how, basestring), "how should be basestring" jdf = self._jdf.join(other._jdf, on._jc, how) return DataFrame(jdf, self.sql_ctx)
@since(1.6)
[docs] def sortWithinPartitions(self, *cols, **kwargs): """Returns a new :class:`DataFrame` with each partition sorted by the specified column(s). :param cols: list of :class:`Column` or column names to sort by. :param ascending: boolean or list of boolean (default True). Sort ascending vs. descending. Specify list for multiple sort orders. If a list is specified, length of the list must equal length of the `cols`. >>> df.sortWithinPartitions("age", ascending=False).show() +---+-----+ |age| name| +---+-----+ | 2|Alice| | 5| Bob| +---+-----+ """ jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def sort(self, *cols, **kwargs): """Returns a new :class:`DataFrame` sorted by the specified column(s). :param cols: list of :class:`Column` or column names to sort by. :param ascending: boolean or list of boolean (default True). Sort ascending vs. descending. Specify list for multiple sort orders. If a list is specified, length of the list must equal length of the `cols`. >>> df.sort(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> df.sort("age", ascending=False).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> df.orderBy(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> from pyspark.sql.functions import * >>> df.sort(asc("age")).collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.orderBy(desc("age"), "name").collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ jdf = self._jdf.sort(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sql_ctx)
orderBy = sort def _jseq(self, cols, converter=None): """Return a JVM Seq of Columns from a list of Column or names""" return _to_seq(self.sql_ctx._sc, cols, converter) def _jmap(self, jm): """Return a JVM Scala Map from a dict""" return _to_scala_map(self.sql_ctx._sc, jm) def _jcols(self, *cols): """Return a JVM Seq of Columns from a list of Column or column names If `cols` has only one list in it, cols[0] will be used as the list. """ if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] return self._jseq(cols, _to_java_column) def _sort_cols(self, cols, kwargs): """ Return a JVM Seq of Columns that describes the sort order """ if not cols: raise ValueError("should sort by at least one column") if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] jcols = [_to_java_column(c) for c in cols] ascending = kwargs.get('ascending', True) if isinstance(ascending, (bool, int)): if not ascending: jcols = [jc.desc() for jc in jcols] elif isinstance(ascending, list): jcols = [jc if asc else jc.desc() for asc, jc in zip(ascending, jcols)] else: raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) return self._jseq(jcols) @since("1.3.1")
[docs] def describe(self, *cols): """Computes statistics for numeric columns. This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical columns. .. note:: This function is meant for exploratory data analysis, as we make no \ guarantee about the backward compatibility of the schema of the resulting DataFrame. >>> df.describe().show() +-------+------------------+ |summary| age| +-------+------------------+ | count| 2| | mean| 3.5| | stddev|2.1213203435596424| | min| 2| | max| 5| +-------+------------------+ >>> df.describe(['age', 'name']).show() +-------+------------------+-----+ |summary| age| name| +-------+------------------+-----+ | count| 2| 2| | mean| 3.5| null| | stddev|2.1213203435596424| null| | min| 2|Alice| | max| 5| Bob| +-------+------------------+-----+ """ if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] jdf = self._jdf.describe(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def head(self, n=None): """Returns the first ``n`` rows. Note that this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory. :param n: int, default 1. Number of rows to return. :return: If n is greater than 1, return a list of :class:`Row`. If n is 1, return a single Row. >>> df.head() Row(age=2, name=u'Alice') >>> df.head(1) [Row(age=2, name=u'Alice')] """ if n is None: rs = self.head(1) return rs[0] if rs else None return self.take(n)
@ignore_unicode_prefix @since(1.3)
[docs] def first(self): """Returns the first row as a :class:`Row`. >>> df.first() Row(age=2, name=u'Alice') """ return self.head()
@ignore_unicode_prefix @since(1.3) def __getitem__(self, item): """Returns the column as a :class:`Column`. >>> df.select(df['age']).collect() [Row(age=2), Row(age=5)] >>> df[ ["name", "age"]].collect() [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df[ df.age > 3 ].collect() [Row(age=5, name=u'Bob')] >>> df[df[0] > 3].collect() [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): jc = self._jdf.apply(item) return Column(jc) elif isinstance(item, Column): return self.filter(item) elif isinstance(item, (list, tuple)): return self.select(*item) elif isinstance(item, int): jc = self._jdf.apply(self.columns[item]) return Column(jc) else: raise TypeError("unexpected item type: %s" % type(item)) @since(1.3) def __getattr__(self, name): """Returns the :class:`Column` denoted by ``name``. >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] """ if name not in self.columns: raise AttributeError( "'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) jc = self._jdf.apply(name) return Column(jc) @ignore_unicode_prefix @since(1.3)
[docs] def select(self, *cols): """Projects a set of expressions and returns a new :class:`DataFrame`. :param cols: list of column names (string) or expressions (:class:`Column`). If one of the column names is '*', that column is expanded to include all columns in the current DataFrame. >>> df.select('*').collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.select('name', 'age').collect() [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df.select(df.name, (df.age + 10).alias('age')).collect() [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] """ jdf = self._jdf.select(self._jcols(*cols)) return DataFrame(jdf, self.sql_ctx)
@since(1.3)
[docs] def selectExpr(self, *expr): """Projects a set of SQL expressions and returns a new :class:`DataFrame`. This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)] """ if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] jdf = self._jdf.selectExpr(self._jseq(expr)) return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def filter(self, condition): """Filters rows using the given condition. :func:`where` is an alias for :func:`filter`. :param condition: a :class:`Column` of :class:`types.BooleanType` or a string of SQL expression. >>> df.filter(df.age > 3).collect() [Row(age=5, name=u'Bob')] >>> df.where(df.age == 2).collect() [Row(age=2, name=u'Alice')] >>> df.filter("age > 3").collect() [Row(age=5, name=u'Bob')] >>> df.where("age = 2").collect() [Row(age=2, name=u'Alice')] """ if isinstance(condition, basestring): jdf = self._jdf.filter(condition) elif isinstance(condition, Column): jdf = self._jdf.filter(condition._jc) else: raise TypeError("condition should be string or Column") return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def groupBy(self, *cols): """Groups the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedData` for all the available aggregate functions. :func:`groupby` is an alias for :func:`groupBy`. :param cols: list of columns to group by. Each element should be a column name (string) or an expression (:class:`Column`). >>> df.groupBy().avg().collect() [Row(avg(age)=3.5)] >>> sorted(df.groupBy('name').agg({'age': 'mean'}).collect()) [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> sorted(df.groupBy(df.name).avg().collect()) [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> sorted(df.groupBy(['name', df.age]).count().collect()) [Row(name=u'Alice', age=2, count=1), Row(name=u'Bob', age=5, count=1)] """ jgd = self._jdf.groupBy(self._jcols(*cols)) from pyspark.sql.group import GroupedData return GroupedData(jgd, self.sql_ctx)
@since(1.4)
[docs] def rollup(self, *cols): """ Create a multi-dimensional rollup for the current :class:`DataFrame` using the specified columns, so we can run aggregation on them. >>> df.rollup("name", df.age).count().orderBy("name", "age").show() +-----+----+-----+ | name| age|count| +-----+----+-----+ | null|null| 2| |Alice|null| 1| |Alice| 2| 1| | Bob|null| 1| | Bob| 5| 1| +-----+----+-----+ """ jgd = self._jdf.rollup(self._jcols(*cols)) from pyspark.sql.group import GroupedData return GroupedData(jgd, self.sql_ctx)
@since(1.4)
[docs] def cube(self, *cols): """ Create a multi-dimensional cube for the current :class:`DataFrame` using the specified columns, so we can run aggregation on them. >>> df.cube("name", df.age).count().orderBy("name", "age").show() +-----+----+-----+ | name| age|count| +-----+----+-----+ | null|null| 2| | null| 2| 1| | null| 5| 1| |Alice|null| 1| |Alice| 2| 1| | Bob|null| 1| | Bob| 5| 1| +-----+----+-----+ """ jgd = self._jdf.cube(self._jcols(*cols)) from pyspark.sql.group import GroupedData return GroupedData(jgd, self.sql_ctx)
@since(1.3)
[docs] def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups (shorthand for ``df.groupBy.agg()``). >>> df.agg({"age": "max"}).collect() [Row(max(age)=5)] >>> from pyspark.sql import functions as F >>> df.agg(F.min(df.age)).collect() [Row(min(age)=2)] """ return self.groupBy().agg(*exprs)
@since(2.0)
[docs] def union(self, other): """ Return a new :class:`DataFrame` containing union of rows in this frame and another frame. This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does deduplication of elements), use this function followed by a distinct. """ return DataFrame(self._jdf.union(other._jdf), self.sql_ctx)
@since(1.3)
[docs] def unionAll(self, other): """ Return a new :class:`DataFrame` containing union of rows in this frame and another frame. .. note:: Deprecated in 2.0, use union instead. """ return self.union(other)
@since(1.3)
[docs] def intersect(self, other): """ Return a new :class:`DataFrame` containing rows only in both this frame and another frame. This is equivalent to `INTERSECT` in SQL. """ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
@since(1.3)
[docs] def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame but not in another frame. This is equivalent to `EXCEPT` in SQL. """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
@since(1.4)
[docs] def dropDuplicates(self, subset=None): """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. :func:`drop_duplicates` is an alias for :func:`dropDuplicates`. >>> from pyspark.sql import Row >>> df = sc.parallelize([ \ Row(name='Alice', age=5, height=80), \ Row(name='Alice', age=5, height=80), \ Row(name='Alice', age=10, height=80)]).toDF() >>> df.dropDuplicates().show() +---+------+-----+ |age|height| name| +---+------+-----+ | 5| 80|Alice| | 10| 80|Alice| +---+------+-----+ >>> df.dropDuplicates(['name', 'height']).show() +---+------+-----+ |age|height| name| +---+------+-----+ | 5| 80|Alice| +---+------+-----+ """ if subset is None: jdf = self._jdf.dropDuplicates() else: jdf = self._jdf.dropDuplicates(self._jseq(subset)) return DataFrame(jdf, self.sql_ctx)
@since("1.3.1")
[docs] def dropna(self, how='any', thresh=None, subset=None): """Returns a new :class:`DataFrame` omitting rows with null values. :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other. :param how: 'any' or 'all'. If 'any', drop a row if it contains any nulls. If 'all', drop a row only if all its values are null. :param thresh: int, default None If specified, drop rows that have less than `thresh` non-null values. This overwrites the `how` parameter. :param subset: optional list of column names to consider. >>> df4.na.drop().show() +---+------+-----+ |age|height| name| +---+------+-----+ | 10| 80|Alice| +---+------+-----+ """ if how is not None and how not in ['any', 'all']: raise ValueError("how ('" + how + "') should be 'any' or 'all'") if subset is None: subset = self.columns elif isinstance(subset, basestring): subset = [subset] elif not isinstance(subset, (list, tuple)): raise ValueError("subset should be a list or tuple of column names") if thresh is None: thresh = len(subset) if how == 'any' else 1 return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx)
@since("1.3.1")
[docs] def fillna(self, value, subset=None): """Replace null values, alias for ``na.fill()``. :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other. :param value: int, long, float, string, or dict. Value to replace null values with. If the value is a dict, then `subset` is ignored and `value` must be a mapping from column name (string) to replacement value. The replacement value must be an int, long, float, or string. :param subset: optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. >>> df4.na.fill(50).show() +---+------+-----+ |age|height| name| +---+------+-----+ | 10| 80|Alice| | 5| 50| Bob| | 50| 50| Tom| | 50| 50| null| +---+------+-----+ >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() +---+------+-------+ |age|height| name| +---+------+-------+ | 10| 80| Alice| | 5| null| Bob| | 50| null| Tom| | 50| null|unknown| +---+------+-------+ """ if not isinstance(value, (float, int, long, basestring, dict)): raise ValueError("value should be a float, int, long, string, or dict") if isinstance(value, (int, long)): value = float(value) if isinstance(value, dict): return DataFrame(self._jdf.na().fill(value), self.sql_ctx) elif subset is None: return DataFrame(self._jdf.na().fill(value), self.sql_ctx) else: if isinstance(subset, basestring): subset = [subset] elif not isinstance(subset, (list, tuple)): raise ValueError("subset should be a list or tuple of column names") return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
@since(1.4)
[docs] def replace(self, to_replace, value, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. :param to_replace: int, long, float, string, or list. Value to be replaced. If the value is a dict, then `value` is ignored and `to_replace` must be a mapping from column name (string) to replacement value. The value to be replaced must be an int, long, float, or string. :param value: int, long, float, string, or list. Value to use to replace holes. The replacement value must be an int, long, float, or string. If `value` is a list or tuple, `value` should be of the same length with `to_replace`. :param subset: optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. >>> df4.na.replace(10, 20).show() +----+------+-----+ | age|height| name| +----+------+-----+ | 20| 80|Alice| | 5| null| Bob| |null| null| Tom| |null| null| null| +----+------+-----+ >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| +----+------+----+ | 10| 80| A| | 5| null| B| |null| null| Tom| |null| null|null| +----+------+----+ """ if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)): raise ValueError( "to_replace should be a float, int, long, string, list, tuple, or dict") if not isinstance(value, (float, int, long, basestring, list, tuple)): raise ValueError("value should be a float, int, long, string, list, or tuple") rep_dict = dict() if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] if isinstance(to_replace, tuple): to_replace = list(to_replace) if isinstance(value, tuple): value = list(value) if isinstance(to_replace, list) and isinstance(value, list): if len(to_replace) != len(value): raise ValueError("to_replace and value lists should be of the same length") rep_dict = dict(zip(to_replace, value)) elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): rep_dict = dict([(tr, value) for tr in to_replace]) elif isinstance(to_replace, dict): rep_dict = to_replace if subset is None: return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) elif isinstance(subset, basestring): subset = [subset] if not isinstance(subset, (list, tuple)): raise ValueError("subset should be a list or tuple of column names") return DataFrame( self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)
@since(2.0)
[docs] def approxQuantile(self, col, probabilities, relativeError): """ Calculates the approximate quantiles of a numerical column of a DataFrame. The result of this algorithm has the following deterministic bound: If the DataFrame has N elements and if we request the quantile at probability `p` up to error `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank of `x` is close to (p * N). More precisely, floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). This method implements a variation of the Greenwald-Khanna algorithm (with some speed optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. :param col: the name of the numerical column :param probabilities: a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the minimum, 0.5 is the median, 1 is the maximum. :param relativeError: The relative target precision to achieve (>= 0). If set to zero, the exact quantiles are computed, which could be very expensive. Note that values greater than 1 are accepted but give the same result as 1. :return: the approximate quantiles at the given probabilities """ if not isinstance(col, str): raise ValueError("col should be a string.") if not isinstance(probabilities, (list, tuple)): raise ValueError("probabilities should be a list or tuple") if isinstance(probabilities, tuple): probabilities = list(probabilities) for p in probabilities: if not isinstance(p, (float, int, long)) or p < 0 or p > 1: raise ValueError("probabilities should be numerical (float, int, long) in [0,1].") probabilities = _to_list(self._sc, probabilities) if not isinstance(relativeError, (float, int, long)) or relativeError < 0: raise ValueError("relativeError should be numerical (float, int, long) >= 0.") relativeError = float(relativeError) jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError) return list(jaq)
@since(1.4)
[docs] def corr(self, col1, col2, method=None): """ Calculates the correlation of two columns of a DataFrame as a double value. Currently only supports the Pearson Correlation Coefficient. :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases of each other. :param col1: The name of the first column :param col2: The name of the second column :param method: The correlation method. Currently only supports "pearson" """ if not isinstance(col1, str): raise ValueError("col1 should be a string.") if not isinstance(col2, str): raise ValueError("col2 should be a string.") if not method: method = "pearson" if not method == "pearson": raise ValueError("Currently only the calculation of the Pearson Correlation " + "coefficient is supported.") return self._jdf.stat().corr(col1, col2, method)
@since(1.4)
[docs] def cov(self, col1, col2): """ Calculate the sample covariance for the given columns, specified by their names, as a double value. :func:`DataFrame.cov` and :func:`DataFrameStatFunctions.cov` are aliases. :param col1: The name of the first column :param col2: The name of the second column """ if not isinstance(col1, str): raise ValueError("col1 should be a string.") if not isinstance(col2, str): raise ValueError("col2 should be a string.") return self._jdf.stat().cov(col1, col2)
@since(1.4)
[docs] def crosstab(self, col1, col2): """ Computes a pair-wise frequency table of the given columns. Also known as a contingency table. The number of distinct values for each column should be less than 1e4. At most 1e6 non-zero pair frequencies will be returned. The first column of each row will be the distinct values of `col1` and the column names will be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no occurrences will have zero as their counts. :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases. :param col1: The name of the first column. Distinct items will make the first item of each row. :param col2: The name of the second column. Distinct items will make the column names of the DataFrame. """ if not isinstance(col1, str): raise ValueError("col1 should be a string.") if not isinstance(col2, str): raise ValueError("col2 should be a string.") return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
@since(1.4)
[docs] def freqItems(self, cols, support=None): """ Finding frequent items for columns, possibly with false positives. Using the frequent element count algorithm described in "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. .. note:: This function is meant for exploratory data analysis, as we make no \ guarantee about the backward compatibility of the schema of the resulting DataFrame. :param cols: Names of the columns to calculate frequent items for as a list or tuple of strings. :param support: The frequency with which to consider an item 'frequent'. Default is 1%. The support must be greater than 1e-4. """ if isinstance(cols, tuple): cols = list(cols) if not isinstance(cols, list): raise ValueError("cols must be a list or tuple of column names as strings.") if not support: support = 0.01 return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def withColumn(self, colName, col): """ Returns a new :class:`DataFrame` by adding a column or replacing the existing column that has the same name. :param colName: string, name of the new column. :param col: a :class:`Column` expression for the new column. >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] """ assert isinstance(col, Column), "col should be Column" return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def withColumnRenamed(self, existing, new): """Returns a new :class:`DataFrame` by renaming an existing column. :param existing: string, name of the existing column to rename. :param col: string, new name of the column. >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] """ return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx)
@since(1.4) @ignore_unicode_prefix
[docs] def drop(self, col): """Returns a new :class:`DataFrame` that drops the specified column. :param col: a string name of the column to drop, or a :class:`Column` to drop. >>> df.drop('age').collect() [Row(name=u'Alice'), Row(name=u'Bob')] >>> df.drop(df.age).collect() [Row(name=u'Alice'), Row(name=u'Bob')] >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect() [Row(age=5, height=85, name=u'Bob')] >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect() [Row(age=5, name=u'Bob', height=85)] """ if isinstance(col, basestring): jdf = self._jdf.drop(col) elif isinstance(col, Column): jdf = self._jdf.drop(col._jc) else: raise TypeError("col should be a string or a Column") return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
[docs] def toDF(self, *cols): """Returns a new class:`DataFrame` that with new specified column names :param cols: list of new column names (string) >>> df.toDF('f1', 'f2').collect() [Row(f1=2, f2=u'Alice'), Row(f1=5, f2=u'Bob')] """ jdf = self._jdf.toDF(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx)
@since(1.3)
[docs] def toPandas(self): """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. Note that this method should only be used if the resulting Pandas's DataFrame is expected to be small, as all the data is loaded into the driver's memory. This is only available if Pandas is installed and available. >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice 1 5 Bob """ import pandas as pd return pd.DataFrame.from_records(self.collect(), columns=self.columns)
########################################################################################## # Pandas compatibility ########################################################################################## groupby = copy_func( groupBy, sinceversion=1.4, doc=":func:`groupby` is an alias for :func:`groupBy`.") drop_duplicates = copy_func( dropDuplicates, sinceversion=1.4, doc=":func:`drop_duplicates` is an alias for :func:`dropDuplicates`.") where = copy_func( filter, sinceversion=1.3, doc=":func:`where` is an alias for :func:`filter`.")
def _to_scala_map(sc, jm): """ Convert a dict into a JVM Map. """ return sc._jvm.PythonUtils.toScalaMap(jm)
[docs]class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. .. versionadded:: 1.4 """ def __init__(self, df): self.df = df
[docs] def drop(self, how='any', thresh=None, subset=None): return self.df.dropna(how=how, thresh=thresh, subset=subset)
drop.__doc__ = DataFrame.dropna.__doc__
[docs] def fill(self, value, subset=None): return self.df.fillna(value=value, subset=subset)
fill.__doc__ = DataFrame.fillna.__doc__
[docs] def replace(self, to_replace, value, subset=None): return self.df.replace(to_replace, value, subset)
replace.__doc__ = DataFrame.replace.__doc__
[docs]class DataFrameStatFunctions(object): """Functionality for statistic functions with :class:`DataFrame`. .. versionadded:: 1.4 """ def __init__(self, df): self.df = df
[docs] def approxQuantile(self, col, probabilities, relativeError): return self.df.approxQuantile(col, probabilities, relativeError)
approxQuantile.__doc__ = DataFrame.approxQuantile.__doc__
[docs] def corr(self, col1, col2, method=None): return self.df.corr(col1, col2, method)
corr.__doc__ = DataFrame.corr.__doc__
[docs] def cov(self, col1, col2): return self.df.cov(col1, col2)
cov.__doc__ = DataFrame.cov.__doc__
[docs] def crosstab(self, col1, col2): return self.df.crosstab(col1, col2)
crosstab.__doc__ = DataFrame.crosstab.__doc__
[docs] def freqItems(self, cols, support=None): return self.df.freqItems(cols, support)
freqItems.__doc__ = DataFrame.freqItems.__doc__
[docs] def sampleBy(self, col, fractions, seed=None): return self.df.sampleBy(col, fractions, seed)
sampleBy.__doc__ = DataFrame.sampleBy.__doc__
def _test(): import doctest from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext, SparkSession import pyspark.sql.dataframe globs = pyspark.sql.dataframe.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) globs['spark'] = SparkSession(sc) globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() globs['df3'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), Row(name='Bob', age=5, height=None), Row(name='Tom', age=None, height=None), Row(name=None, age=None, height=None)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) globs['sc'].stop() if failure_count: exit(-1) if __name__ == "__main__": _test()