모델을 저장하고 불러올 때는 3가지의 핵심 함수가 있습니다.
1. torch.save() : 직렬화된 객체를 디스크에 저장합니다. 이 함수는 Python 의 pickle을 사용하여 직렬화하고 객체를 저장합니다.
2. torch.load() : 객체 파일을 역직렬화하여 메모리에 올립니다. 이 함수는 데이터를 장치에 불러올 때도 사용합니다.
3. torch.nn.Module.load_state_dict : 역직렬화된 state_dict를 사용하여 모델의 매개변수를 불러옵니다.
state_dict
이름에서도 알 수 있듯이 해당 모델의 state(상태)를 dict(딕셔너리) 형태로 가지고 있는 객체입니다.
만약 예를 든다면 다음과 같이 나올 수 있습니다.
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
저장하기 :
torch.save(model.state_dict(), PATH)
# .pt 혹은 .pth 확장자를 사용하는 것이 일반적인 규칙입니다.
불러오기 :
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
전체 모델을 저장하거나 불러올 수도 있습니다.
저장하기 :
torch.save(model, PATH)
불러오기 :
# 모델 클래스는 어딘가에 반드시 선언되어 있어야 합니다
model = torch.load(PATH)
model.eval()
일반 체크포인트(checkpoint) 를 저장하거나 불러올 수도 있습니다.
저장하기 :
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
# 일반적으로 checkpoint를 저장할 때는 .tar 확장자를 사용합니다.
불러오기 :
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
일반적으로, 이런 체크포인트는 종종 모델만 저장하는 것보다 2~3배 커지게 됩니다.
여기까지는 기본적으로 모델을 저장하고 불러오는 방법에 대해 정리했습니다.
여러개의 모델을 하나의 파일에 저장하거나,
다른 모델의 매개변수를 가져와 사용하거나(warmup),
장치 간 모델을 주고받는 경우 등 다양한 사용법이 추가적으로 있으니 조금 더 자세히 알고싶다면 레퍼런스에 있는 주소를 따라가주시기 바랍니다.
Reference : pytorch 공식 문서 https://pytorch.org/tutorials/beginner/saving_loading_models.html
반응형