Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Adding a group count column to a PySpark dataframe

I am coming from R and the tidyverse to PySpark due to its superior Spark handling, and I am struggling to map certain concepts from one context to the other.

In particular, suppose that I had a dataset like the following

x | y --+-- a | 5 a | 8 a | 7 b | 1 

and I wanted to add a column containing the number of rows for each x value, like so:

x | y | n --+---+--- a | 5 | 3 a | 8 | 3 a | 7 | 3 b | 1 | 1 

In dplyr, I would just say:

import(tidyverse)  df <- read_csv("...") df %>%     group_by(x) %>%     mutate(n = n()) %>%     ungroup() 

and that would be that. I can do something almost as simple in PySpark if I'm looking to summarize by number of rows:

from pyspark.sql import SparkSession from pyspark.sql.functions import col  spark = SparkSession.builder.getOrCreate()  spark.read.csv("...") \     .groupBy(col("x")) \     .count() \     .show() 

And I thought I understood that withColumn was equivalent to dplyr's mutate. However, when I do the following, PySpark tells me that withColumn is not defined for groupBy data:

from pyspark.sql import SparkSession from pyspark.sql.functions import col, count  spark = SparkSession.builder.getOrCreate()  spark.read.csv("...") \     .groupBy(col("x")) \     .withColumn("n", count("x")) \     .show() 

In the short run, I can simply create a second dataframe containing the counts and join it to the original dataframe. However, it seems like this could become inefficient in the case of large tables. What is the canonical way to accomplish this?

like image 437
David Bruce Borenstein Avatar asked Feb 14 '18 18:02

David Bruce Borenstein


People also ask

How do you do a groupBy count in PySpark?

PySpark Groupby Count is used to get the number of records for each group. So to perform the count, first, you need to perform the groupBy() on DataFrame which groups the records based on single or multiple column values, and then do the count() to get the number of records for each group.

How do you add a count column in PySpark?

Method -1 : Using select() method If we want to return the count from multiple columns, we have to use the count () method inside the select() method by specifying the column name separated by a comma. Where, df is the input PySpark DataFrame. column_name is the column to get the total number of rows (count).

How do I import a count in DataFrame PySpark?

In Pyspark, there are two ways to get the count of distinct values. We can use distinct() and count() functions of DataFrame to get the count distinct of PySpark DataFrame. Another way is to use SQL countDistinct() function which will provide the distinct value count of all the selected columns.

How do you get all the columns after groupBy in PySpark?

1 Answer. Suppose you have a df that includes columns “name” and “age”, and on these two columns you want to perform groupBY. Now, in order to get other columns also after doing a groupBy you can use join function. Now, data_joined will have all columns including the count values.


2 Answers

When you do a groupBy(), you have to specify the aggregation before you can display the results. For example:

import pyspark.sql.functions as f data = [     ('a', 5),     ('a', 8),     ('a', 7),     ('b', 1), ] df = sqlCtx.createDataFrame(data, ["x", "y"]) df.groupBy('x').count().select('x', f.col('count').alias('n')).show() #+---+---+ #|  x|  n| #+---+---+ #|  b|  1| #|  a|  3| #+---+---+ 

Here I used alias() to rename the column. But this only returns one row per group. If you want all rows with the count appended, you can do this with a Window:

from pyspark.sql import Window w = Window.partitionBy('x') df.select('x', 'y', f.count('x').over(w).alias('n')).sort('x', 'y').show() #+---+---+---+ #|  x|  y|  n| #+---+---+---+ #|  a|  5|  3| #|  a|  7|  3| #|  a|  8|  3| #|  b|  1|  1| #+---+---+---+ 

Or if you're more comfortable with SQL, you can register the dataframe as a temporary table and take advantage of pyspark-sql to do the same thing:

df.registerTempTable('table') sqlCtx.sql(     'SELECT x, y, COUNT(x) OVER (PARTITION BY x) AS n FROM table ORDER BY x, y' ).show() #+---+---+---+ #|  x|  y|  n| #+---+---+---+ #|  a|  5|  3| #|  a|  7|  3| #|  a|  8|  3| #|  b|  1|  1| #+---+---+---+ 
like image 76
pault Avatar answered Sep 30 '22 19:09

pault


as @pault appendix

import pyspark.sql.functions as F  ...  (df .groupBy(F.col('x')) .agg(F.count('x').alias('n')) .show())  #+---+---+ #|  x|  n| #+---+---+ #|  b|  1| #|  a|  3| #+---+---+ 

enjoy

like image 25
wiesiu_p Avatar answered Sep 30 '22 18:09

wiesiu_p