Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Aggregate sparse vector in PySpark

I have a Hive table that contains text data and some metadata associated to each document. Looks like this.

from pyspark.ml.feature import Tokenizer
from pyspark.ml.feature import CountVectorizer

df = sc.parallelize([
  ("1", "doc_1", "fruit is good for you"),
  ("2", "doc_2", "you should eat fruit and veggies"),
  ("2", "doc_3", "kids eat fruit but not veggies")
]).toDF(["month","doc_id", "text"])
+-----+------+--------------------+
|month|doc_id|                text|
+-----+------+--------------------+
|    1| doc_1|fruit is good for...|
|    2| doc_2|you should eat fr...|
|    2| doc_3|kids eat fruit bu...|
+-----+------+--------------------+

I want to count words by month. So far I've taken a CountVectorizer approach:

tokenizer = Tokenizer().setInputCol("text").setOutputCol("words")
tokenized = tokenizer.transform(df)

cvModel = CountVectorizer().setInputCol("words").setOutputCol("features").fit(tokenized)
counted = cvModel.transform(tokenized)
+-----+------+--------------------+--------------------+--------------------+
|month|doc_id|                text|               words|            features|
+-----+------+--------------------+--------------------+--------------------+
|    1| doc_1|fruit is good for...|[fruit, is, good,...|(12,[0,3,4,7,8],[...|
|    2| doc_2|you should eat fr...|[you, should, eat...|(12,[0,1,2,3,9,11...|
|    2| doc_3|kids eat fruit bu...|[kids, eat, fruit...|(12,[0,1,2,5,6,10...|
+-----+------+--------------------+--------------------+--------------------+

Now I want to group by month and return something that looks like:

month  word   count
1      fruit  1
1      is     1
...
2      fruit  2
2      kids   1
2      eat    2
... 

How could I do that?

like image 355
ADJ Avatar asked Oct 30 '22 14:10

ADJ


1 Answers

There is no built-in mechanism for Vector* aggregation but you don't need one here. Once you have tokenized data you can just explode and aggregate:

from pyspark.sql.functions import explode

(counted
    .select("month", explode("words").alias("word"))
    .groupBy("month", "word")
    .count())

If you prefer to limit the results to the vocabulary just add a filter:

from pyspark.sql.functions import col

(counted
    .select("month", explode("words").alias("word"))
    .where(col("word").isin(cvModel.vocabulary))
    .groupBy("month", "word")
    .count())

* Since Spark 2.4 we have access to Summarizer but it won't be useful here.

like image 103
zero323 Avatar answered Nov 03 '22 00:11

zero323