I try to refactor a ScalaTest FunSuite test to avoid boilerplate code to init and destroy Spark session.
The problem is that I need import implicit functions but using before/after approach only variables (var fields) can be use, and to import it is necessary a value (val fields).
The idea is to have a new clean Spark Session every test execution.
I try to do something like this:
import org.apache.spark.SparkContext
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.scalatest.{BeforeAndAfter, FunSuite}
object SimpleWithBeforeTest extends FunSuite with BeforeAndAfter {
var spark: SparkSession = _
var sc: SparkContext = _
implicit var sqlContext: SQLContext = _
before {
spark = SparkSession.builder
.master("local")
.appName("Spark session for testing")
.getOrCreate()
sc = spark.sparkContext
sqlContext = spark.sqlContext
}
after {
spark.sparkContext.stop()
}
test("Import implicits inside the test 1") {
import sqlContext.implicits._
// Here other stuff
}
test("Import implicits inside the test 2") {
import sqlContext.implicits._
// Here other stuff
}
But in the line import sqlContext.implicits._
I have an error
Cannot resolve symbol sqlContext
How to resolve this problem or how to implements the tests class?
You can also use spark-testing-base, which pretty much handles all the boilerplate code.
Here is a blog post by the creator, explaining how to use it.
And here is a simple example from their wiki:
class test extends FunSuite with DatasetSuiteBase { test("simple test") { val sqlCtx = sqlContext import sqlCtx.implicits._ val input1 = sc.parallelize(List(1, 2, 3)).toDS assertDatasetEquals(input1, input1) // equal val input2 = sc.parallelize(List(4, 5, 6)).toDS intercept[org.scalatest.exceptions.TestFailedException] { assertDatasetEquals(input1, input2) // not equal } } }
Define a new immutable variable for the spark context and assign the var to it before importing implicits.
class MyCassTest extends FlatSpec with BeforeAndAfter {
var spark: SparkSession = _
before {
val sparkConf: SparkConf = new SparkConf()
spark = SparkSession.
builder().
config(sparkConf).
master("local[*]").
getOrCreate()
}
after {
spark.stop()
}
"myFunction()" should "return 1.0 blab bla bla" in {
val sc = spark
import sc.implicits._
// assert ...
}
}
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With