I have a pyspark df with the following schema:
root
|-- array_bytes: binary (nullable = true)
I would like to be able to convert this to an image array. I can accomplish this in Pandas with the following code:
df_pandas = df.toPandas()
def bytes_to_array(byte_data):
arr = np.frombuffer(byte_data, dtype=np.uint8)
return arr.reshape((224, 224, 3))
df_pandas['image_array'] = df_pandas['array_bytes'].apply(bytes_to_array)
I can't seem to find a way to do this in PySpark. Here's what I've tried:
def convert_binary_to_array(binary_data: bytes) -> np.ndarray:
arr = np.frombuffer(binary_data, dtype=np.uint8)
return arr.reshape((224, 224, 3))
def convert_binary_in_df(df, binary_column: str = 'binary_data'):
def convert_binary_udf(byte_data):
return convert_binary_to_array(byte_data).tolist()
# register and apply udf
convert_binary_spark_udf = udf(convert_binary_udf, ArrayType(ArrayType(IntegerType())))
df_output = df.withColumn("image_data", convert_binary_spark_udf(binary_column))
return df_output
df_converted = convert_binary_in_df(df, binary_column='array_bytes')
However, image_data just ends up being full of nulls. I'm not very strong in PySpark and haven't been able to figure out what's wrong. Thanks in advance for any help.
I put together a working example to take some image, put it through this conversion into a dataframe and then reconstruct the image back from the dataframe and see if looks the same
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType
import numpy as np
import matplotlib.pyplot as plt
# an image with a central red square and blue background
image = np.zeros((224, 224, 3), dtype=np.uint8)
image[:, :] = [0, 0, 255]
image[72:152, 72:152] = [255, 0, 0]
plt.imshow(image)
plt.title("Original Image")
plt.axis('off')
plt.show()
def convert_binary_to_array(binary_data: bytes) -> np.ndarray:
arr = np.frombuffer(binary_data, dtype=np.uint8)
return arr.reshape((224, 224, 3))
def convert_binary_in_df(df, binary_column: str = 'binary_data'):
def convert_binary_udf(byte_data):
return convert_binary_to_array(byte_data).tolist()
convert_binary_spark_udf = udf(convert_binary_udf, ArrayType(ArrayType(ArrayType(IntegerType()))))
df_output = df.withColumn("image_data", convert_binary_spark_udf(binary_column))
return df_output
spark = SparkSession.builder.appName("ImageArrayConversion").getOrCreate()
binary_image_data = image.tobytes()
df = spark.createDataFrame([(binary_image_data,)], ["array_bytes"])
df_converted = convert_binary_in_df(df, binary_column='array_bytes')
df_converted.show(truncate=False)
image_data_from_df = df_converted.collect()[0]['image_data']
# Reconstruct the image back
recreated_image = np.array(image_data_from_df, dtype=np.uint8)
plt.imshow(recreated_image)
plt.title("Recreated Image from DF")
plt.axis('off')
plt.show()
The only fix I had to make in your code was to change the return type of the UDF to ArrayType(ArrayType(ArrayType(IntegerType())) because there are three channels in the image.

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With