MACHINE LEARNING/PYTORCH

Pytorch 모델 저장하기 & 불러오기

24_bean 2022. 8. 5. 13:54

모델을 저장하고 불러올 때는 3가지의 핵심 함수가 있습니다.

 

1. torch.save() : 직렬화된 객체를 디스크에 저장합니다. 이 함수는 Python 의 pickle을 사용하여 직렬화하고 객체를 저장합니다.

2. torch.load() : 객체 파일을 역직렬화하여 메모리에 올립니다. 이 함수는 데이터를 장치에 불러올 때도 사용합니다.

3. torch.nn.Module.load_state_dict : 역직렬화된 state_dict를 사용하여 모델의 매개변수를 불러옵니다.


state_dict

이름에서도 알 수 있듯이 해당 모델의 state(상태)를 dict(딕셔너리) 형태로 가지고 있는 객체입니다.

만약 예를 든다면 다음과 같이 나올 수 있습니다.

Model's state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])

Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

저장하기 :

torch.save(model.state_dict(), PATH)
# .pt 혹은 .pth 확장자를 사용하는 것이 일반적인 규칙입니다.

 

불러오기 : 

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

전체 모델을 저장하거나 불러올 수도 있습니다.

저장하기 :

torch.save(model, PATH)

불러오기 : 

# 모델 클래스는 어딘가에 반드시 선언되어 있어야 합니다
model = torch.load(PATH)
model.eval()

일반 체크포인트(checkpoint) 를 저장하거나 불러올 수도 있습니다.

저장하기 :

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)
            
# 일반적으로 checkpoint를 저장할 때는 .tar 확장자를 사용합니다.

불러오기 : 

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

일반적으로, 이런 체크포인트는 종종 모델만 저장하는 것보다 2~3배 커지게 됩니다.


여기까지는 기본적으로 모델을 저장하고 불러오는 방법에 대해 정리했습니다.

 

여러개의 모델을 하나의 파일에 저장하거나, 

다른 모델의 매개변수를 가져와 사용하거나(warmup),

장치 간 모델을 주고받는 경우 등 다양한 사용법이 추가적으로 있으니 조금 더 자세히 알고싶다면 레퍼런스에 있는 주소를 따라가주시기 바랍니다.


Reference : pytorch 공식 문서 https://pytorch.org/tutorials/beginner/saving_loading_models.html