pytorchでレイヤーをフリーズさせる
画像分野で深層学習を用いる際はImageNetのような大規模データセットで学習した学習済みモデルを使用して転移学習、ファインチューニングを行うことが一般的です。
このファインチューニングを行う際に、最初の数エポックはモデルの最終層以外の重みを固定し最終層だけ学習、その後に全レイヤーの重みを学習するといったテクニックがあります。
一般的にニューラルネットワークの入力層に近いほど学習データの抽象的な特徴を学習していると言われていますが、
このテクニックを使うことにより、学習初期にいきなり入力層に近いレイヤーの重みが更新されることで抽象的な特徴抽出機能が崩壊することを防ぐ効果があります。
pytorchでは以下のようにパラメータをrequires_grad=False
することによってbackward()
の際に重みが更新されないようにする(freeze)ことができます。
from torchvision.models import resnet34 model = resnet34(pretrained=True) # freeze all layers for param in model.parameters(): param.requires_grad = False
上記の例だと最終層も含めてすべてのレイヤーがfreezeされてしまうので最終層はrequires_grad=True
のままにしておきます。
from torchvision.models import resnet34 model = resnet34(pretrained=True) # freeze layers except last layer for param in model.parameters(): param.requires_grad = False last_layer = list(model.children())[-1] print(f'except last layer: {last_layer}') for param in last_layer.parameters(): param.requires_grad = True
このようにしてやることで最終層だけ除いてレイヤーの重みを固定することができます。
この状態で数エポック学習を回して最終的に全レイヤーを学習する際はrequires_grad=True
に戻してあげればOKです。
from torchvision.models import resnet34 model = resnet34(pretrained=True) # unfreeze all layers for param in model.parameters(): param.requires_grad = True