Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

difference between Dataset and TensorDataset in pyTorch

what is the difference between "torch.utils.data.TensorDataset" and "torch.utils.data.Dataset" - the docs are not clear about that and I could not find any answers on google.

like image 605
Moran Reznik Avatar asked Nov 17 '25 03:11

Moran Reznik


1 Answers

The Dataset class is an abstract class that is used to define new types of (customs) datasets. Instead, the TensorDataset is a ready to use class to represent your data as list of tensors.

You can define your custom dataset in the following way:

class CustomDataset(torch.utils.data.Dataset):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    # Your code

    self.instances = your_data

  def __getitem__(self, idx):
    return self.instances[idx] # In case you stored your data on a list called instances

  def __len__(self):
    return len(self.instances)

If you just want to create a dataset that contains tensors for input features and labels, then use the TensorDataset directly:

dataset = TensorDataset(input_features, labels)

Note that input_features and labels must match on the length of the first dimension.

like image 146
OSainz Avatar answered Nov 18 '25 21:11

OSainz



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!