In the official PyTorch C++ examples on GitHub Here you can witness a strange definition of a class:
class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {...}
My understanding is that this defines a class CustomDataset
which "inherits from" or "extends" torch::data::datasets::Dataset<CustomDataset>
. This is weird to me since the class we're creating is inheriting from another class which is parameterized by the class we're creating...How does this even work? What does it mean? This seems to me like an Integer
class inheriting from vector<Integer>
, which seems absurd.
TorchScript C++ API TorchScript allows PyTorch models defined in Python to be serialized and then loaded and run in C++ capturing the model code via compilation or tracing its execution. You can learn more in the Loading a TorchScript Model in C++ tutorial.
The backward function receives the gradient of the output Tensors with respect to some scalar value, and computes the gradient of the input Tensors with respect to that same scalar value. In PyTorch we can easily define our own autograd operator by defining a subclass of torch.autograd.Function and implementing the forward and backward functions.
This tutorial introduces the fundamental concepts of PyTorch through self-contained examples. At its core, PyTorch provides two main features: y=\sin (x) y = sin(x) with a third order polynomial as our running example.
This means you can define your models in Python as much as possible, but subsequently export them via TorchScript for doing no-Python execution in production or embedded environments. The TorchScript C++ API is used to interact with these models and the TorchScript execution engine, including:
This is the curiously-recurring template pattern, or CRTP for short. A major advantage of this technique is that it enabled so-called static polymorphism, meaning that functions in torch::data::datasets::Dataset
can call into functions of CustomDataset
, without needing to make those functions virtual (and thus deal with the runtime mess of virtual method dispatch and so on). You can also perform compile-time metaprogramming such as compile-time enable_if
s depending on the properties of the custom dataset type.
In the case of PyTorch, BaseDataset
(the superclass of Dataset
) uses this technique heavily to support operations such as mapping and filtering:
template <typename TransformType> MapDataset<Self, TransformType> map(TransformType transform) & { return datasets::map(static_cast<Self&>(*this), std::move(transform)); }
Note the static cast of this
to the derived type (legal as long as CRTP is properly applied); datasets::map
constructs a MapDataset
object which is also parametrized by the dataset type, allowing the MapDataset
implementation to statically call methods such as get_batch
(or encounter a compile-time error if they do not exist).
Furthermore, since MapDataset
receives the custom dataset type as a type parameter, compile-time metaprogramming is possible:
/// The implementation of `get_batch()` for the stateless case, which simply /// applies the transform to the output of `get_batch()` from the dataset. template < typename D = SourceDataset, typename = torch::disable_if_t<D::is_stateful>> OutputBatchType get_batch_impl(BatchRequestType indices) { return transform_.apply_batch(dataset_.get_batch(std::move(indices))); } /// The implementation of `get_batch()` for the stateful case. Here, we follow /// the semantics of `Optional.map()` in many functional languages, which /// applies a transformation to the optional's content when the optional /// contains a value, and returns a new optional (of a different type) if the /// original optional returned by `get_batch()` was empty. template <typename D = SourceDataset> torch::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl( BatchRequestType indices) { if (auto batch = dataset_.get_batch(std::move(indices))) { return transform_.apply_batch(std::move(*batch)); } return nullopt; }
Notice that the conditional enable is dependent on SourceDataset
, which we only have available because the dataset is parametrized with this CRTP pattern.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With