Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Retrieve Spark Mllib StringIndexer column mapping

How do I get the mapping out of a trained Spark MLlib StringIndexerModel?

val stringIndexer = new StringIndexer()
    .setInputCol("myCol")
    .setOutputCol("myColIdx")
val stringIndexerModel = stringIndexer.fit(data)
val res = stringIndexerModel.transform(data)

The code above will add a myColIdx to my DataFrame mapping values in myCol to an index based on the values frequency. i.e. Most frequent value -> 0, second most frequent -> 1, etc...

How do I retrieve that mapping from the model? If I serialize/deserialize the model, will the mapping be stable (i.e. Am I guaranteed to same result after the transform)?

like image 569
Marsellus Wallace Avatar asked Apr 23 '17 19:04

Marsellus Wallace


1 Answers

StringIndexerModel exposes the mapping using labels attribute:

stringIndexerModel.labels: Array[String]

where values correspond to consecutive labels for example for:

val data = Seq("foo", "bar", "foo", "bar", "foobar", "bar").toDF("myCol")

you'll get following labels:

import org.apache.spark.ml.feature.IndexToString

Array(bar, foo, foobar)

with bar indexed as 0.0, foo as 1.0 and foobar as 2.0. This is property of the model and is preserved when model is saved.

When used in Pipeline you can also use IndexToString which will use column metadata to map indices back to labels.

indexToString.transform(stringIndexerModel.transform(data)).show
+------+--------+-------------+
| myCol|myColIdx|myColReversed|
+------+--------+-------------+
|   foo|     1.0|          foo|
|   bar|     0.0|          bar|
|   foo|     1.0|          foo|
|   bar|     0.0|          bar|
|foobar|     2.0|       foobar|
|   bar|     0.0|          bar|
+------+--------+-------------+
like image 148
zero323 Avatar answered Oct 20 '22 22:10

zero323