Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Stratified sampling with pyspark

I have a Spark DataFrame that has one column that has lots of zeros and very few ones (only 0.01% of ones).

I'd like to take a random subsample but a stratified one - so that it keeps the ratio of 1s to 0s in that column.

Is it possible to do in pyspark ?

I am looking for a non-scala solution and on based on DataFrames and not RDD-based.

like image 308
user3245256 Avatar asked Dec 04 '17 16:12

user3245256


People also ask

How do you do stratified sampling in PySpark?

1.4 Stratified sampling in PySpark You can get Stratified sampling in PySpark without replacement by using sampleBy() method. It returns a sampling fraction for each stratum. If a stratum is not specified, it takes zero as the default. fractions – It's Dictionary type takes key and value.

How do you create dummy data in PySpark?

In order to create an empty PySpark DataFrame manually with schema ( column names & data types) first, Create a schema using StructType and StructField . Now use the empty RDD created above and pass it to createDataFrame() of SparkSession along with the schema for column names & data types.

How do you use PySpark collect?

PySpark Collect() – Retrieve data from DataFrame. Collect() is the function, operation for RDD or Dataframe that is used to retrieve the data from the Dataframe. It is used useful in retrieving all the elements of the row from each partition in an RDD and brings that over the driver node/program.


2 Answers

The solution I suggested in Stratified sampling in Spark is pretty straightforward to convert from Scala to Python (or even to Java - What's the easiest way to stratify a Spark Dataset ?).

Nevertheless, I'll rewrite it python. Let's start first by creating a toy DataFrame :

from pyspark.sql.functions import lit
list = [(2147481832,23355149,1),(2147481832,973010692,1),(2147481832,2134870842,1),(2147481832,541023347,1),(2147481832,1682206630,1),(2147481832,1138211459,1),(2147481832,852202566,1),(2147481832,201375938,1),(2147481832,486538879,1),(2147481832,919187908,1),(214748183,919187908,1),(214748183,91187908,1)]
df = spark.createDataFrame(list, ["x1","x2","x3"])
df.show()
# +----------+----------+---+
# |        x1|        x2| x3|
# +----------+----------+---+
# |2147481832|  23355149|  1|
# |2147481832| 973010692|  1|
# |2147481832|2134870842|  1|
# |2147481832| 541023347|  1|
# |2147481832|1682206630|  1|
# |2147481832|1138211459|  1|
# |2147481832| 852202566|  1|
# |2147481832| 201375938|  1|
# |2147481832| 486538879|  1|
# |2147481832| 919187908|  1|
# | 214748183| 919187908|  1|
# | 214748183|  91187908|  1|
# +----------+----------+---+

This DataFrame has 12 elements as you can see :

df.count()
# 12

Distributed as followed :

df.groupBy("x1").count().show()
# +----------+-----+
# |        x1|count|
# +----------+-----+
# |2147481832|   10|
# | 214748183|    2|
# +----------+-----+

Now let's sample :

First we'll set the seed :

seed = 12

The find the keys to fraction on and sample :

fractions = df.select("x1").distinct().withColumn("fraction", lit(0.8)).rdd.collectAsMap()
print(fractions)                                                            
# {2147481832: 0.8, 214748183: 0.8}
sampled_df = df.stat.sampleBy("x1", fractions, seed)
sampled_df.show()
# +----------+---------+---+
# |        x1|       x2| x3|
# +----------+---------+---+
# |2147481832| 23355149|  1|
# |2147481832|973010692|  1|
# |2147481832|541023347|  1|
# |2147481832|852202566|  1|
# |2147481832|201375938|  1|
# |2147481832|486538879|  1|
# |2147481832|919187908|  1|
# | 214748183|919187908|  1|
# | 214748183| 91187908|  1|
# +----------+---------+---+

We can now check the content of our sample :

sampled_df.count()
# 9

sampled_df.groupBy("x1").count().show()
# +----------+-----+
# |        x1|count|
# +----------+-----+
# |2147481832|    7|
# | 214748183|    2|
# +----------+-----+
like image 70
eliasah Avatar answered Sep 17 '22 05:09

eliasah


Assume you have titanic dataset in 'data' dataframe which you want to split into train and test set using stratified sampling based on the 'Survived' target variable.

  # Check initial distributions of 0's and 1's
-> data.groupBy("Survived").count().show()

 Survived|count|
 +--------+-----+
 |       1|  342|
 |       0|  549


  # Taking 70% of both 0's and 1's into training set
-> train = data.sampleBy("Survived", fractions={0: 0.7, 1: 0.7}, seed=10)

  # Subtracting 'train' from original 'data' to get test set 
-> test = data.subtract(train)



  # Checking distributions of 0's and 1's in train and test sets after the sampling
-> train.groupBy("Survived").count().show()
+--------+-----+
|Survived|count|
+--------+-----+
|       1|  239|
|       0|  399|
+--------+-----+
-> test.groupBy("Survived").count().show()

+--------+-----+
|Survived|count|
+--------+-----+
|       1|  103|
|       0|  150|
+--------+-----+
like image 30
Ankit Sharma Avatar answered Sep 19 '22 05:09

Ankit Sharma