【PyTorch】How to save ML model?
Here, I'll explain how to save machine learning model.
1. Create Model
Frist, creating model.
import timm
# Load a TIMM model
model = timm.create_model('resnet50', pretrained=True)
2. Save
PyTorch has two ways for saving model.
- Saving the Model State Dictonary
- Saving the Entire Model
Let's see in below.
2.1 Saving the Model State Dictonary
This saves only the model's parameters(weights and biases), which is useful for fine-tuning or resuming training:
torch.save(model.state_dict(), 'model_weights.pth') # Saves only the model's parameters
2.2 Saving the Entire Model
This saves the whole model, including its architecture, which is useful when you want to load teh model as is, with its architecture and parameters intact:
torch.save(model, 'entire_model.pth') # Saves the entire model
3. Load
3.1 State Dictionary
We have to redefine model before loading weight.
new_model = timm.create_model('resnet50') # Load an empty model of the same type
new_model.load_state_dict(torch.load('model_weights.pth')) # Load the saved parameters
3.2 Entire
We can load as is.
loaded_model = torch.load('entire_model.pth') # Loads the whole model
4. Which should use?
4.1 Flexibility
Saving the state dctionary is more robust acrosss different versions of PyTorch and allows for more flexibility in terms of modifying the model architecture.
4.2 Security
Saving the entire model is simpler since it doesn't require you to redefine the model architecture when loading it. However, this method relies on Python's pickle module, which has known security risks when loading objects from untrusted sources.
4.3 Conclution
Unless you have a specific reason to save the entire model, it's generally recommended to save only the state dictionary. This approach is safer and more flexible, especially for larger projects or when collaborating with others where model definitions might change.
5. Summary
This time, I explained about how to save PyTorch model.
Thank you for reading.
Discussion