PyTorch Lightning with OneCycleLR Scheduler

PyTorch Lightning 將許多訓練過程中的細節封裝起來,讓使用者可以專注在模型的設計與訓練上。而在訓練過程中,我們常常會使用到一些學習率調整的方法,例如:OneCycleLR Scheduler。然而,當我們在 PyTorch Lightning 中使用 OneCycleLR Scheduler 時,會遇到一個尷尬的問題。

假設我們有一個簡單的模型,其 main.py 如下:

def main():
    cli = LightningCLI(aModel, aDataModule)

if __name__ == '__main__':
    main()

而我的模型裡定義了一個 scheduler:

class MLP(L.LightningModule):
    ...
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), self.hparams.lr)
        return {
            "optimizer": optimizer,
            "lr_scheduler": optim.lr_scheduler.OneCycleLR(
                optimizer, 
                max_lr=self.hparams.lr, 
                steps_per_epoch=len(train_loader)),
        }

問題來了,這 train_loader 要從哪裡獲得呢?
main.py 中,已經將模型和資料模組封裝在 LightningCLI 中,兩者透過 LightningCLI 獲取參數,因此無法直接在模型取得 train_loader。當然,你也可以不要用 LightningCLI,但是我就是想用呀!!!
欸嘿,別急,PyTorch Lightning 也有發現這個問題,只是它的解法藏的比較深而已,在 Optimization 文件中有提到,Trainer 有個屬性叫做 estimated_stepping_batches,它會估算訓練時,呼叫 optimizer.step() 的次數,包含考慮 gradient accumulation,因此,只要將此參數指定到 OneCycleLRtotal_steps 參數即可。

def configure_optimizers(self):
    optimizer = optim.AdamW(self.parameters(), self.hparams.lr)
    return {
        "optimizer": optimizer,
        "lr_scheduler": optim.lr_scheduler.OneCycleLR(
            optimizer, 
            max_lr=self.hparams.lr, 
            total_steps=self.trainer.estimated_stepping_batches),
    }

References