5.0 / 5
We will cover the following tasks in 48 minutes:
Before we get to the project, let’s try to understand why we may need to save and restore our models. Perhaps the most obvious reason is that if we want to deploy our trained models in a production setting, perhaps in a web app or mobile app, then we’d need to save the models after training them locally and then restoring them on a server. Another reason could be that you’re working on a problem which needs a really long training time - in that case you may want to save your model after training it for a while and come back later the next day, restore the model that you previously trained and just continue the training. Another reason could be that you have trained a model and now you want another developer or researcher to take a look at it, so you’d need a convenient way to send them the architecture along with the trained weights.
Importing the Data
In this task we are going to import the data that we’d be working with. Instead of downloading the data on the disk before starting the project, I decided it might be worth looking at how to download the data from a URL. Most of the times, you will have to download the data from somewhere so this could be a small little exercise towards that.
We will be working with popular Iris dataset. This dataset has 3 classes and 4 features. The dataset was created by R.A. Fisher who was a very influential statistician and it contains 3 classes of 50 instances each, where each class refers to a type of iris plant.
Creating and Normalizing a Training Set
In this task, we will create a training set - this will have the examples in NumPy array and the labels in another NumPy array. The first three columns are features and the last one is our label. You can access the NumPy arrays from data-frames by simply calling the values parameter. But before that, we should do a couple of things: 1. Shuffle the data so that the examples are not ordered in sequence as per the classes. In the dataset, by default, the first 50 examples are all for one class and the next 50 for the next class and the last 50 for the third class. So, we want to shuffle the order. 2. Next we will separate the features from labels. The last column is the label and the first four are the features. This is easy to do in pandas.
Creating the Model
In this task, we create a function that returns a model. This model is going to be fairly small and of a fixed architecture. We will use just 2 hidden layers each with only 8 hidden units. We will use the
relu activation function for the hidden layers and
softmax activation for the output layer to get the probability scores for the three classes.
The Model Checkpoint
Finally, we get to the part where we save our models! The first method we are looking at is using a Keras callback function called
ModelCheckpoint. This callback is used to save checkpoints during training and also at the end of the training. By default, this callback saves the model after every epoch at the file path given. But that may not be required in many cases, perhaps even most cases. So, you can pass on the argument Period and set it to whatever number of epochs need to complete before the model is saved.
Saving Only the Weights
You don’t have to save the entire model using
ModelCheckpoint. You could just save the trained weights for the checkpoints instead and that will save some space since you won’t need to save the model architecture along with the weights. For this to work when you are restoring the weights, however, you will need to instantiate a model of the same architecture first. Let’s see how to save just the weights. All we have to do is set the
save_weights only argument to
So far, we have seen how to use the
ModelCheckpoint callback to save models during training. What if you wanted to save a model after training? This is again quite straight-forward. Let’s take a look!
About the Host (Amit Yadav)
I have been writing code since 1993, when I was 11, and my first passion project started with a database management software that I wrote for a local hospital. More recently, I wrote an award winning education Chatbot for a multi-billion-revenue company. I solved a recurrent problem for my client where they wanted to make basic cyber safety and privacy education accessible for their users. This bot enabled my client to reach out to their customers with personalised and real-time education. In the last one year, I’ve continued my interest in this field by constantly learning and growing in Machine Learning, NLP and Deep Learning. I'm very excited to share my variety of experience and learnings with you with the help of Rhyme.com.