일단 전이 학습에 대한 한글 자료는 네이버에 많다. 여기서 전이 학습을 찾아보면 십중 팔구 PyTorch의 ants와 bees를 구별하는 글을 번역한 것을 볼 수 있다. 실제로 자기 컴퓨터에 맞게 수정한 글들도 있지만, 거의 복불 같은 느낌이다.
https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
일단 이 코드를 기반으로 DenseNet에 대하여 해보면 오류가 발생한다. 처음에 한글로 된 자료를 찾지 못하여 영어 자료를 찾다가 다음의 글에서 해결책을 찾을 수 있었다. 그 전에 무엇이 문제였는지부터 알아보자.
https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
우선 ResNet 모델을 보면, 마지막 부분에 (fc)라는 부분이 있을 것이다.
import torchvision
model_ft = torchvision.models.resnet18(pretrained=True)
print(model_ft)
----
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
그래서 ResNet 전이학습에서는 이 부분을 연결하기 위해서 다음의 코드가 들어간다.
num_ftrs = model_ft.fc.in_features
그럼 DenseNet를 확인해 본다. 마지막에 classifier라고 되어 있는 것을 볼 수 있다.
import torchvision
model_ft = torchvision.models.densenet161(pretrained=True)
----
)
(norm5): BatchNorm2d(2208, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(classifier): Linear(in_features=2208, out_features=1000, bias=True)
)
그렇다면 classifier가 추가되어야 할 것이다. 위에서 언급한 링크에는 다음과 같이 예시되어 있다. 인터넷 자료에는 다른 classifier가 있기는 한데, 일단 공식(?) 자료는 간단하게 다음의 한 줄을 예로 들어 놓았다.
model.classifier = nn.Linear(1024, num_classes)
이제 실행해보면 될 것이다.