Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Preserve index-string correspondence spark string indexer

Spark's StringIndexer is quite useful, but it's common to need to retrieve the correspondences between the generated index values and the original strings, and it seems like there should be a built-in way to accomplish this. I'll illustrate using this simple example from the Spark documentation:

from pyspark.ml.feature import StringIndexer

df = sqlContext.createDataFrame(
    [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")],
    ["id", "category"])
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
indexed_df = indexer.fit(df).transform(df)

This simplified case gives us:

+---+--------+-------------+
| id|category|categoryIndex|
+---+--------+-------------+
|  0|       a|          0.0|
|  1|       b|          2.0|
|  2|       c|          1.0|
|  3|       a|          0.0|
|  4|       a|          0.0|
|  5|       c|          1.0|
+---+--------+-------------+

All fine and dandy, but for many use cases I want to know the mapping between my original strings and the index labels. The simplest way I can think to do this off hand is something like this:

   In [8]: indexed.select('category','categoryIndex').distinct().show()
+--------+-------------+
|category|categoryIndex|
+--------+-------------+
|       b|          2.0|
|       c|          1.0|
|       a|          0.0|
+--------+-------------+

The result of which I could store as a dictionary or similar if I wanted:

In [12]: mapping = {row.categoryIndex:row.category for row in
           indexed.select('category','categoryIndex').distinct().collect()}

In [13]: mapping
Out[13]: {0.0: u'a', 1.0: u'c', 2.0: u'b'}

My question is this: Since this is such a common task, and I'm guessing (but could of course be wrong) that the string indexer is somehow storing this mapping anyway, is there a way to accomplish the above task more simply?

My solution is more or less straightforward, but for large data structures this involves a bunch of extra computation that (perhaps) I can avoid. Ideas?

like image 949
moustachio Avatar asked Nov 10 '15 18:11

moustachio


1 Answers

Label mapping can extracted from the column metadata:

meta = [
    f.metadata for f in indexed_df.schema.fields if f.name == "categoryIndex"
]
meta[0]
## {'ml_attr': {'name': 'category', 'type': 'nominal', 'vals': ['a', 'c', 'b']}}

where ml_attr.vals provide a mapping between position and label:

dict(enumerate(meta[0]["ml_attr"]["vals"]))
## {0: 'a', 1: 'c', 2: 'b'}

Spark 1.6+

You can convert numeric values to labels using IndexToString. This will use column metadata as shown above.

from pyspark.ml.feature import IndexToString

idx_to_string = IndexToString(
    inputCol="categoryIndex", outputCol="categoryValue")

idx_to_string.transform(indexed_df).drop("id").distinct().show()
## +--------+-------------+-------------+
## |category|categoryIndex|categoryValue|
## +--------+-------------+-------------+
## |       b|          2.0|            b|
## |       a|          0.0|            a|
## |       c|          1.0|            c|
## +--------+-------------+-------------+

Spark <= 1.5

It is a dirty hack but you can simply extract labels from a Java indexer as follows:

from pyspark.ml.feature import StringIndexerModel

# A simple monkey patch so we don't have to _call_java later 
def labels(self):
    return self._call_java("labels")

StringIndexerModel.labels = labels

# Fit indexer model
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex").fit(df)

# Extract mapping
mapping = dict(enumerate(indexer.labels()))
mapping
## {0: 'a', 1: 'c', 2: 'b'}
like image 174
zero323 Avatar answered Oct 21 '22 22:10

zero323