PyTorch 에서 학습률 조정에 EMA를 기본으로 제공하고 있지는 않다. 하지만 EMA를 이용할 수 있는 코드를 예시로 공개하고 있다. 기록을 위하여 블로그 작성함.
torch.optim — PyTorch 1.12 documentation
torch.optim is a package implementing various optimization algorithms.
Most commonly used methods are already supported, and the interface is general
enough, so that more sophisticated ones can be also easily integrated in the
future.
Most commonly used methods are already supported, and the interface is general
enough, so that more sophisticated ones can be also easily integrated in the
future.
>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>> 0.1 * averaged_model_parameter + 0.9 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)