Freezing models in Tensorflow

Not my original post. However made some updates minor in the posts with latest version

Freeze Tensorflow models and serve on web

In this tutorial, we shall learn how to freeze a trained Tensorflow Model and serve it on a webserver. You can do this for any network you have trained but we shall use the trained model for dog/cat classification in this earlier tutorial and serve it on a python Flask webserver. 
So you trained a new model using Tensorflow and now you want to show it off to your friends and colleagues on a website for a hackathon or a new startup idea. Let’s see how to achieve that.

1. What is Freezing Tensorflow models?

Neural networks are computationally very expensive. Let’s look at the architecture of alexnet which is relatively simple neural network: 
alexnet network architecture: tensorflow tutorial
Let’s calculate the number of variables required for prediction: 
         conv1 layer: (11*11)*3*96 (weights) + 96 (biases)            = 34944 
         conv2 layer: (5*5)*96*256 (weights)+ 256 (biases)          = 614656
         conv3 layer: (3*3)*256*384 (weights) + 384 (biases)        = 885120 
         conv4 layer: (3*3)*384*384 (weights) + 384 (biases)        = 1327488
         conv5 layer: (3*3)*384*256 (weights) + 256 (biases)        = 884992
         fc1 layer:      (6*6)*256*4096 (weights) + 4096 (biases)    = 37752832
         fc2 layer:      4096*4096 (weights) + 4096 (biases)           = 16781312
         fc3 layer:      4096*1000 (weights) + 1000 (biases)           = 4097000 
This is more than 60 million parameters that we shall need to calculate a prediction on one image. Apart from it, we also have similar number of gradients that are calculated and used to perform backward propagation during training. Tensorflow models contain all of these variables. Think about it, you don’t need the gradients when you deploy your model on a webserver so why carry all this load. Freezing is the process to identify and save all of required things(graph, weights etc) in a single file that you can easily use.  
  1. model-ckpt.meta: This contains the complete graph. [This contains a serialized MetaGraphDef protocol buffer. It contains the graphDef that describes the data-flow, annotations for variables, input pipelines and other relevant information
  2. model-ckpt.data-0000-of-00001: This contains all the values of variables(weights, biases, placeholders,gradients, hyper-parameters etc). 
  3. model-ckpt.index: metadata. [ It’s an immutable table(tensoflow::table::Table). Each key is a name of a Tensor and it’s value is a serialized BundleEntryProto. Each BundleEntryProto describes the metadata of a Tensor]
  4. checkpoint: All checkpoint information
 So, in summary, when we are deploying to webserver we want to get rid of unnecessary meta-data, gradients and unnecessary training variabels and encapsulate it all in a single file . This single encapsulated file(.pb extension) is called frozen graph def”. It’s essentially a serialized graph_def protocol buffer written to disk.
In the next section, we shall learn how we can freeze the trained model. 

2. Freezing the graph: 

We have a trained model and we want to selectively choose and save the variables we will need for inference. You can download the model from here. Here are the steps to do this:
  1. Restore the model (load graph using .meta file and restore weights inside a session). Convert the graph to graph_def. 

 
  1. We choose which outputs we want from the network. A lot of times you will only be choosing the prediction node. But it’s possible to choose multiple values so that multiple graphs are saved. In our case, we want only y_pred as we want the predictions. 
  1. Now, we shall use convert_variables_to_constants function in graph_util to pass the session, graph_def and the ends that we want to save.  

 
  1. Finally we serialize and write the output graph to the file system. 
Look at the size of the model. This has reduced significantly from 25 MB to 8.2 MB.

3. Using the frozen Model:

Now, let’s see how we shall use this frozen model for prediction.
Step-1: Load the frozen file and parse it to get the unserialized graph_def

 
Step-2: Now, we just import the graph_def using tf.import_graph_def function. 

 
This function takes a few parameters: 
input_map: A dictionary mapping input names in restored_graph_def to Tensors
return_elements: You can choose to return specify Tensors/Operations from import_graph_def. The name of the operations can be specified in return_elements like this. 

 
Now, the complete graph with values has been restored and we can use this to predict like we earlier did. Here is the complete code. 

 

4. Deploying to a webserver: 

Finally, we shall deploy this model to a python webserver. We will install python webserver and deploy the code. The webserver will allow user to upload an image and then will generate a prediction i.e. if it’s a dog or cat. Let’s start with installing flask webserver. Here is how you can install flask.

Now, we shall add all the code we discussed above in a webapp.py file after creating a flask app. Let’s first create a flask app:

The frozen weights are 8 MB in our case. If webserver has to load weights for each request and then do the inference, it will take a lot more time. So, we shall create the graph and load the weights and keep that in the memory so that we can quickly serve each incoming request.

And then call this function at the start of the webserver, so that it’s accessible for each request and happens only once.

Finally, we create an end-point for our web-server which allows an user to upload an image and run prediction:

Once, done we can upload an image using the UI here:
 After clicking, upload you can see the results like this:
Hopefully, this post helps you in showing off your newly trained model. May be you could build the hotDog vs not-hotDog model and raise millions of dollars for your cool new startup. However, the way to deploy a tensorflow model on production is Tensorflow-Serving infrastructure which we shall cover in a future post. The code shared in this post can be downloaded from our github repo.

Comments

Popular posts from this blog

SOX - Sound eXchange - How to use SOX for audio processing tasks in research.

Sox of Silence - Original post - http://digitalcardboard.com/blog/2009/08/25/the-sox-of-silence/

How to get video or audio duration of a file using ffmpeg?