Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

If I cache a Spark Dataframe and then overwrite the reference, will the original data frame still be cached?

Suppose I had a function to generate a (py)spark data frame, caching the data frame into memory as the last operation.

def gen_func(inputs):
   df = ... do stuff...
   df.cache()
   df.count()
   return df

Per my understanding, Spark's caching works as follows:

  1. When cache/persist plus an action (count()) is called on a data frame, it is computed from its DAG and cached into memory, affixed to the object which refers to it.
  2. As long as a reference exists to that object, possibly within other functions/other scopes, the df will continue to be cached, and all DAGs that depend on the df will use the in-memory cached data as a starting point.
  3. If all references to the df are deleted, Spark puts up the cache as memory to be garbage collected. It may not be garbage collected immediately, causing some short-term memory blocks (and in particular, memory leaks if you generate cached data and throw them away too fast), but eventually it will be cleared up.

My question is, suppose I use gen_func to generate a data frame, but then overwrite the original data frame reference (perhaps with a filter or a withColumn).

df=gen_func(inputs)
df=df.filter("some_col = some_val")

In Spark, RDD/DF are immutable, so the reassigned df after the filter and the df before the filter refer to two entirely different objects. In this case, the reference to the original df that was cache/counted has been overwritten. Does that mean that the cached data frame is no longer available and will be garbage collected? Does that mean that the new post-filter df will compute everything from scratch, despite being generated from a previously cached data frame?

I am asking this because I was recently fixing some out-of-memory issues with my code, and it seems to me that caching might be the problem. However, I do not really understand the full details yet of what are the safe ways to use cache, and how one might accidentally invalidate one's cached memory. What is missing in my understanding? Am I deviating from best practice in doing the above?

like image 407
Andrew Ma Avatar asked Feb 17 '20 03:02

Andrew Ma


People also ask

When should I cache my Spark data frame?

When to cache? If you're executing multiple actions on the same DataFrame then cache it. Every time the following line is executed (in this case 3 times), spark reads the Parquet file, and executes the query. Now, Spark will read the Parquet, execute the query only once and then cache it.

Can we cache DataFrame in Spark?

cache() is an Apache Spark transformation that can be used on a DataFrame, Dataset, or RDD when you want to perform more than one action. cache() caches the specified DataFrame, Dataset, or RDD in the memory of your cluster's workers.

What happens when cache memory is full in Spark?

unpersist() . If the caching layer becomes full, Spark will start evicting the data from memory using the LRU (least recently used) strategy. So it is good practice to use unpersist to stay more in control about what should be evicted.


1 Answers

I've done a couple of experiments as shown below. Apparently, the dataframe, once cached, remains cached (as shown in getPersistentRDDs and the query plan - InMemory etc.), even if all Python reference were overwritten or deleted altogether using del, and with garbage collection explicitly called.

Experiment 1:

def func():
    data = spark.createDataFrame([[1],[2],[3]]).toDF('col1')
    data.cache()
    data.count()
    return data

sc._jsc.getPersistentRDDs()

df = func()
sc._jsc.getPersistentRDDs()

df2 = df.filter('col1 != 2')
del df
import gc
gc.collect()
sc._jvm.System.gc()
sc._jsc.getPersistentRDDs()

df2.select('*').explain()

del df2
gc.collect()
sc._jvm.System.gc()
sc._jsc.getPersistentRDDs()

Results:

>>> def func():
...     data = spark.createDataFrame([[1],[2],[3]]).toDF('col1')
...     data.cache()
...     data.count()
...     return data
...
>>> sc._jsc.getPersistentRDDs()
{}

>>> df = func()
>>> sc._jsc.getPersistentRDDs()
{71: JavaObject id=o234}

>>> df2 = df.filter('col1 != 2')
>>> del df
>>> import gc
>>> gc.collect()
93
>>> sc._jvm.System.gc()
>>> sc._jsc.getPersistentRDDs()
{71: JavaObject id=o240}

