Pytorch 모델 저장하기(torach.save(model, PATH)), 모델 불러오기(torch.load(PATH)), 불러온 모델 이어서 학습하기

2022. 9. 28. 01:04Pytorch

torch.save(model, PATH)
# 모델 클래스는 어딘가에 반드시 선언되어 있어야 합니다
model = torch.load(PATH)
model.eval()

이는 단순히 저장된 모델을 불러와서 prediction을 하기 위함이다.

 

 

모델을 저장할때는 모델말고도 파라미터, optimizer, loss등을 함께 저장할 수 있는데, 이 경우 확장자는 .tar 이다.

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

이때 모델을 불러오기 전, 모델은 초기화시켜주어야 하고 optimizer 또한 초기화 시켜줘야한다.

그리고 초기화된 변수에 다시 Dictionary 구조로 저장된 모델 및 파라터를 할당한다.

 

이후 model.eval()를 통해 prediction을 수행하거나

model.train()을 통해 추가 학습을 진행할 수 있다.

 

출처 : https://tutorials.pytorch.kr/beginner/saving_loading_models.html

 

모델 저장하기 & 불러오기

Author: Matthew Inkawhich, 번역: 박정환,. 이 문서에서는 PyTorch 모델을 저장하고 불러오는 다양한 방법을 제공합니다. 이 문서 전체를 다 읽는 것도 좋은 방법이지만, 필요한 사용 예의 코드만 참고하

tutorials.pytorch.kr

사용법 : https://stackoverflow.com/questions/49941426/attributeerror-collections-ordereddict-object-has-no-attribute-eval

 

AttributeError: 'collections.OrderedDict' object has no attribute 'eval'

I have a model file which looks like this OrderedDict([('inp.conv1.conv.weight', (0 ,0 ,0 ,.,.) = -1.5073e-01 6.4760e-02 1.9156e-01 1.2175e-01 3.5886e-02 1.39...

stackoverflow.com