Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pyspark Dataframe group by filtering

I have a data frame as below

cust_id   req    req_met
-------   ---    -------
 1         r1      1
 1         r2      0
 1         r2      1
 2         r1      1
 3         r1      1
 3         r2      1
 4         r1      0
 5         r1      1
 5         r2      0
 5         r1      1

I have to look at customers, see how many requirements they have and see if they have met at least once. There can be multiple records with same customer and requirement, one with met and not met. In the above case my output should be

cust_id
-------
  1
  2
  3

What I have done is

# say initial dataframe is df
df1 = df\
    .groupby('cust_id')\
    .countdistinct('req')\
    .alias('num_of_req')\
    .sum('req_met')\
    .alias('sum_req_met')

df2 = df1.filter(df1.num_of_req == df1.sum_req_met)

But in few cases it is not getting correct results

How can this be done ?

like image 527
Lijju Mathew Avatar asked Mar 16 '17 06:03

Lijju Mathew


1 Answers

First, I'll just prepare toy dataset from given above,

from pyspark.sql.functions import col
import pyspark.sql.functions as fn

df = spark.createDataFrame([[1, 'r1', 1],
 [1, 'r2', 0],
 [1, 'r2', 1],
 [2, 'r1', 1],
 [3, 'r1', 1],
 [3, 'r2', 1],
 [4, 'r1', 0],
 [5, 'r1', 1],
 [5, 'r2', 0],
 [5, 'r1', 1]], schema=['cust_id', 'req', 'req_met'])
df = df.withColumn('req_met', col("req_met").cast(IntegerType()))
df = df.withColumn('cust_id', col("cust_id").cast(IntegerType()))

I do the same thing by group by cust_id and req then count the req_met. After that, I create function to floor those requirement to just 0, 1

def floor_req(r):
    if r >= 1:
        return 1
    else:
        return 0
udf_floor_req = udf(floor_req, IntegerType())
gr = df.groupby(['cust_id', 'req'])
df_grouped = gr.agg(fn.sum(col('req_met')).alias('sum_req_met'))
df_grouped_floor = df_grouped.withColumn('sum_req_met', udf_floor_req('sum_req_met'))

Now, we can check if each customer has met all requirement by counting distinct number of requirement and total number of requirement met.

df_req = df_grouped_floor.groupby('cust_id').agg(fn.sum('sum_req_met').alias('sum_req'), 
                                                 fn.count('req').alias('n_req'))

Finally, you just have to check if two columns are equal:

df_req.filter(df_req['sum_req'] == df_req['n_req'])[['cust_id']].orderBy('cust_id').show()
like image 90
titipata Avatar answered Oct 10 '22 11:10

titipata