I have a dataframe in Spark 2 as shown below where users have between 50 to thousands of posts. I would like to create a new dataframe that will have all the users in the original dataframe but with only 5 randomly sampled posts for each user.
+--------+--------------+--------------------+
| user_id| post_id| text|
+--------+--------------+--------------------+
|67778705|44783131591473|some text...........|
|67778705|44783134580755|some text...........|
|67778705|44783136367108|some text...........|
|67778705|44783136970669|some text...........|
|67778705|44783138143396|some text...........|
|67778705|44783155162624|some text...........|
|67778705|44783688650554|some text...........|
|68950272|88655645825660|some text...........|
|68950272|88651393135293|some text...........|
|68950272|88652615409812|some text...........|
|68950272|88655744880460|some text...........|
|68950272|88658059871568|some text...........|
|68950272|88656994832475|some text...........|
+--------+--------------+--------------------+
Something like posts.groupby('user_id').agg(sample('post_id'))
but there is no such function in pyspark.
Any advice?
Update:
This question is different from another closely related question stratified-sampling-in-spark in two ways:
I have also updated the question's title to clarify this.
Using sampleBy
will result in approximate solution. Here is an alternative approach that is a little more hacky than the approach above but always results in exactly the same sample sizes.
import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window
df.withColumn("row_num",row_number().over(Window.partitionBy($"user_id").orderBy($"something_random"))
If you don't already have a random ID then you can use org.apache.spark.sql.functions.rand
to create a column with a random value to guarantee your random sampling.
You can use the .sampleBy(...)
method for DataFrames
http://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.DataFrame.sampleBy
Here's a working example:
import numpy as np
import string
import random
# generate some fake data
p = [(
str(int(e)),
''.join(
random.choice(
string.ascii_uppercase + string.digits)
for _ in range(10)
)
) for e in np.random.normal(10, 1, 10000)]
posts = spark.createDataFrame(p, ['label', 'val'])
# define the sample size
percent_back = 0.05
# use this if you want an (almost) exact number of samples
# sample_count = 200
# percent_back = sample_count / posts.count()
frac = dict(
(e.label, percent_back)
for e
in posts.select('label').distinct().collect()
)
# use this if you want (almost) balanced sample
# f = posts.groupby('label').count()
# f_min_count can also be specified to be exact number
# e.g. f_min_count = 5
# as long as it is less the the minimum count of posts per user
# calculated from all the users
# alternatively, you can take the minimum post count
# f_min_count = f.select('count').agg(func.min('count').alias('minVal')).collect()[0].minVal
# f = f.withColumn('frac',f_min_count/func.col('count'))
# frac = dict(f.select('label', 'frac').collect())
# sample the data
sampled = posts.sampleBy('label', fractions=frac)
# compare the original counts with sampled
original_total_count = posts.count()
original_counts = posts.groupby('label').count()
original_counts = original_counts \
.withColumn('count_perc',
original_counts['count'] / original_total_count)
sampled_total_count = sampled.count()
sampled_counts = sampled.groupBy('label').count()
sampled_counts = sampled_counts \
.withColumn('count_perc',
sampled_counts['count'] / sampled_total_count)
print(original_counts.sort('label').show(100))
print(sampled_counts.sort('label').show(100))
print(sampled_total_count)
print(sampled_total_count / original_total_count)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With