Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

GroupNorm is considerably slower and consumes higher GPU memory than BatchNorm in Pytorch

Tags:

pytorch

I use GroupNorm in pytorch instead of BatchNorm and keep all the others (network architecture) unchanged. It shows that in Imagenet dataset, using resnet50 architecture, GroupNorm is 40% slower than BatchNorm, and consumes 33% more GPU memory than BatchNorm. I am really confused because GroupNorm shouldn’t need more calculation than BatchNorm. The details are listed below.

For details of Group Normalization, one can see this paper: https://arxiv.org/pdf/1803.08494.pdf

For BatchNorm, one minibatch consumes 12.8 seconds with GPU memory 7.51GB;

For GroupNorm, one minibatch consumes 17.9 seconds with GPU memory 10.02GB.

I use the following code to convert all the BatchNorm layers to GroupNorm layers.

def convert_bn_model_to_gn(module, num_groups=16):
"""
Recursively traverse module and its children to replace all instances of
``torch.nn.modules.batchnorm._BatchNorm`` with :class:`torch.nn.GroupNorm`.
Args:
    module: your network module
    num_groups: num_groups of GN
"""
mod = module
if isinstance(module, nn.modules.batchnorm._BatchNorm):
    mod = nn.GroupNorm(num_groups, module.num_features,
                       eps=module.eps, affine=module.affine)
    # mod = nn.modules.linear.Identity()
    if module.affine:
        mod.weight.data = module.weight.data.clone().detach()
        mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
    mod.add_module(name, convert_bn_model_to_gn(
        child, num_groups=num_groups))
del module
return mod
like image 805
zbh2047 Avatar asked Nov 07 '22 13:11

zbh2047


1 Answers

Yes, you are right GN does use more resources compared to BN. I'm guessing this is because it has to calculate the mean and variance for every group of channels, whereas BN only has to calculate once over the whole batch.

But the advantage with GN, is that you can lower your Batch Size up to 2, without reducing any performance, as stated within the paper, so you can make up for the overhead computation.

like image 162
RevoGen Avatar answered Nov 29 '22 01:11

RevoGen