Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

libtorch (PyTorch C++) weird class syntax

In the official PyTorch C++ examples on GitHub Here you can witness a strange definition of a class:

class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {...}

My understanding is that this defines a class CustomDataset which "inherits from" or "extends" torch::data::datasets::Dataset<CustomDataset>. This is weird to me since the class we're creating is inheriting from another class which is parameterized by the class we're creating...How does this even work? What does it mean? This seems to me like an Integer class inheriting from vector<Integer>, which seems absurd.

like image 279
JacKeown Avatar asked Apr 20 '20 03:04

JacKeown


People also ask

Can PyTorch run in C++?

TorchScript C++ API TorchScript allows PyTorch models defined in Python to be serialized and then loaded and run in C++ capturing the model code via compilation or tracing its execution. You can learn more in the Loading a TorchScript Model in C++ tutorial.

How to use autograd operator in PyTorch?

The backward function receives the gradient of the output Tensors with respect to some scalar value, and computes the gradient of the input Tensors with respect to that same scalar value. In PyTorch we can easily define our own autograd operator by defining a subclass of torch.autograd.Function and implementing the forward and backward functions.

What is PyTorch and how does it work?

This tutorial introduces the fundamental concepts of PyTorch through self-contained examples. At its core, PyTorch provides two main features: y=\sin (x) y = sin(x) with a third order polynomial as our running example.

Can I use torchscript with Python models?

This means you can define your models in Python as much as possible, but subsequently export them via TorchScript for doing no-Python execution in production or embedded environments. The TorchScript C++ API is used to interact with these models and the TorchScript execution engine, including:


1 Answers

This is the curiously-recurring template pattern, or CRTP for short. A major advantage of this technique is that it enabled so-called static polymorphism, meaning that functions in torch::data::datasets::Dataset can call into functions of CustomDataset, without needing to make those functions virtual (and thus deal with the runtime mess of virtual method dispatch and so on). You can also perform compile-time metaprogramming such as compile-time enable_ifs depending on the properties of the custom dataset type.

In the case of PyTorch, BaseDataset (the superclass of Dataset) uses this technique heavily to support operations such as mapping and filtering:

  template <typename TransformType>
  MapDataset<Self, TransformType> map(TransformType transform) & {
    return datasets::map(static_cast<Self&>(*this), std::move(transform));
  }

Note the static cast of this to the derived type (legal as long as CRTP is properly applied); datasets::map constructs a MapDataset object which is also parametrized by the dataset type, allowing the MapDataset implementation to statically call methods such as get_batch (or encounter a compile-time error if they do not exist).

Furthermore, since MapDataset receives the custom dataset type as a type parameter, compile-time metaprogramming is possible:

  /// The implementation of `get_batch()` for the stateless case, which simply
  /// applies the transform to the output of `get_batch()` from the dataset.
  template <
      typename D = SourceDataset,
      typename = torch::disable_if_t<D::is_stateful>>
  OutputBatchType get_batch_impl(BatchRequestType indices) {
    return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
  }

  /// The implementation of `get_batch()` for the stateful case. Here, we follow
  /// the semantics of `Optional.map()` in many functional languages, which
  /// applies a transformation to the optional's content when the optional
  /// contains a value, and returns a new optional (of a different type)  if the
  /// original optional returned by `get_batch()` was empty.
  template <typename D = SourceDataset>
  torch::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
      BatchRequestType indices) {
    if (auto batch = dataset_.get_batch(std::move(indices))) {
      return transform_.apply_batch(std::move(*batch));
    }
    return nullopt;
  }

Notice that the conditional enable is dependent on SourceDataset, which we only have available because the dataset is parametrized with this CRTP pattern.

like image 51
nanofarad Avatar answered Sep 30 '22 10:09

nanofarad