Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark StringIndexer.fit is very slow on large records

I have large data records formatted as the following sample:

// +---+------+------+
// |cid|itemId|bought|
// +---+------+------+
// |abc|   123|  true|
// |abc|   345|  true|
// |abc|   567|  true|
// |def|   123|  true|
// |def|   345|  true|
// |def|   567|  true|
// |def|   789| false|
// +---+------+------+

cid and itemId are strings.

There are 965,964,223 records.

I am trying to convert cid to an integer using StringIndexer as follows:

dataset.repartition(50)
val cidIndexer = new StringIndexer().setInputCol("cid").setOutputCol("cidIndex")
val cidIndexedMatrix = cidIndexer.fit(dataset).transform(dataset)

But these lines of code are very slow (takes around 30 minutes). The problem is that it is so huge that I could not do anything further after that.

I am using amazon EMR cluster of R4 2XLarge cluster with 2 nodes (61 GB of memory).

Is there any performance improvement that I can do further? Any help will be much appreciated.

like image 953
Rengasami Ramanujam Avatar asked Jul 23 '18 19:07

Rengasami Ramanujam


Video Answer


1 Answers

That is an expected behavior, if cardinality of column is high. As a part of the training process, StringIndexer collects all the labels, and to create label - index mapping (using Spark's o.a.s.util.collection.OpenHashMap).

This process requires O(N) memory in the worst case scenario, and is both computationally and memory intensive.

In cases where cardinality of the column is high, and its content is going to be used as feature, it is better to apply FeatureHasher (Spark 2.3 or later).

import org.apache.spark.ml.feature.FeatureHasher

val hasher = new FeatureHasher()
  .setInputCols("cid")
  .setOutputCols("cid_hash_vec")
hasher.transform(dataset)

It doesn't guarantee uniqueness and it is not reversible, but it is good enough for many applications, and doesn't require fitting process.

For column that won't be used as a feature you can also use hash function:

import org.apache.spark.sql.functions.hash

dataset.withColumn("cid_hash", hash($"cid"))
like image 94
10465355 Avatar answered Sep 21 '22 17:09

10465355