Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pyspark corr for each group in DF (more than 5K columns)

I have a DF with 100 million rows and 5000+ columns. I am trying to find the corr between colx and remaining 5000+ columns.

aggList1 =  [mean(col).alias(col + '_m') for col in df.columns]  #exclude keys
df21= df.groupBy('key1', 'key2', 'key3', 'key4').agg(*aggList1)
df = df.join(broadcast(df21),['key1', 'key2', 'key3', 'key4']))
df= df.select([func.round((func.col(colmd) - func.col(colmd + '_m')), 8).alias(colmd)\
                     for colmd in all5Kcolumns])


aggCols= [corr(colx, col).alias(col) for col in colsall5K]
df2 = df.groupBy('key1', 'key2', 'key3').agg(*aggCols)

Right now it is not working because of spark 64KB codegen issue (even spark 2.2). So i am looping for each 300 columns and merging all at the end. But it is taking more than 30 hours in a cluster with 40 nodes (10 core each and each node with 100GB). Any help to tune this?

Below things already tried - Re partition DF to 10,000 - Checkpoint in each loop - cache in each loop

like image 315
Harish Avatar asked Oct 30 '22 12:10

Harish


1 Answers

You can try with a bit of NumPy and RDDs. First a bunch of imports:

from operator import itemgetter
import numpy as np
from pyspark.statcounter import StatCounter

Let's define a few variables:

keys = ["key1", "key2", "key3"] # list of key column names
xs = ["x1", "x2", "x3"]    # list of column names to compare
y = "y"                         # name of the reference column

And some helpers:

def as_pair(keys, y, xs):
    """ Given key names, y name, and xs names
    return a tuple of key, array-of-values"""
    key = itemgetter(*keys)
    value = itemgetter(y, * xs)  # Python 3 syntax

    def as_pair_(row):
        return key(row), np.array(value(row))
    return as_pair_

def init(x):
    """ Init function for combineByKey
    Initialize new StatCounter and merge first value"""
    return StatCounter().merge(x)

def center(means):
    """Center a row value given a 
    dictionary of mean arrays
    """
    def center_(row):
        key, value = row
        return key, value - means[key]
    return center_

def prod(arr):
    return arr[0] * arr[1:]

def corr(stddev_prods):
    """Scale the row to get 1 stddev 
    given a dictionary of stddevs
    """
    def corr_(row):
        key, value = row
        return key, value / stddev_prods[key]
    return corr_

and convert DataFrame to RDD of pairs:

pairs = df.rdd.map(as_pair(keys, y, xs))

Next let's compute statistics per group:

stats = (pairs
    .combineByKey(init, StatCounter.merge, StatCounter.mergeStats)
    .collectAsMap())

means = {k: v.mean() for k, v in stats.items()}

Note: With 5000 features and 7000 group there should no issue with keeping this structure in memory. With larger datasets you may have to use RDD and join but this will be slower.

Center the data:

centered = pairs.map(center(means))

Compute covariance:

covariance = (centered
    .mapValues(prod)
    .combineByKey(init, StatCounter.merge, StatCounter.mergeStats)
    .mapValues(StatCounter.mean))

And finally correlation:

stddev_prods = {k: prod(v.stdev()) for k, v in stats.items()}

correlations = covariance.map(corr(stddev_prods))

Example data:

df = sc.parallelize([
    ("a", "b", "c", 0.5, 0.5, 0.3, 1.0),
    ("a", "b", "c", 0.8, 0.8, 0.9, -2.0), 
    ("a", "b", "c", 1.5, 1.5, 2.9, 3.6),
    ("d", "e", "f", -3.0, 4.0, 5.0, -10.0),
    ("d", "e", "f", 15.0, -1.0, -5.0, 10.0),
]).toDF(["key1", "key2", "key3", "y", "x1", "x2", "x3"])

Results with DataFrame:

df.groupBy(*keys).agg(*[corr(y, x) for x in xs]).show()
+----+----+----+-----------+------------------+------------------+
|key1|key2|key3|corr(y, x1)|       corr(y, x2)|       corr(y, x3)|
+----+----+----+-----------+------------------+------------------+
|   d|   e|   f|       -1.0|              -1.0|               1.0|
|   a|   b|   c|        1.0|0.9972300220940342|0.6513360726920862|
+----+----+----+-----------+------------------+------------------+

and the method provided above:

correlations.collect()
[(('a', 'b', 'c'), array([ 1.        ,  0.99723002,  0.65133607])),
 (('d', 'e', 'f'), array([-1., -1.,  1.]))]

This solution, while a bit involved, is quite elastic and can be easily adjusted to handle different data distributions. It should be also possible to given further boost with JIT.

like image 147
zero323 Avatar answered Nov 15 '22 08:11

zero323