This recipe provides options to save and reload an entire model or just the parameters of the model. While reloading this recipe copies the parameter from 1 net to another net. There are 3 main functions involved in saving and loading a model in pytorch.
1. torch.save: This saves a serialized object to disk. It uses python's pickle utility for serialization. Models, tensors and dictionaries can be saved using this function.
2. torch.load: torch.load: Uses pickle's unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into.
3. torch.nn.Module.load_state_dict: Loads a model's parameter dictionary using a deserialized state_dict. The learnable parameters (i.e. weights and biases) of an torch.nn.Module model are contained in the model's parameters (accessed with model.parameters()). A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor.
What is PyTorch ?
Pytorch is a Python-based scientific computing package that uses the power of graphics processing units and can replace the numpy library. It is also a very popular deep learning research platform built for flexibility and speed. You can use other Python packages such as NumPy, SciPy to extend PyTorch functionalities.
What is Deep Learning Model ?
Deep learning is a subset of machine learning. Deep learning uses neural networks to make predictions. A neural network takes inputs, which are then processed using hidden layers using weights that are adjusted during training. The model then outputs a prediction.