Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python: Generate a unique batch from given dataset

I'm applying a CNN to classify a given dataset.

My function:

def batch_generator(dataset, input_shape = (256, 256), batch_size = 32):
    dataset_images = []
    dataset_labels = []
    for i in range(0, len(dataset)):
        dataset_images.append(cv2.resize(cv2.imread(dataset[i], cv2.IMREAD_COLOR), 
                     input_shape, interpolation = cv2.INTER_AREA))
        dataset_labels.append(labels[dataset[i].split('/')[-2]])
    return dataset_images, dataset_labels

This function is supposed to be called for every epoch and it should return a unique batch of size 'batch_size' containing dataset_images (each image is 256x256) and corresponding dataset_label from the labels dictionary.

input 'dataset' contains path to all the images, so I'm opening them and resizing them to 256x256. Can someone help me in adding to this code so that is returns the desired batches?

like image 945
Ashar Avatar asked Nov 06 '22 23:11

Ashar


1 Answers

As @jodag suggests, using DataLoaders is a good idea.

I have a snippet of that I use for some of my CNN in Pytorch

from torch.utils.data import Dataset, DataLoader
import torch
class Data(Dataset):
    """
    Constructs a Dataset to be parsed into a DataLoader
    """
    def __init__(self,X,y):
        X = torch.from_numpy(X).float()

        #Transpose to fit dimensions of my network
        X = torch.transpose(X,1,2)

        y = torch.from_numpy(y).float()
        self.X,self.y = X,y

    def __getitem__(self, i):
        return self.X[i],self.y[i]

    def __len__(self):
        return self.X.shape[0]

def create_data_loader(X,y,batch_size,**kwargs):
    """
    Creates a data-loader for the data X and y

    params:
    -------

    X: np.array
        - numpy array of size "n" x k where n is samples an "k" is number of features

    y: np.array
        - numpy array of sie "n"

    batch_size: int
        - Take a wild guess, dumbass

    kwargs:
        - Additional keyword-arguments for "DataLoader"

    return
    ------

    dl: torch.utils.data.DataLoader object
    """

    data = Data(X, y)

    dl = DataLoader(data, batch_size=batch_size,num_workers=0,**kwargs)
    return dl

which is used like this;

from create_data_loader import create_data_loader

train_data_loader= create_data_loader(X_train,y_train,batch_size=32) #Note, it has "shuffle=True" as default!
val_data_loader= create_data_loader(X_val,y_val,batch_size=32,shuffle=False) #If you want to keep index'es in the same order for e.g cross-validate


for x_train, y_train in train_data_loader:
   logit = net(x_train,y_train)
   .
   .
   net.eval()
   for x_val,y_val in val_data_loader:
       logit  = net(x_val,y_val)
       classes_pred = logit.argmax(axis=1)
       print(f"Val accuracy: {(y_val==classes_pred).mean()}")
like image 161
CutePoison Avatar answered Nov 12 '22 18:11

CutePoison