Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to plot geolocated RGB data faster using Python basemap

I'm having an issue plotting an RGB image using Python's Basemap module with latitude and longitude data. Now, I am able to make the plots that I want, but the problem is how slow it is, since it is able to plot single channel data much faster than the RGB data, and in general, plotting RGB images on their own is also fast. Since I have lat/lon data, that is where things get complicated. I've checked out the solution to this problem:

How to plot an irregular spaced RGB image using python and basemap?

which is how I got to where I am right now. It essentially comes down to the following issue. When using the pcolormesh method in basemap, to plot RGB data you have to define a colorTuple parameter which will map the RGB data point by point. Since the array size is on the order of 2000x1000, this takes awhile to do. A snippet of what I'm talking about is seen below (full working code further down):

if one_channel:
    m.pcolormesh(lons, lats, img[:,:,0], latlon=True)
else:
    # This is the part that is slow, but I don't know how to
    # accurately plot the data otherwise.

    mesh_rgb = img[:, :-1, :]
    colorTuple = mesh_rgb.reshape((mesh_rgb.shape[0] * mesh_rgb.shape[1]), 3)

    # What you put in for the image doesn't matter because of the color mapping
    m.pcolormesh(lons, lats, img[:,:,0], latlon=True,color=colorTuple)

When plotting just one channel, it can make the map in about 10 seconds or so. When plotting the RGB data, it can take 3-4 minutes. Given that there is only 3 times as much data, I feel that there must be a better way, especially since plotting RGB data can go just as fast as one channel data when you are making rectangular images.

So, my questions is: Is there any way to make this calculation faster, either with other plotting modules (Bokeh for instance) or by changing the color mapping in any way? I've tried using imshow with carefully chosen map boundaries, but since it just stretches the image to the full extent of the map, this isn't really good enough for accurate mapping of the data.

Below is a stripped down version of my code that will work for an example with the correct modules:

from pyhdf.SD import SD,SDC
import numpy as np
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap

def get_hdf_attr(infile,dataset,attr):

    f = SD(infile,SDC.READ)
    data = f.select(dataset)
    index = data.attr(attr).index()
    attr_out = data.attr(index).get()
    f.end()

    return attr_out

def get_hdf_dataset(infile,dataset):

    f = SD(infile,SDC.READ)
    data = f.select(dataset)[:]
    f.end()

    return data

class make_rgb:

    def __init__(self,file_name):

        sds_250 = get_hdf_dataset(file_name, 'EV_250_Aggr1km_RefSB')
        scales_250 = get_hdf_attr(file_name, 'EV_250_Aggr1km_RefSB', 'reflectance_scales')
        offsets_250 = get_hdf_attr(file_name, 'EV_250_Aggr1km_RefSB', 'reflectance_offsets')

        sds_500 = get_hdf_dataset(file_name, 'EV_500_Aggr1km_RefSB')
        scales_500 = get_hdf_attr(file_name, 'EV_500_Aggr1km_RefSB', 'reflectance_scales')
        offsets_500 = get_hdf_attr(file_name, 'EV_500_Aggr1km_RefSB', 'reflectance_offsets')

        data_shape = sds_250.shape

        along_track = data_shape[1]
        cross_track = data_shape[2]

        rgb = np.zeros((along_track, cross_track, 3))

        rgb[:, :, 0] = (sds_250[0, :, :] - offsets_250[0]) * scales_250[0]
        rgb[:, :, 1] = (sds_500[1, :, :] - offsets_500[1]) * scales_500[1]
        rgb[:, :, 2] = (sds_500[0, :, :] - offsets_500[0]) * scales_500[0]

        rgb[rgb > 1] = 1.0
        rgb[rgb < 0] = 0.0

        lin = np.array([0, 30, 60, 120, 190, 255]) / 255.0
        nonlin = np.array([0, 110, 160, 210, 240, 255]) / 255.0

        scale = interp1d(lin, nonlin, kind='quadratic')

        self.img = scale(rgb)

    def plot_image(self):

        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111)

        ax.set_yticks([])
        ax.set_xticks([])
        plt.imshow(self.img, interpolation='nearest')
        plt.show()

    def plot_geo(self,geo_file,one_channel=False):

        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111)

        lats = get_hdf_dataset(geo_file, 0)
        lons = get_hdf_dataset(geo_file, 1)

        lat_0 = np.mean(lats)
        lat_range = [np.min(lats), np.max(lats)]

        lon_0 = np.mean(lons)
        lon_range = [np.min(lons), np.max(lons)]

        map_kwargs = dict(projection='cass', resolution='l',
                          llcrnrlat=lat_range[0], urcrnrlat=lat_range[1],
                          llcrnrlon=lon_range[0], urcrnrlon=lon_range[1],
                          lat_0=lat_0, lon_0=lon_0)

        m = Basemap(**map_kwargs)

        if one_channel:
            m.pcolormesh(lons, lats, self.img[:,:,0], latlon=True)
        else:
            # This is the part that is slow, but I don't know how to
            # accurately plot the data otherwise.
            mesh_rgb = self.img[:, :-1, :]
            colorTuple = mesh_rgb.reshape((mesh_rgb.shape[0] * mesh_rgb.shape[1]), 3)
            m.pcolormesh(lons, lats, self.img[:,:,0], latlon=True,color=colorTuple)

        m.drawcoastlines()
        m.drawcountries()

        plt.show()

if __name__ == '__main__':

    # https://ladsweb.nascom.nasa.gov/archive/allData/6/MOD021KM/2015/183/
    data_file = 'MOD021KM.A2015183.1005.006.2015183195350.hdf'

    # https://ladsweb.nascom.nasa.gov/archive/allData/6/MOD03/2015/183/
    geo_file = 'MOD03.A2015183.1005.006.2015183192656.hdf'

    # Very Fast
    make_rgb(data_file).plot_image()

    # Also Fast, takes about 10 seconds
    make_rgb(data_file).plot_geo(geo_file,one_channel=True)

    # Much slower, takes several minutes
    make_rgb(data_file).plot_geo(geo_file)
like image 996
tmwilson26 Avatar asked Oct 17 '22 19:10

tmwilson26


1 Answers

I solved this issue by added the 1.0 to the value of every part of the colorTuple to turn it into an RGBA array. I went through the pcolormesh function and found that it was calling the color convertor to convert the RGB to an RGBA array 4 different times, taking about 50 seconds each time. If you give it an RGBA array to start, it will bypass this and produce the plot in a reasonable timeframe. The additional line of code that was added is seen below:

if one_channel:
    m.pcolormesh(lons, lats, img[:,:,0], latlon=True)
else:
    mesh_rgb = img[:, :-1, :]
    colorTuple = mesh_rgb.reshape((mesh_rgb.shape[0] * mesh_rgb.shape[1]), 3)

    # ADDED THIS LINE
    colorTuple = np.insert(colorTuple,3,1.0,axis=1)

    # What you put in for the image doesn't matter because of the color mapping
    m.pcolormesh(lons, lats, img[:,:,0], latlon=True,color=colorTuple)
like image 173
tmwilson26 Avatar answered Oct 20 '22 23:10

tmwilson26