Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

spark reading data from mysql in parallel

Im trying to read data from mysql and write it back to parquet file in s3 with specific partitions as follows:

df=sqlContext.read.format('jdbc')\
   .options(driver='com.mysql.jdbc.Driver',url="""jdbc:mysql://<host>:3306/<>db?user=<usr>&password=<pass>""",
         dbtable='tbl',
         numPartitions=4 )\
   .load()


df2=df.withColumn('updated_date',to_date(df.updated_at))
df2.write.parquet(path='s3n://parquet_location',mode='append',partitionBy=['updated_date'])

My problem is that it open only one connection to mysql (instead of 4) and it doesn't write to parquert until it fetches all the data from mysql, because my table in mysql is huge (100M rows) the process failed on OutOfMemory.

Is there a way to configure Spark to open more than one connection to mysql and to write partial data to parquet?

like image 606
Lior Baber Avatar asked Jan 28 '16 12:01

Lior Baber


People also ask

Can Spark read be in parallel?

Spark SQL will read different column family in parallel. By default, for example, in this example, there are six columns for the Parquet file. By default, all the six columns will be in a single Parquet file.

Does Spark SQL run in parallel?

There is a driver program within the Spark cluster where the application logic execution is stored. Here, data is processed in parallel with multiple workers.

Can Spark read from MySQL?

Start a Spark Shell and Connect to MySQL Data With the shell running, you can connect to MySQL with a JDBC URL and use the SQL Context load() function to read a table. The Server and Port properties must be set to a MySQL server.

Is Spark parallel or distributed?

Spark uses Resilient Distributed Datasets (RDD) to perform parallel processing across a cluster or computer processors. It has easy-to-use APIs for operating on large datasets, in various programming languages.


2 Answers

For Spark >= 2.0 I've created a class with next methods:

...
private val dbUrl =
s"""jdbc:mysql://${host}:${port}/${db_name}
    |?zeroDateTimeBehavior=convertToNull
    |&read_buffer_size=100M""".stripMargin.replace("\n", "")

def run(sqlQuery: String): DataFrame = {
println(sqlQuery)
Datapipeline.spark.read
  .format("jdbc")
  .option("driver", "com.mysql.jdbc.Driver")
  .option("url", dbUrl)
  .option("user", user)
  .option("password", pass)
  .option("dbtable", s"($sqlQuery) as tmp")
  .load()
}
...
def getBounds(table: String, whereClause: String, partitionColumn: String): Array[Int] = {
val sql = s"select min($partitionColumn) as min, max($partitionColumn) as max from $table${
  if (whereClause.length > 0) s" where $whereClause"
}"
val df = run(sql).collect()(0)

Array(df.get(0).asInstanceOf[Int], df.get(1).asInstanceOf[Int])
}

def getTableFields(table: String): String = {
val sql =
  s"""
     |SELECT *
     |FROM information_schema.COLUMNS
     |WHERE table_name LIKE '$table'
     |  AND TABLE_SCHEMA LIKE '${db_name}'
     |ORDER BY ORDINAL_POSITION
   """.stripMargin
run(sql).collect().map(r => r.getAs[String]("COLUMN_NAME")).mkString(", ")
}

/**
* Returns DataFrame partitioned by <partritionColumn> to number of partitions provided in
* <numPartitions> for a <table> with WHERE clause
* @param table - a table name
* @param whereClause - WHERE clause without "WHERE" key word
* @param partitionColumn - column name used for partitioning, should be numeric
* @param numPartitions - number of partitions
* @return - a DataFrame
*/
def run(table: String, whereClause: String, partitionColumn: String, numPartitions: Int): DataFrame = {
val bounds = getBounds(table, whereClause, partitionColumn)

val fields = getTableFields(table)
val dfs: Array[DataFrame] = new Array[DataFrame](numPartitions)

val lowerBound = bounds(0)
val partitionRange: Int = ((bounds(1) - bounds(0)) / numPartitions)

for (i <- 0 to numPartitions - 2) {
  dfs(i) = run(
    s"""select $fields from $table
        | where $partitionColumn >= ${lowerBound + (partitionRange * i)} and $partitionColumn < ${lowerBound + (partitionRange * (i + 1))}${
      if (whereClause.length > 0)
        s" and $whereClause"
    }
     """.stripMargin.replace("\n", ""))
}

dfs(numPartitions - 1) = run(s"select $fields from $table where $partitionColumn >= ${lowerBound + (partitionRange * (numPartitions - 1))}${
  if (whereClause.length > 0)
    s" and $whereClause"
}".replace("\n", ""))

dfs.reduceLeft((res, df) => res.union(df))

}

Last run method will create a number of necessary partitions. When you call an action method Spark will create as many parallel tasks as many partitions have been defined for the DataFrame returned by the run method.

Enjoy.

like image 135
Orka Avatar answered Sep 20 '22 04:09

Orka


You should set these properties:

partitionColumn, 
lowerBound, 
upperBound, 
numPartitions

as it is documented here: http://spark.apache.org/docs/latest/sql-programming-guide.html#jdbc-to-other-databases

like image 34
mgaido Avatar answered Sep 18 '22 04:09

mgaido