Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Find columns that are exact duplicates (i.e., that contain duplicate values across all rows) in PySpark dataframe

I have a PySpark dataframe with ~4800 columns and am trying to find a way to identify and drop columns that have different column names but that are otherwise exact duplicates of one another.

For example, in the following dataframe, I would want to drop columns C and E (since they are duplicates of column A) and also know that columns C and E were the ones that were dropped.

+---+---+---+---+---+---+
|  A|  B|  C|  D|  E|  F|
+---+---+---+---+---+---+
|  1|  2|  1|  3|  1|  2|
+---+---+---+---+---+---+ 
|  1|  1|  1|  2|  1|  2|
+---+---+---+---+---+---+ 
|  1|  3|  1|  1|  1|  2|
+---+---+---+---+---+---+ 

I see this post with a potential solution - but it runs very slowly. I'm not sure if there's a way to optimize this to run more quickly on a larger dataframe?

like image 886
erin489 Avatar asked Oct 17 '25 03:10

erin489


1 Answers

A relatively fast way would be to calculate simple hashes of all columns to compare (at a cost of slight risk of collision).

Scala:

// Grab a list of columns
val cols = df.columns
// >>> Array(A, B, C, D, E, F)

// Calculate column hashes
// Include a unique id, so that same values in different order get a different hash
val hashes = df.withColumn("id", monotonically_increasing_id)
  .select(cols.map(c => hash($"id" * col(c)).as(c)):_*)
  .agg(sum(lit(0)).as("dummy"), df.columns.map(c => sum(col(c)).as(c)):_*)
  .drop("dummy")
// >>> +----------+-----------+----------+-----------+----------+-----------+
// >>> |         A|          B|         C|          D|         E|          F|
// >>> +----------+-----------+----------+-----------+----------+-----------+
// >>> |-515933930|-1948328522|-515933930|-2768907968|-515933930|-2362158726|
// >>> +----------+-----------+----------+-----------+----------+-----------+

// Group column names by their hash value
val groups = (hashes.columns zip hashes.head.toSeq).groupBy(_._2).mapValues(_.map(_._1))
// >>> Map(-515933930 -> Array(A, C, E), -1948328522 -> Array(B), -2768907968 -> Array(D), -2362158726 -> Array(F))

// Pick one column for each hash value and discard rest
val columnsToKeep = groups.values.map(_.head)
// >>> List(A, B, D, F)
val columnsToDrop = groups.values.flatMap(_.tail)
// >>> List(C, E)
val finalDf = df.select(columnsToKeep.toSeq.map(col):_*)
// >>> +---+---+---+---+
// >>> |  A|  B|  D|  F|
// >>> +---+---+---+---+
// >>> |  1|  2|  3|  2|
// >>> |  1|  1|  2|  2|
// >>> |  1|  3|  1|  2|
// >>> +---+---+---+---+

Python:

from pyspark.sql import functions as F

df_hashes = (df
    .withColumn("_id", F.monotonically_increasing_id())
    .agg(*[F.sum(F.hash("_id", c)).alias(c) for c in df.columns[::-1]])
)
keep = dict(zip(df_hashes.head(), df_hashes.columns)).values()

cols_to_keep = [c for c in df.columns if c in keep]
# ['A', 'B', 'D', 'F']

cols_to_drop = set(df.columns) - set(keep)
# {'C', 'E'}

df_final = df.select([c for c in cols_to_keep]).show()
# +---+---+---+---+
# |  A|  B|  D|  F|
# +---+---+---+---+
# |  1|  2|  3|  2|
# |  1|  1|  2|  2|
# |  1|  3|  1|  2|
# +---+---+---+---+
like image 101
Kombajn zbożowy Avatar answered Oct 20 '25 20:10

Kombajn zbożowy



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!