Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Create multiple Spark DataFrames from RDD based on some key value (pyspark)

I have some text files containing JSON objects (one object per line). Example:

{"a": 1, "b": 2, "table": "foo"}
{"c": 3, "d": 4, "table": "bar"}
{"a": 5, "b": 6, "table": "foo"}
...

I want to parse the contents of text files into Spark DataFrames based on the table name. So in the example above, I would have a DataFrame for "foo" and another DataFrame for "bar". I have made it as far as grouping the lines of JSON into lists inside of an RDD with the following (pyspark) code:

text_rdd = sc.textFile(os.path.join("/path/to/data", "*"))
tables_rdd = text_rdd.groupBy(lambda x: json.loads(x)['table'])

This produces an RDD containing a list of tuples with the following structure:

RDD[("foo", ['{"a": 1, "b": 2, "table": "foo"}', ...],
    ("bar", ['{"c": 3, "d": 4, "table": "bar"}', ...]]

How do I break this RDD into a DataFrame for each table key?

edit: I tried to clarify above that there are multiple lines in a single file containing information for a table. I know that I can call .collectAsMap on the "groupBy" RDD that I have created, but I know that this will consume a sizeable amount of RAM on my driver. My question is: is there a way to break the "groupBy" RDD into multiple DataFrames without using .collectAsMap?

like image 339
conrosebraugh Avatar asked Sep 11 '25 08:09

conrosebraugh


1 Answers

You can split it efficiently into parquet partitions: First we'll convert it into dataframe:

text_rdd = sc.textFile(os.path.join("/path/to/data", "*"))
df = spark.read.json(text_rdd)
df.printSchema()
    root
     |-- a: long (nullable = true)
     |-- b: long (nullable = true)
     |-- c: long (nullable = true)
     |-- d: long (nullable = true)
     |-- table: string (nullable = true)

Now we can write it:

df.write.partitionBy('table').parquet([output directory name])

If you list the content of [output directory name], you'll see as many partitions as there are distinct values of table:

hadoop fs -ls [output directory name]

    _SUCCESS
    table=bar/
    table=foo/

If you want to keep each table's columns only, you can do this (assuming the full list of columns appear whenever the table appears in the file)

import ast
from pyspark.sql import Row
table_cols = spark.createDataFrame(text_rdd.map(lambda l: ast.literal_eval(l)).map(lambda l: Row(
        table = l["table"], 
        keys = sorted(l.keys())
    ))).distinct().toPandas()
table_cols = table_cols.set_index("table")
table_cols.to_dict()["keys"]

    {u'bar': [u'c', u'd', u'table'], u'foo': [u'a', u'b', u'table']}
like image 78
MaFF Avatar answered Sep 13 '25 22:09

MaFF