Lately I'm planning to migrate my standalone python ML code to spark. The ML pipeline in spark.ml
turns out quite handy, with streamlined API for chaining up algorithm stages and hyper-parameter grid search.
Still, I found its support for one important feature obscure in existing documents: caching of intermediate results. The importance of this feature arise when the pipeline involves computation intensive stages.
For example, in my case I use a huge sparse matrix to perform multiple moving averages on time series data in order to form input features. The structure of the matrix is determined by some hyper-parameter. This step turns out to be a bottleneck for the entire pipeline because I have to construct the matrix in runtime.
During parameter search, I usually have other parameters to examine other than this "structure parameter". So if I can reuse the huge matrix when the "structure parameter" is unchanged, I can save tons of time. For this reason, I intentionally formed my code to cache and reuse these intermediate results.
So my question is: can Spark's ML pipeline handle intermediate caching automatically? Or do I have to manually form code to do so? If so, is there any best practice to learn from?
P.S. I have looked into the official document and some other material, but none of them seems to discuss this topic.
By caching you create a checkpoint in your spark application and if further down the execution of application any of the tasks fail your application will be able to recompute the lost RDD partition from the cache.
cache() is an Apache Spark transformation that can be used on a DataFrame, Dataset, or RDD when you want to perform more than one action. cache() caches the specified DataFrame, Dataset, or RDD in the memory of your cluster's workers.
Spark DataFrame or Dataset cache() method by default saves it to storage level ` MEMORY_AND_DISK ` because recomputing the in-memory columnar representation of the underlying table is expensive. Note that this is different from the default cache level of ` RDD. cache() ` which is ' MEMORY_ONLY '.
Description. CACHE TABLE statement caches contents of a table or output of a query with the given storage level. This reduces scanning of the original files in future queries.
So I ran into the same problem and the way I solved is that I have implemented my own PipelineStage, that caches the input DataSet and returns it as it is.
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
class Cacher(val uid: String) extends Transformer with DefaultParamsWritable {
override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF.cache()
override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
override def transformSchema(schema: StructType): StructType = schema
def this() = this(Identifiable.randomUID("CacherTransformer"))
}
To use it then you would do something like this:
new Pipeline().setStages(Array(stage1, new Cacher(), stage2))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With