Rate this Page
โ˜… โ˜… โ˜… โ˜… โ˜…

Patching Batch Norm#

Created On: Jan 03, 2023 | Last Updated On: Jun 11, 2025

Whatโ€™s happening?#

Batch Norm requires in-place updates to running_mean and running_var of the same size as the input. Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e. regular.add_(batched) is not allowed). So when vmapping over a batch of inputs to a single module, we end up with this error

How to fix#

One of the best supported ways is to switch BatchNorm for GroupNorm. Options 1 and 2 support this

All of these options assume that you donโ€™t need running stats. If youโ€™re using a module this means that itโ€™s assumed you wonโ€™t use batch norm in evaluation mode. If you have a use case that involves running batch norm with vmap in evaluation mode, please file an issue

Option 1: Change the BatchNorm#

If you want to change for GroupNorm, anywhere that you have BatchNorm, replace it with:

BatchNorm2d(C, G, track_running_stats=False)

Here C is the same C as in the original BatchNorm. G is the number of groups to break C into. As such, C % G == 0 and as a fallback, you can set C == G, meaning each channel will be treated separately.

If you must use BatchNorm and youโ€™ve built the module yourself, you can change the module to not use running stats. In other words, anywhere that thereโ€™s a BatchNorm module, set the track_running_stats flag to be False

BatchNorm2d(64, track_running_stats=False)

Option 2: torchvision parameter#

Some torchvision models, like resnet and regnet, can take in a norm_layer parameter. These are often defaulted to be BatchNorm2d if theyโ€™ve been defaulted.

Instead you can set it to be GroupNorm.

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))

Here, once again, c % g == 0 so as a fallback, set g = c.

If you are attached to BatchNorm, be sure to use a version that doesnโ€™t use running stats

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))

Option 3: functorchโ€™s patching#

functorch has added some functionality to allow for quick, in-place patching of the module to not use running stats. Changing the norm layer is more fragile, so we have not offered that. If you have a net where you want the BatchNorm to not use running stats, you can run replace_all_batch_norm_modules_ to update the module in-place to not use running stats

from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)

Option 4: eval mode#

When run under eval mode, the running_mean and running_var will not be updated. Therefore, vmap can support this mode

model.eval()
vmap(model)(x)
model.train()