如何在 PyTorch Lightning 共享 DataModule 和 Model 之間的參數

PyTorch Lightning 將許多訓練過程中的細節封裝起來,讓使用者可以專注在模型的設計與訓練上。而尤其搭配 LightningCLI 使用時,更是方便。然而,當我們在 PyTorch Lightning 中,想要於多個模組間共享參數時,你會發現似乎做不到?

前一篇文章中,我們提到了在使用 OneCycleLR Scheduler 時,如何取得 dataloader 的長度,但是不可能 PyTorchLightning 都只能透過這種特定的參數去取得吧?。因此,這次我們將討論更廣泛的問題,即如何在 LightningCLI 中,共享參數。

假設我們有一個簡單的模型,其 main.py 如下:

def main():
    cli = LightningCLI(aModel, aDataModule)

if __name__ == '__main__':
    main()

而我的模型裡定義了一個 num_classes 參數,而這個參數需要透過 DataModule 獲得:

class MyModel(L.LightningModule):
    def __init__(
        self, 
        lr: float, 
        backbone: str = "resnet18",
        num_classes: int | None = None
    ):
        super().__init__()
        self.save_hyperparameters()

        model = create_model(backbone, num_classes)
        ...

class MyDataModule(L.LightningDataModule):
    def __init__(
        self, 
        data_path: str | Path,
        num_classes: int | None = None,
    ):
        super().__init__()

        self.num_classes = load_data(data_path)
        self.save_hyperparameters()

問題來了,我要怎麼把 num_classesDataModule 傳到 Model 呢? 在 main.py 中,已經將模型和資料模組封裝在 LightningCLI 中,兩者透過 LightningCLI 獲取參數,因此無法直接在模型取得 MyDataModule 的資料。
所以,PyTorch Lightning 提供了共享參數的方法,首先要修改 main.py 中的 LightningCLI,添加 add_arguments_to_parser 函式。此外,加上 parser.link_arguments("data.ARG", "model.ARG", apply_on="instantiate"),這行程式的意義在於將兩個參數連結起來,並且指定在 datamodule 初始後再同步:

class MyCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")

def main():
    cli = MyCLI(
        MyModel, 
        MyDataModule
    )

if __name__ == '__main__':
    main()

然而,如果你按照 PyTorch Lightning 的官方文件,只做了這樣的修改,你可能會發現訓練時,噴了一堆錯誤:

RuntimeError: Error while merging hparams: the keys ['num_classes'] are present in both the LightningModule's and LightningDataModule's hparams but have different values.

原因未知,可能是一種 bug 吧,又或是兩邊都儲存 hparams 可能導致同步出錯,總之,問題出在 save_hyperparameters() 上。因此,我們需要在 MyModel 中,將 num_classessave_hyperparameters() 中移除:

class MyModel(L.LightningModule):
    def __init__(
        self, 
        lr: float, 
        backbone: str = "resnet18",
        num_classes: int | None = None
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["num_classes"])

        model = create_model(backbone, num_classes)
        ...

這樣,就可以在 MyModel MyDataModule 之間共享參數了。

References