Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to prune weights less than a threshold in PyTorch?

How to prune weights of a CNN (convolution neural network) model which is less than a threshold value (let's consider prune all weights which are <= 1).

How we can achieve that for a weight file saved in .pth format in pytorch?

like image 453
MSD Paul Avatar asked May 06 '20 07:05

MSD Paul


Video Answer


1 Answers

PyTorch since 1.4.0 provides model pruning out of the box, see official tutorial.

As there is no threshold method to prune in PyTorch currently, you have to implement it yourself, though it's kinda easy once you get the overall idea.

Threshold Pruning method

Below is a code performing pruning:

from torch.nn.utils import prune


class ThresholdPruning(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    def __init__(self, threshold):
        self.threshold = threshold

    def compute_mask(self, tensor, default_mask):
        return torch.abs(tensor) > self.threshold

Explanation:

  • PRUNING_TYPE can be one of global, structured, unstructured. global acts across whole module (e.g. remove 20% of weight with smallest value), structured acts on whole channels/modules. We need unstructured as we would like to modify each connection in specific parameter tensor (say weight or bias)
  • __init__ - pass here whatever you want or need to make it work, normal stuff
  • compute_mask - mask to be used to prune specific tensor. In our case all parameters below threshold should be zero. I did it with absolute value as it makes more sense. default_mask is not needed here, but is left as named parameter as that's what API requires atm.

Moreover, inheriting from prune.BasePruningMethod defines methods to apply the mask to each parameter, make pruning permanent etc. See base class docs for more info.

Example module

Nothing too fancy, you can put anything you want here:

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.first = torch.nn.Linear(50, 30)
        self.second = torch.nn.Linear(30, 10)

    def forward(self, inputs):
        return self.second(torch.relu(self.first(inputs)))


module = MyModule()

You can also load your module via module = torch.load('checkpoint.pth') if you need, it doesn't matter.

Prune module's parameters

We should define which parameter of our module (and whether it's weight or bias) should be pruned, like this:

parameters_to_prune = ((module.first, "weight"), (module.second, "weight"))

Now, we can apply globally our unstructured pruning to all defined parameters (threshold is passed as kwarg to __init__ of ThresholdPruning):

prune.global_unstructured(
    parameters_to_prune, pruning_method=ThresholdPruning, threshold=0.1
)

Results

weight attribute

To see the effect, check weights of first submodule simply with:

print(module.first.weight)

It is a weight with our pruning technique applied, but please notice it's not a torch.nn.Parameter anymore! Now it is simply an attribute of our model, hence it won't take part in training or evaluation currently.

weight_mask

We can check created mask via module.first.weight_mask to see everything is done correctly (it will be binary in this case).

weight_orig

Applying pruning creates a new torch.nn.Parameter with original weights named name + _orig, in this case weight_orig, let's see:

print(module.first.weight_orig)

This parameter will be used during training and evaluation currently!. After applying pruning via methods described above there are forward_pre_hooks added which "switch" original weight to weight_orig.

Due to such approach you can define and apply your pruning at any part of training or inference without "destroying" original weights.

Applying pruning permanently

If you wish to apply pruning permanently simply issue:

prune.remove(module.first, "weight")

And now our module.first.weight is once again parameter with entries appropriately pruned, module.first.weight_mask is removed and so is module.first.weight_orig. It's what you are probably after.

You can iterate over children to make it permanent:

for child in module.children():
    prune.remove(child, "weight")

You could define parameters_to_prune using the same logic:

parameters_to_prune = [(child, "weight") for child in module.children()]

Or if you want only convolution layers to be pruned (or anything else really):

parameters_to_prune = [
    (child, "weight")
    for child in module.children()
    if isinstance(child, torch.nn.Conv2d)
]

Advantages

  • uses "PyTorch way of pruning" so it's easier to communicate your intent to other programmers
  • define pruning on a per-tensor basis, single responsibility instead of going through everything
  • confine to predefined ways
  • pruning is not permanent hence you can recover from it if needed. Module can be saved with pruning masks and original weights so it leaves you some space to revert eventual mistake (e.g. threshold was too high and now all your weights are zero rendering results meaningless)
  • works with original weights during forward calls unless you want to finally change to pruned version (simple call to remove)

Disadvantages

  • IMO pruning API could be clearer
  • You can do it shorter (as provided by Shai)
  • might be confusing for those who do not know such thing is "defined" by PyTorch (still there are tutorials and docs so I don't think it's a major problem)
like image 102
Szymon Maszke Avatar answered Sep 29 '22 13:09

Szymon Maszke