Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to index categorical features in another way when using spark ml

The VectorIndexer in spark indexes categorical features according to the frequency of variables. But I want to index the categorical features in a different way.

For example, with a dataset as below, "a","b","c" will be indexed as 0,1,2 if I use the VectorIndexer in spark. But I want to index them according to the label. There are 4 rows data which are indexed as 1, and among them 3 rows have feature 'a',1 row feautre 'c'. So here I will index 'a' as 0,'c' as 1 and 'b' as 2.

Is there any convienient way to implement this?

 label|feature
-----------------
    1 | a
    1 | c
    0 | a
    0 | b
    1 | a
    0 | b
    0 | b
    0 | c
    1 | a
like image 763
April Avatar asked Dec 20 '25 17:12

April


1 Answers

If I understand your question correctly, you are looking to replicate the behaviour of StringIndexer() on grouped data. To do so (in pySpark), we first define an udf that will operate on a List column containing all the values per group. Note that elements with equal counts will be ordered arbitrarily.

from collections import Counter
from pyspark.sql.types import ArrayType, IntegerType

def encoder(col):

  # Generate count per letter
  x = Counter(col)

  # Create a dictionary, mapping each letter to its rank
  ranking = {pair[0]: rank 
           for rank, pair in enumerate(x.most_common())}

  # Use dictionary to replace letters by rank
  new_list = [ranking[i] for i in col]

  return(new_list)

encoder_udf = udf(encoder, ArrayType(IntegerType()))

Now we can aggregate the feature column into a list grouped by the column label using collect_list() , and apply our udf rowwise:

from pyspark.sql.functions import collect_list, explode

df1 = (df.groupBy("label")
       .agg(collect_list("feature")
            .alias("features"))
       .withColumn("index", 
                   encoder_udf("features")))

Consequently, you can explode the index column to get the encoded values instead of the letters:

df1.select("label", explode(df1.index).alias("index")).show()
+-----+-----+
|label|index|
+-----+-----+
|    0|    1|
|    0|    0|
|    0|    0|
|    0|    0|
|    0|    2|
|    1|    0|
|    1|    1|
|    1|    0|
|    1|    0|
+-----+-----+
like image 95
mtoto Avatar answered Dec 24 '25 02:12

mtoto



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!