Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to apply exponential moving average decay for variables in pytorch?

I am reading following paper. And it uses EMA decay for variables.
Bidirectional Attention Flow for Machine Comprehension

During training, the moving averages of all weights of the model are maintained with the exponential decay rate of 0.999.

They use TensorFlow and I found the related code of EMA.
https://github.com/allenai/bi-att-flow/blob/master/basic/model.py#L229

In PyTorch, how do I apply EMA to Variables?

like image 753
jef Avatar asked Nov 01 '25 07:11

jef


1 Answers

You can implement an Exponential Moving Average (EMA) for model variables by having a copy of your model with a custom update rule.

First, create a copy of your model to store the moving averages of the parameters:

import copy

model = YourModel()
ema_model = copy.deepcopy(model)

Then, define the EMA update function, which will update the moving averages of the model parameters after each training step:

def update_ema_variables(model, ema_model, ema_decay):
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.copy_(ema_param.data * ema_decay + (1 - ema_decay) * param.data)

Finally, call the update_ema_variables function in your training loop after each optimization step:

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
ema_decay = 0.999

for epoch in range(epochs):
    for batch in data_loader:
        # Perform your forward pass, compute loss, and update model parameters
        optimizer.zero_grad()
        output = model(batch)
        loss = criterion(output, batch)
        loss.backward()
        optimizer.step()

        # Update EMA model
        update_ema_variables(model, ema_model, ema_decay)

With this code you can maintain the moving averages of your model's parameters during training. The ema_model will hold the EMA parameters, and you can use it for evaluation or inference.

Alternatively, there are libraries with simple wrapper for this e.g. https://github.com/fadel/pytorch_ema

like image 86
iyop45 Avatar answered Nov 04 '25 00:11

iyop45



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!