>>> df2.select('*').explain()
== Physical Plan ==
*(1) Filter (isnotnull(col1#174L) AND NOT (col1#174L = 2))
+- *(1) ColumnarToRow
   +- InMemoryTableScan [col1#174L], [isnotnull(col1#174L), NOT (col1#174L = 2)]
         +- InMemoryRelation [col1#174L], StorageLevel(disk, memory, deserialized, 1 replicas)
               +- *(1) Project [_1#172L AS col1#174L]
                  +- *(1) Scan ExistingRDD[_1#172L]

>>> del df2
>>> gc.collect()
85
>>> sc._jvm.System.gc()
>>> sc._jsc.getPersistentRDDs()
{71: JavaObject id=o250}

Experiment 2:

def func():
    data = spark.createDataFrame([[1],[2],[3]]).toDF('col1')
    data.cache()
    data.count()
    return data

sc._jsc.getPersistentRDDs()

df = func()
sc._jsc.getPersistentRDDs()

df = df.filter('col1 != 2')
import gc
gc.collect()
sc._jvm.System.gc()
sc._jsc.getPersistentRDDs()

df.select('*').explain()

del df
gc.collect()
sc._jvm.System.gc()
sc._jsc.getPersistentRDDs()

Results:

>>> def func():
...     data = spark.createDataFrame([[1],[2],[3]]).toDF('col1')
...     data.cache()
...     data.count()
...     return data
...
>>> sc._jsc.getPersistentRDDs()
{}

>>> df = func()
>>> sc._jsc.getPersistentRDDs()
{86: JavaObject id=o317}

>>> df = df.filter('col1 != 2')
>>> import gc
>>> gc.collect()
244
>>> sc._jvm.System.gc()
>>> sc._jsc.getPersistentRDDs()
{86: JavaObject id=o323}

>>> df.select('*').explain()
== Physical Plan ==
*(1) Filter (isnotnull(col1#220L) AND NOT (col1#220L = 2))
+- *(1) ColumnarToRow
   +- InMemoryTableScan [col1#220L], [isnotnull(col1#220L), NOT (col1#220L = 2)]
         +- InMemoryRelation [col1#220L], StorageLevel(disk, memory, deserialized, 1 replicas)
               +- *(1) Project [_1#218L AS col1#220L]
                  +- *(1) Scan ExistingRDD[_1#218L]

>>> del df
>>> gc.collect()
85
>>> sc._jvm.System.gc()
>>> sc._jsc.getPersistentRDDs()
{86: JavaObject id=o333}

Experiment 3 (control experiment, to show that unpersist works)

def func():
    data = spark.createDataFrame([[1],[2],[3]]).toDF('col1')
    data.cache()
    data.count()
    return data

sc._jsc.getPersistentRDDs()

df = func()
sc._jsc.getPersistentRDDs()

df2 = df.filter('col1 != 2')
df2.select('*').explain()

df.unpersist()
df2.select('*').explain()

Results:

>>> def func():
...     data = spark.createDataFrame([[1],[2],[3]]).toDF('col1')
...     data.cache()
...     data.count()
...     return data
...
>>> sc._jsc.getPersistentRDDs()
{}

>>> df = func()
>>> sc._jsc.getPersistentRDDs()
{116: JavaObject id=o398}

>>> df2 = df.filter('col1 != 2')
>>> df2.select('*').explain()
== Physical Plan ==
*(1) Filter (isnotnull(col1#312L) AND NOT (col1#312L = 2))
+- *(1) ColumnarToRow
   +- InMemoryTableScan [col1#312L], [isnotnull(col1#312L), NOT (col1#312L = 2)]
         +- InMemoryRelation [col1#312L], StorageLevel(disk, memory, deserialized, 1 replicas)
               +- *(1) Project [_1#310L AS col1#312L]
                  +- *(1) Scan ExistingRDD[_1#310L]

>>> df.unpersist()
DataFrame[col1: bigint]
>>> sc._jsc.getPersistentRDDs()
{}

>>> df2.select('*').explain()
== Physical Plan ==
*(1) Project [_1#310L AS col1#312L]
+- *(1) Filter (isnotnull(_1#310L) AND NOT (_1#310L = 2))
   +- *(1) Scan ExistingRDD[_1#310L]

To answer the OP's question:

Does that mean that the cached data frame is no longer available and will be garbage collected? Does that mean that the new post-filter df will compute everything from scratch, despite being generated from a previously cached data frame?

The experiments suggest no for both. The dataframe remains cached, is not garbage collected, and the new dataframe is computed using the cached (unreference-able) dataframe, according to the query plan.

Some helpful functions related to cache usage (if you don't want to do it through the Spark UI) are:

sc._jsc.getPersistentRDDs(), which shows a list of cached RDDs/dataframes, and

spark.catalog.clearCache(), which clears all cached RDDs/dataframes.

Am I deviating from best practice in doing the above?

I am no authority to judge you on this, but as one of the comments suggested, avoid reassigning to df because dataframes are immutable. Try to imagine you're coding in scala and you defined df as a val. Doing df = df.filter(...) is impossible. Python can't enforce that per se, but I think the best practice is to avoid overwriting any dataframe variables, so that you can always call df.unpersist() afterwards if you no longer need the cached results anymore.

like image 131
mck Avatar answered Sep 21 '22 13:09

mck