Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PySpark - Get indices of duplicate rows

Let's say I have a PySpark data frame, like so:

+--+--+--+--+
|a |b |c |d |
+--+--+--+--+
|1 |0 |1 |2 |
|0 |2 |0 |1 |
|1 |0 |1 |2 |
|0 |4 |3 |1 |
+--+--+--+--+

How can I create a column marking all of the duplicate rows, like so:

+--+--+--+--+--+
|a |b |c |d |e |
+--+--+--+--+--+
|1 |0 |1 |2 |1 |
|0 |2 |0 |1 |0 |
|1 |0 |1 |2 |1 |
|0 |4 |3 |1 |0 |
+--+--+--+--+--+

I attempted it using the groupBy and aggregate functions to no avail.

like image 741
Chris C Avatar asked Jun 14 '18 20:06

Chris C


2 Answers

Just to expand on my comment:

You can group by all of the columns and use pyspark.sql.functions.count() to determine if a column is duplicated:

import pyspark.sql.functions as f
df.groupBy(df.columns).agg((f.count("*")>1).cast("int").alias("e")).show()
#+---+---+---+---+---+
#|  a|  b|  c|  d|  e|
#+---+---+---+---+---+
#|  1|  0|  1|  2|  1|
#|  0|  2|  0|  1|  0|
#|  0|  4|  3|  1|  0|
#+---+---+---+---+---+

Here we use count("*") > 1 as the aggregate function, and cast the result to an int. The groupBy() will have the consequence of dropping the duplicate rows. Depending on your needs, this may be sufficient.

However, if you'd like to keep all of the rows, you can use a Window function like shown in the other answers OR you can use a join():

df.join(
    df.groupBy(df.columns).agg((f.count("*")>1).cast("int").alias("e")),
    on=df.columns,
    how="inner"
).show()
#+---+---+---+---+---+
#|  a|  b|  c|  d|  e|
#+---+---+---+---+---+
#|  1|  0|  1|  2|  1|
#|  1|  0|  1|  2|  1|
#|  0|  2|  0|  1|  0|
#|  0|  4|  3|  1|  0|
#+---+---+---+---+---+

Here we inner join the original dataframe with the one that is the result of the groupBy() above on all of the columns.

like image 125
pault Avatar answered Oct 01 '22 17:10

pault


Define a window function to check whether the count of rows when grouped by all columns is greater than 1. If yes, its a duplicate (1) else not duplicate (0)

allColumns = df.columns
import sys
from pyspark.sql import functions as f
from pyspark.sql import window as w
windowSpec = w.Window.partitionBy(allColumns).rowsBetween(-sys.maxint, sys.maxint)

df.withColumn('e', f.when(f.count(f.col('d')).over(windowSpec) > 1, f.lit(1)).otherwise(f.lit(0))).show(truncate=False) 

which should give you

+---+---+---+---+---+
|a  |b  |c  |d  |e  |
+---+---+---+---+---+
|1  |0  |1  |2  |1  |
|1  |0  |1  |2  |1  |
|0  |2  |0  |1  |0  |
|0  |4  |3  |1  |0  |
+---+---+---+---+---+

I hope the answer is helpful

Updated

As @pault commented, you can eliminate when, col and lit by casting the boolean to integer:

df.withColumn('e', (f.count('*').over(windowSpec) > 1).cast('int')).show(truncate=False)
like image 36
Ramesh Maharjan Avatar answered Oct 01 '22 17:10

Ramesh Maharjan