Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Get IDs for duplicate rows (considering all other columns) in Apache Spark

I have a Spark sql dataframe, consisting of an ID column and n "data" columns, i.e.

id | dat1 | dat2 | ... | datn

The id columnn is uniquely determined, whereas, looking at dat1 ... datn there may be duplicates.

My goal is to find the ids of those duplicates.

My approach so far:

  • get the duplicate rows using groupBy:

    dup_df = df.groupBy(df.columns[1:]).count().filter('count > 1')

  • join the dup_df with the entire df to get the duplicate rows including id:

    df.join(dup_df, df.columns[1:])

I am quite certain that this is basically correct, it fails because the dat1 ... datn columns contain null values.

To do the join on null values, I found .e.g this SO post. But this would require to construct a huge "string join condition".

Thus my questions:

  1. Is there a simple / more generic / more pythonic way to do joins on null values?
  2. Or, even better, is there another (easier, more beautiful, ...) method to get the desired ids?

BTW: I am using Spark 2.1.0 and Python 3.5.3

like image 636
akoeltringer Avatar asked Mar 29 '17 13:03

akoeltringer


People also ask

How do I find duplicate rows in spark data frame?

➠ Find complete row duplicates: GroupBy can be used along with count() aggregate function on all the columns (using df. ➠ Find column level duplicates: GroupBy with required columns can be used along with count() aggregate function and then filter can be used to get duplicate records.

How do you drop duplicate rows based on one column in PySpark?

PySpark distinct() function is used to drop/remove the duplicate rows (all columns) from DataFrame and dropDuplicates() is used to drop rows based on selected (one or multiple) columns.

How can I see the number of duplicate rows?

You can count the number of duplicate rows by counting True in pandas. Series obtained with duplicated() . The number of True can be counted with sum() method. If you want to count the number of False (= the number of non-duplicate rows), you can invert it with negation ~ and then count True with sum() .

How do you count duplicate rows in PySpark?

In PySpark, you can use distinct(). count() of DataFrame or countDistinct() SQL function to get the count distinct. distinct() eliminates duplicate records(matching all columns of a Row) from DataFrame, count() returns the count of records on DataFrame.


1 Answers

If number ids per group is relatively small you can groupBy and collect_list. Required imports

from pyspark.sql.functions import collect_list, size

example data:

df = sc.parallelize([
    (1, "a", "b", 3),
    (2, None, "f", None),
    (3, "g", "h", 4),
    (4, None, "f", None),
    (5, "a", "b", 3)
]).toDF(["id"])

query:

(df
   .groupBy(df.columns[1:])
   .agg(collect_list("id").alias("ids"))
   .where(size("ids") > 1))

and the result:

+----+---+----+------+
|  _2| _3|  _4|   ids|
+----+---+----+------+
|null|  f|null|[2, 4]|
|   a|  b|   3|[1, 5]|
+----+---+----+------+

You can apply explode twice (or use an udf) to an output equivalent to the one returned from join.

You can also identify groups using minimal id per group. A few additional imports:

from pyspark.sql.window import Window
from pyspark.sql.functions import col, count, min

window definition:

w = Window.partitionBy(df.columns[1:])

query:

(df
    .select(
        "*", 
        count("*").over(w).alias("_cnt"), 
        min("id").over(w).alias("group"))
    .where(col("_cnt") > 1))

and the result:

+---+----+---+----+----+-----+
| id|  _2| _3|  _4|_cnt|group|
+---+----+---+----+----+-----+
|  2|null|  f|null|   2|    2|
|  4|null|  f|null|   2|    2|
|  1|   a|  b|   3|   2|    1|
|  5|   a|  b|   3|   2|    1|
+---+----+---+----+----+-----+

You can further use group column for self join.

like image 103
zero323 Avatar answered Oct 24 '22 12:10

zero323