Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark reading from Postgres JDBC table slow

I am trying to load about 1M rows from a PostgreSQL database into Spark. When using Spark it takes about 10s. However, loading the same query using psycopg2 driver takes 2s. I am using postgresql jdbc driver version 42.0.0

def _loadFromPostGres(name):
    url_connect = "jdbc:postgresql:"+dbname
    properties = {"user": "postgres", "password": "postgres"}
    df = SparkSession.builder.getOrCreate().read.jdbc(url=url_connect, table=name, properties=properties)
    return df

df = _loadFromPostGres("""
    (SELECT "seriesId", "companyId", "userId", "score" 
    FROM user_series_game 
    WHERE "companyId"=655124304077004298) as
user_series_game""")

print measure(lambda : len(df.collect()))

The output is -

--- 10.7214591503 seconds ---
1076131

Using psycopg2 -

import psycopg2
conn = psycopg2.connect(conn_string)
cur = conn.cursor()

def _exec():
    cur.execute("""(SELECT "seriesId", "companyId", "userId", "score" 
        FROM user_series_game 
        WHERE "companyId"=655124304077004298)""")
    return cur.fetchall()
print measure(lambda : len(_exec()))
cur.close()
conn.close()

The output is -

--- 2.27961301804 seconds ---
1076131

The measure function -

def measure(func) :
    start_time = time.time()
    x = func()
    print("--- %s seconds ---" % (time.time() - start_time))
    return x

Kindly help me find the cause of this problem.


Edit 1

I did a few more benchmarks. Using Scala and JDBC -

import java.sql._;
import scala.collection.mutable.ArrayBuffer;

def exec() {

val url = ("jdbc:postgresql://prod.caumccqvmegm.ap-southeast-1.rds.amazonaws.com/prod"+ 
    "?tcpKeepAlive=true&prepareThreshold=-1&binaryTransfer=true&defaultRowFetchSize=10000")

val conn = DriverManager.getConnection(url,"postgres","postgres");

val sqlText = """SELECT "seriesId", "companyId", "userId", "score" 
        FROM user_series_game 
        WHERE "companyId"=655124304077004298"""

val t0 = System.nanoTime()

val stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)

val rs = stmt.executeQuery()

val list = new ArrayBuffer[(Long, Long, Long, Double)]()

while (rs.next()) {
    val seriesId = rs.getLong("seriesId")
    val companyId = rs.getLong("companyId")
    val userId = rs.getLong("userId")
    val score = rs.getDouble("score")
    list.append((seriesId, companyId, userId, score))
}

val t1 = System.nanoTime()

println("Elapsed time: " + (t1 - t0) * 1e-9 + "s")

println(list.size)

rs.close()
stmt.close()
conn.close()
}

exec()

The output was -

Elapsed time: 1.922102285s
1143402

When I did collect() in Spark + Scala -

import org.apache.spark.sql.SparkSession

def exec2() {

    val spark = SparkSession.builder().getOrCreate()

    val url = ("jdbc:postgresql://prod.caumccqvmegm.ap-southeast-1.rds.amazonaws.com/prod"+ 
    "?tcpKeepAlive=true&prepareThreshold=-1&binaryTransfer=true&defaultRowFetchSize=10000")

    val sqlText = """(SELECT "seriesId", "companyId", "userId", "score" 
        FROM user_series_game 
        WHERE "companyId"=655124304077004298) as user_series_game"""

    val t0 = System.nanoTime()

    val df = spark.read
          .format("jdbc")
          .option("url", url)
          .option("dbtable", sqlText)
          .option("user", "postgres")
          .option("password", "postgres")
          .load()

    val list = df.collect()

    val t1 = System.nanoTime()

    println("Elapsed time: " + (t1 - t0) * 1e-9 + "s")

    print (list.size)
}

exec2()

The output was

Elapsed time: 1.486141076s
1143445

So 4x amount of extra time is spent within Python serialisation. I understand there will be some penalty, but this seems too much.

like image 844
Abhijit Bhole Avatar asked Apr 21 '17 04:04

Abhijit Bhole


1 Answers

The reason is really simple and have two simultaneous reasons.

First I will give you a perpective of how psycopg2 works.

This lib psycopg2 works like any other lib to connect to a RDMS. This lib will send the query to the engine of your postgres and it will return the data to you. Straight foward like this.

Conn -> Query -> ReturnData -> FetchData

When you use spark is a little bit different in two ways. Spark is not like a programatic language that run in one single thread. It has a Distributed System to work. Even if you are running in a local machine. See Spark has a basic concept of Driver(Master) and Workers.

The Driver recieve the request to execute the query to the Postgres, the Driver will not request the data for each worker request the information from your Postgres.

If you see the documentation here you will se a note like this:

Don’t create too many partitions in parallel on a large cluster; otherwise Spark might crash your external database systems.

This note means that each worker will be responsible to request the data for your postgres. This is a small overhead of starting this process but nothing really big. But have a overhead here, to send the data to each worker.

Seccond point, your collect in this part of code:

print measure(lambda : len(df.collect()))

The collect function will send a command for all of your workers to send the data to your Driver. To store in the memory of your driver, it is like a Reduce, it creates Shuffle in the middle of the process. Shuffle is the step of the process that the data is send to other workers. In the case of collect each worker will send that to your Driver.

So the steps of Spark in JDBC of your code is:

(Workers)Conn -> (Workers)Query -> (Workers)FetchData -> (Driver) Request the Data -> (Workers) Shuffle -> (Driver) Collect

Well there in a bunch of other stuffs that happens with the Spark, like the QueryPlan, build the DataFrame and other stuffs.

That is the reason that you have faster response in your simple code of Python than Spark.

like image 163
Thiago Baldim Avatar answered Sep 24 '22 23:09

Thiago Baldim