I have a PySpark dataframe with ~4800 columns and am trying to find a way to identify and drop columns that have different column names but that are otherwise exact duplicates of one another.
For example, in the following dataframe, I would want to drop columns C and E (since they are duplicates of column A) and also know that columns C and E were the ones that were dropped.
+---+---+---+---+---+---+
| A| B| C| D| E| F|
+---+---+---+---+---+---+
| 1| 2| 1| 3| 1| 2|
+---+---+---+---+---+---+
| 1| 1| 1| 2| 1| 2|
+---+---+---+---+---+---+
| 1| 3| 1| 1| 1| 2|
+---+---+---+---+---+---+
I see this post with a potential solution - but it runs very slowly. I'm not sure if there's a way to optimize this to run more quickly on a larger dataframe?
A relatively fast way would be to calculate simple hashes of all columns to compare (at a cost of slight risk of collision).
Scala:
// Grab a list of columns
val cols = df.columns
// >>> Array(A, B, C, D, E, F)
// Calculate column hashes
// Include a unique id, so that same values in different order get a different hash
val hashes = df.withColumn("id", monotonically_increasing_id)
.select(cols.map(c => hash($"id" * col(c)).as(c)):_*)
.agg(sum(lit(0)).as("dummy"), df.columns.map(c => sum(col(c)).as(c)):_*)
.drop("dummy")
// >>> +----------+-----------+----------+-----------+----------+-----------+
// >>> | A| B| C| D| E| F|
// >>> +----------+-----------+----------+-----------+----------+-----------+
// >>> |-515933930|-1948328522|-515933930|-2768907968|-515933930|-2362158726|
// >>> +----------+-----------+----------+-----------+----------+-----------+
// Group column names by their hash value
val groups = (hashes.columns zip hashes.head.toSeq).groupBy(_._2).mapValues(_.map(_._1))
// >>> Map(-515933930 -> Array(A, C, E), -1948328522 -> Array(B), -2768907968 -> Array(D), -2362158726 -> Array(F))
// Pick one column for each hash value and discard rest
val columnsToKeep = groups.values.map(_.head)
// >>> List(A, B, D, F)
val columnsToDrop = groups.values.flatMap(_.tail)
// >>> List(C, E)
val finalDf = df.select(columnsToKeep.toSeq.map(col):_*)
// >>> +---+---+---+---+
// >>> | A| B| D| F|
// >>> +---+---+---+---+
// >>> | 1| 2| 3| 2|
// >>> | 1| 1| 2| 2|
// >>> | 1| 3| 1| 2|
// >>> +---+---+---+---+
Python:
from pyspark.sql import functions as F
df_hashes = (df
.withColumn("_id", F.monotonically_increasing_id())
.agg(*[F.sum(F.hash("_id", c)).alias(c) for c in df.columns[::-1]])
)
keep = dict(zip(df_hashes.head(), df_hashes.columns)).values()
cols_to_keep = [c for c in df.columns if c in keep]
# ['A', 'B', 'D', 'F']
cols_to_drop = set(df.columns) - set(keep)
# {'C', 'E'}
df_final = df.select([c for c in cols_to_keep]).show()
# +---+---+---+---+
# | A| B| D| F|
# +---+---+---+---+
# | 1| 2| 3| 2|
# | 1| 1| 2| 2|
# | 1| 3| 1| 2|
# +---+---+---+---+
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With