Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Choosing random items from a Spark GroupedData Object

Tags:

I'm new to using Spark in Python and have been unable to solve this problem: After running groupBy on a pyspark.sql.dataframe.DataFrame

df = sqlsc.read.json("data.json") df.groupBy('teamId') 

how can you choose N random samples from each resulting group (grouped by teamId) without replacement?

I'm basically trying to choose N random users from each team, maybe using groupBy is wrong to start with?

like image 812
Nyxynyx Avatar asked Nov 17 '15 05:11

Nyxynyx


1 Answers

Well, it is kind of wrong. GroupedData is not really designed for a data access. It just describes grouping criteria and provides aggregation methods. See my answer to Using groupBy in Spark and getting back to a DataFrame for more details.

Another problem with this idea is selecting N random samples. It is a task which is really hard to achieve in parallel without psychical grouping of data and it is not something that happens when you call groupBy on a DataFrame:

There are at least two ways to handle this:

  • convert to RDD, groupBy and perform local sampling

    import random  n = 3  def sample(iter, n):      rs = random.Random()  # We should probably use os.urandom as a seed     return rs.sample(list(iter), n)      df = sqlContext.createDataFrame(     [(x, y, random.random()) for x in (1, 2, 3) for y in "abcdefghi"],      ("teamId", "x1", "x2"))  grouped = df.rdd.map(lambda row: (row.teamId, row)).groupByKey()  sampled = sqlContext.createDataFrame(     grouped.flatMap(lambda kv: sample(kv[1], n)))  sampled.show()  ## +------+---+-------------------+ ## |teamId| x1|                 x2| ## +------+---+-------------------+ ## |     1|  g|   0.81921738561455| ## |     1|  f| 0.8563875814036598| ## |     1|  a| 0.9010425238735935| ## |     2|  c| 0.3864428179837973| ## |     2|  g|0.06233470405822805| ## |     2|  d|0.37620872770129155| ## |     3|  f| 0.7518901502732027| ## |     3|  e| 0.5142305439671874| ## |     3|  d| 0.6250620479303716| ## +------+---+-------------------+ 
  • use window functions

    from pyspark.sql import Window from pyspark.sql.functions import col, rand, rowNumber  w = Window.partitionBy(col("teamId")).orderBy(col("rnd_"))  sampled = (df     .withColumn("rnd_", rand())  # Add random numbers column     .withColumn("rn_", rowNumber().over(w))  # Add rowNumber over windw     .where(col("rn_") <= n)  # Take n observations     .drop("rn_")  # drop helper columns     .drop("rnd_"))  sampled.show()  ## +------+---+--------------------+ ## |teamId| x1|                  x2| ## +------+---+--------------------+ ## |     1|  f|  0.8563875814036598| ## |     1|  g|    0.81921738561455| ## |     1|  i|  0.8173912535268248| ## |     2|  h| 0.10862995810038856| ## |     2|  c|  0.3864428179837973| ## |     2|  a|  0.6695356657072442| ## |     3|  b|0.012329360826023095| ## |     3|  a|  0.6450777858109182| ## |     3|  e|  0.5142305439671874| ## +------+---+--------------------+ 

but I am afraid both will be rather expensive. If size of the individual groups is balanced and relatively large I would simply use DataFrame.randomSplit.

If number of groups is relatively small it is possible to try something else:

from pyspark.sql.functions import count, udf from pyspark.sql.types import BooleanType from operator import truediv  counts = (df     .groupBy(col("teamId"))     .agg(count("*").alias("n"))     .rdd.map(lambda r: (r.teamId, r.n))     .collectAsMap())   # This defines fraction of observations from a group which should # be taken to get n values  counts_bd = sc.broadcast({k: truediv(n, v) for (k, v) in counts.items()})  to_take = udf(lambda k, rnd: rnd <= counts_bd.value.get(k), BooleanType())  sampled = (df     .withColumn("rnd_", rand())     .where(to_take(col("teamId"), col("rnd_")))     .drop("rnd_"))  sampled.show()  ## +------+---+--------------------+ ## |teamId| x1|                  x2| ## +------+---+--------------------+ ## |     1|  d| 0.14815204548854788| ## |     1|  f|  0.8563875814036598| ## |     1|  g|    0.81921738561455| ## |     2|  a|  0.6695356657072442| ## |     2|  d| 0.37620872770129155| ## |     2|  g| 0.06233470405822805| ## |     3|  b|0.012329360826023095| ## |     3|  h|  0.9022527556458557| ## +------+---+--------------------+ 

In Spark 1.5+ you can replace udf with a call to sampleBy method:

df.sampleBy("teamId", counts_bd.value) 

It won't give you exact number of observations but should be good enough most of the time as long as a number of observations per group is large enough to get proper samples. You can also use sampleByKey on a RDD in a similar way.

like image 143
zero323 Avatar answered Dec 10 '22 21:12

zero323