Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to check if all records for a given key are in the same partition already?

Tags:

apache-spark

I'd like to avoid repartitioning data set by key as much as possible and know if all records for a given key are in the same partition already.

Is there a built-in function in Spark that would give me the answer?

like image 721
Jacek Laskowski Avatar asked Dec 29 '16 10:12

Jacek Laskowski


People also ask

What is round robin partitioning?

Round-robin partitioning is used to achieve an equal distribution of rows to partitions. However, unlike hash partitioning, you do not have to specify partitioning columns. With round-robin partitioning, new rows are assigned to partitions on a rotation basis. The table must not have primary keys.

What is partition key in SQL?

A table partitioning key is an ordered set of one or more columns in a table. The values in the table partitioning key columns are used to determine in which data partition each table row belongs. To define the table partitioning key on a table use the CREATE TABLE statement with the PARTITION BY clause.

What's the difference between a partition and an index?

Indexes are used to speed the search of data within tables. Partitions provide segregation of the data at the hdfs level, creating sub-directories for each partition. Partitioning allows the number of files read and amount of data searched in a query to be limited.


2 Answers

Not built-in but if you assume specific partitioner it is easy enough to implement your own function:

import org.apache.spark.rdd.RDD
import org.apache.spark.Partitioner
import scala.reflect.ClassTag

def checkDistribution[K : ClassTag, V : ClassTag](
   rdd: RDD[(K, V)], partitioner: Partitioner) = 
  // If partitioner is set we compare partitioners 
  rdd.partitioner.map(_ == partitioner).getOrElse {
    // Otherwise check if correct number of partitions 
    rdd.partitions.size ==  partitioner.numPartitions &&
    //  And check if distribution matches partitioner
    rdd.keys.mapPartitionsWithIndex((i, iter) => 
      Iterator(iter.forall(x => partitioner.getPartition(x) == i))
    ).fold(true)(_ && _)
  }

A few tests:

import org.apache.spark.HashPartitioner

val rdd = sc.range(0, 20, 5).map((_, None))
  • Not partitioned, invalid distribution:

    checkDistribution(rdd, new HashPartitioner(10))
    
    Boolean = false
    
  • Partitioned, invalid partitioner:

    checkDistribution(
      rdd.partitionBy(new HashPartitioner(5)),
      new HashPartitioner(10)
    )
    
    Boolean = false
    
  • Partitioned, valid partitioner:

    checkDistribution(
      rdd.partitionBy(new HashPartitioner(10)),
      new HashPartitioner(10)
    )
    
    Boolean = true
    
  • Not partitioned, valid distribution:

    checkDistribution(
      rdd.partitionBy(new HashPartitioner(10)).map(identity),
      new HashPartitioner(10)
    )
    
    Boolean = true
    

Without assuming particular partitioner the only option that comes to mind requires shuffle, so it it unlikely to be an improvement.

def checkDistribution[K : ClassTag, V : ClassTag](rdd: RDD[(K, V)]) =
   rdd.keys.mapPartitionsWithIndex((i, iter) => iter.map((_, i)))
     .combineByKey(
       x => Seq(x), 
       (x: Seq[Int], y: Int) => x, 
       (x: Seq[Int], y: Seq[Int]) => x ++ y)  // Should be more or less OK
     .values
     .mapPartitions(iter => Iterator(iter.forall(_.size == 1)))
     .fold(true)(_ && _)

One possible improvement is that you can use the same logic to automatically define Partitioner for the data. If you collectAsMap before values and check that all Seqs are of size 1 you have a valid partitioner which guarantees no network traffic.

like image 135
zero323 Avatar answered Jan 03 '23 00:01

zero323


Not 100% what you requested but you can check this by using spark_partition_id. Basically do:

withColumn("pid", spark_partition_id())

and then do:

df.groupby(what you want to check).agg(max($"pid").as("pidmax"),min($"pid").as("pidmin")).filter($"pidmax"===$"pidmin").count()

The count would give you how many elements are not partitioned. Note that this is relatively low cost being a simple aggregation.

I don't believe there is a generic way because if we read from a generic source (e.g. file), we don't necessarily know how the source was originally partitioned.

It would be nice if there was something like "get current partitioner" which would get explicit partitioners (e.g. if we had an explicit repartition command or reading something from parquet which was written using PartitionBy) as an approximation though.

like image 32
Assaf Mendelson Avatar answered Jan 03 '23 00:01

Assaf Mendelson