How to prune weights of a CNN (convolution neural network) model which is less than a threshold value (let's consider prune all weights which are <= 1).
How we can achieve that for a weight file saved in .pth format in pytorch?
PyTorch since 1.4.0 provides model pruning out of the box, see official tutorial.
As there is no threshold method to prune in PyTorch currently, you have to implement it yourself, though it's kinda easy once you get the overall idea.
Below is a code performing pruning:
from torch.nn.utils import prune
class ThresholdPruning(prune.BasePruningMethod):
PRUNING_TYPE = "unstructured"
def __init__(self, threshold):
self.threshold = threshold
def compute_mask(self, tensor, default_mask):
return torch.abs(tensor) > self.threshold
Explanation:
PRUNING_TYPE can be one of global, structured, unstructured. global acts across whole module (e.g. remove 20% of weight with smallest value), structured acts on whole channels/modules. We need unstructured as we would like to modify each connection in specific parameter tensor (say weight or bias)__init__ - pass here whatever you want or need to make it work, normal stuffcompute_mask - mask to be used to prune specific tensor. In our case all parameters below threshold should be zero. I did it with absolute value as it makes more sense. default_mask is not needed here, but is left as named parameter as that's what API requires atm.Moreover, inheriting from prune.BasePruningMethod defines methods to apply the mask to each parameter, make pruning permanent etc. See base class docs for more info.
Nothing too fancy, you can put anything you want here:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.first = torch.nn.Linear(50, 30)
self.second = torch.nn.Linear(30, 10)
def forward(self, inputs):
return self.second(torch.relu(self.first(inputs)))
module = MyModule()
You can also load your module via module = torch.load('checkpoint.pth')
if you need, it doesn't matter.
We should define which parameter of our module (and whether it's weight or bias) should be pruned, like this:
parameters_to_prune = ((module.first, "weight"), (module.second, "weight"))
Now, we can apply globally our unstructured pruning to all defined parameters (threshold is passed as kwarg to __init__ of ThresholdPruning):
prune.global_unstructured(
parameters_to_prune, pruning_method=ThresholdPruning, threshold=0.1
)
weight attributeTo see the effect, check weights of first submodule simply with:
print(module.first.weight)
It is a weight with our pruning technique applied, but please notice it's not a torch.nn.Parameter anymore! Now it is simply an attribute of our model, hence it won't take part in training or evaluation currently.
weight_maskWe can check created mask via module.first.weight_mask to see everything is done correctly (it will be binary in this case).
weight_origApplying pruning creates a new torch.nn.Parameter with original weights named name + _orig, in this case weight_orig, let's see:
print(module.first.weight_orig)
This parameter will be used during training and evaluation currently!. After applying pruning via methods described above there are forward_pre_hooks added which "switch" original weight to weight_orig.
Due to such approach you can define and apply your pruning at any part of training or inference without "destroying" original weights.
If you wish to apply pruning permanently simply issue:
prune.remove(module.first, "weight")
And now our module.first.weight is once again parameter with entries appropriately pruned, module.first.weight_mask is removed and so is module.first.weight_orig. It's what you are probably after.
You can iterate over children to make it permanent:
for child in module.children():
prune.remove(child, "weight")
You could define parameters_to_prune using the same logic:
parameters_to_prune = [(child, "weight") for child in module.children()]
Or if you want only convolution layers to be pruned (or anything else really):
parameters_to_prune = [
(child, "weight")
for child in module.children()
if isinstance(child, torch.nn.Conv2d)
]
threshold was too high and now all your weights are zero rendering results meaningless)forward calls unless you want to finally change to pruned version (simple call to remove)Shai)If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With