How to prune weights of a CNN (convolution neural network) model which is less than a threshold value (let's consider prune all weights which are <= 1).
How we can achieve that for a weight file saved in .pth format in pytorch?
PyTorch since 1.4.0
provides model pruning out of the box, see official tutorial.
As there is no threshold
method to prune in PyTorch currently, you have to implement it yourself, though it's kinda easy once you get the overall idea.
Below is a code performing pruning:
from torch.nn.utils import prune
class ThresholdPruning(prune.BasePruningMethod):
PRUNING_TYPE = "unstructured"
def __init__(self, threshold):
self.threshold = threshold
def compute_mask(self, tensor, default_mask):
return torch.abs(tensor) > self.threshold
Explanation:
PRUNING_TYPE
can be one of global
, structured
, unstructured
. global
acts across whole module (e.g. remove 20%
of weight with smallest value), structured
acts on whole channels/modules. We need unstructured
as we would like to modify each connection in specific parameter tensor (say weight
or bias
)__init__
- pass here whatever you want or need to make it work, normal stuffcompute_mask
- mask to be used to prune specific tensor. In our case all parameters below threshold should be zero. I did it with absolute value as it makes more sense. default_mask
is not needed here, but is left as named parameter as that's what API requires atm.Moreover, inheriting from prune.BasePruningMethod
defines methods to apply the mask to each parameter, make pruning permanent etc. See base class docs for more info.
Nothing too fancy, you can put anything you want here:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.first = torch.nn.Linear(50, 30)
self.second = torch.nn.Linear(30, 10)
def forward(self, inputs):
return self.second(torch.relu(self.first(inputs)))
module = MyModule()
You can also load your module via module = torch.load('checkpoint.pth')
if you need, it doesn't matter.
We should define which parameter of our module (and whether it's weight
or bias
) should be pruned, like this:
parameters_to_prune = ((module.first, "weight"), (module.second, "weight"))
Now, we can apply global
ly our unstructured
pruning to all defined parameters
(threshold
is passed as kwarg
to __init__
of ThresholdPruning
):
prune.global_unstructured(
parameters_to_prune, pruning_method=ThresholdPruning, threshold=0.1
)
weight
attributeTo see the effect, check weights of first
submodule simply with:
print(module.first.weight)
It is a weight with our pruning technique applied, but please notice it's not a torch.nn.Parameter
anymore! Now it is simply an attribute of our model, hence it won't take part in training or evaluation currently.
weight_mask
We can check created mask via module.first.weight_mask
to see everything is done correctly (it will be binary in this case).
weight_orig
Applying pruning creates a new torch.nn.Parameter
with original weights named name + _orig
, in this case weight_orig
, let's see:
print(module.first.weight_orig)
This parameter will be used during training and evaluation currently!. After applying pruning
via methods described above there are forward_pre_hooks
added which "switch" original weight
to weight_orig
.
Due to such approach you can define and apply your pruning at any part of training
or inference
without "destroying" original weights.
If you wish to apply pruning permanently simply issue:
prune.remove(module.first, "weight")
And now our module.first.weight
is once again parameter with entries appropriately pruned, module.first.weight_mask
is removed and so is module.first.weight_orig
. It's what you are probably after.
You can iterate over children
to make it permanent:
for child in module.children():
prune.remove(child, "weight")
You could define parameters_to_prune
using the same logic:
parameters_to_prune = [(child, "weight") for child in module.children()]
Or if you want only convolution
layers to be pruned (or anything else really):
parameters_to_prune = [
(child, "weight")
for child in module.children()
if isinstance(child, torch.nn.Conv2d)
]
threshold
was too high and now all your weights are zero rendering results meaningless)forward
calls unless you want to finally change to pruned version (simple call to remove
)Shai
)If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With