Pytorch 모델 저장하기(torach.save(model, PATH)), 모델 불러오기(torch.load(PATH)), 불러온 모델 이어서 학습하기
2022. 9. 28. 01:04ㆍPytorch
torch.save(model, PATH)
# 모델 클래스는 어딘가에 반드시 선언되어 있어야 합니다
model = torch.load(PATH)
model.eval()
이는 단순히 저장된 모델을 불러와서 prediction을 하기 위함이다.
모델을 저장할때는 모델말고도 파라미터, optimizer, loss등을 함께 저장할 수 있는데, 이 경우 확장자는 .tar 이다.
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
이때 모델을 불러오기 전, 모델은 초기화시켜주어야 하고 optimizer 또한 초기화 시켜줘야한다.
그리고 초기화된 변수에 다시 Dictionary 구조로 저장된 모델 및 파라터를 할당한다.
이후 model.eval()를 통해 prediction을 수행하거나
model.train()을 통해 추가 학습을 진행할 수 있다.
출처 : https://tutorials.pytorch.kr/beginner/saving_loading_models.html