Let's say I have a PySpark data frame, like so:
+--+--+--+--+
|a |b |c |d |
+--+--+--+--+
|1 |0 |1 |2 |
|0 |2 |0 |1 |
|1 |0 |1 |2 |
|0 |4 |3 |1 |
+--+--+--+--+
How can I create a column marking all of the duplicate rows, like so:
+--+--+--+--+--+
|a |b |c |d |e |
+--+--+--+--+--+
|1 |0 |1 |2 |1 |
|0 |2 |0 |1 |0 |
|1 |0 |1 |2 |1 |
|0 |4 |3 |1 |0 |
+--+--+--+--+--+
I attempted it using the groupBy and aggregate functions to no avail.
Just to expand on my comment:
You can group by all of the columns and use pyspark.sql.functions.count()
to determine if a column is duplicated:
import pyspark.sql.functions as f
df.groupBy(df.columns).agg((f.count("*")>1).cast("int").alias("e")).show()
#+---+---+---+---+---+
#| a| b| c| d| e|
#+---+---+---+---+---+
#| 1| 0| 1| 2| 1|
#| 0| 2| 0| 1| 0|
#| 0| 4| 3| 1| 0|
#+---+---+---+---+---+
Here we use count("*") > 1
as the aggregate function, and cast the result to an int
. The groupBy()
will have the consequence of dropping the duplicate rows. Depending on your needs, this may be sufficient.
However, if you'd like to keep all of the rows, you can use a Window
function like shown in the other answers OR you can use a join()
:
df.join(
df.groupBy(df.columns).agg((f.count("*")>1).cast("int").alias("e")),
on=df.columns,
how="inner"
).show()
#+---+---+---+---+---+
#| a| b| c| d| e|
#+---+---+---+---+---+
#| 1| 0| 1| 2| 1|
#| 1| 0| 1| 2| 1|
#| 0| 2| 0| 1| 0|
#| 0| 4| 3| 1| 0|
#+---+---+---+---+---+
Here we inner join the original dataframe with the one that is the result of the groupBy()
above on all of the columns.
Define a window
function to check whether the count
of rows when grouped by all columns is greater than 1. If yes, its a duplicate (1) else not duplicate (0)
allColumns = df.columns
import sys
from pyspark.sql import functions as f
from pyspark.sql import window as w
windowSpec = w.Window.partitionBy(allColumns).rowsBetween(-sys.maxint, sys.maxint)
df.withColumn('e', f.when(f.count(f.col('d')).over(windowSpec) > 1, f.lit(1)).otherwise(f.lit(0))).show(truncate=False)
which should give you
+---+---+---+---+---+
|a |b |c |d |e |
+---+---+---+---+---+
|1 |0 |1 |2 |1 |
|1 |0 |1 |2 |1 |
|0 |2 |0 |1 |0 |
|0 |4 |3 |1 |0 |
+---+---+---+---+---+
I hope the answer is helpful
Updated
As @pault commented, you can eliminate when
, col
and lit
by casting the boolean
to integer
:
df.withColumn('e', (f.count('*').over(windowSpec) > 1).cast('int')).show(truncate=False)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With