I recently found my self in a tricky situation. Never had it been easier to save and restore a Tensorflow model than with tf.saved_model.simple_save and then tf.saved_model.loader.load. On the other hand, very little documentation exists regarding the interaction with the Dataset API and how to restore a saved tf.data.Dataset’s Iterator.
This post contains standalone and deterministic code to make it easily reproducible for you: we create a model, train it, save it, restore it and check that inferences match.
The code runs under python 3 and Tensorflow 1.8.
The following code generates random data for the sake of the demonstration.
We start by creating the placeholders. They will hold the data at runtime. From them, we create the Dataset and then its Iterator. We get the iterator’s generated tensor, called input_tensor which will serve as input to our model.
The model itself is built from input_tensor: a GRU-based bidirectional RNN followed by a dense classifier. Because why not.
The loss is a softmax_cross_entropy_with_logits, optimized with Adam. After 2 epochs (of 2 batches each), we save the “trained” model with tf.saved_model.simple_save. If you run the code as is, then the model will be saved in a folder called simple/ in your current working directory.
In a new graph, we then restore the saved model with tf.saved_model.loader.load. We grab the placeholders and logits with graph.get_tensor_by_name and the Iterator initializing operation with graph.get_operation_by_name.
Lastly we run an inference for both batches in the dataset, and check that the saved and restored model both yield the same values. They do!
Before reaching this working piece of code, I kept running into this error :
FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized
It took me a long time to figure it out. I had to go through several StackOverflow questions and many blog posts to even phrase my problem. I even got to Google’s second search page..!
Saving worked fine, restoring the graph too, Tensors were fetched and initialized, what could go wrong?
Here is a skeleton of how my code worked (careful, it’s wrong):
Can you spot what’s wrong here?
Initializing the Iterator
tf.saved_model.simple_save freezes a graph’s variables from a session’s values. When tf.saved_model.loader.load is called, it restores variables in the current default graph. However when we call iterator.initializer, we don’t initilaize the restoredIterator, we initialize the new one! But restored_logits still depends on the restored graph’s input_tensor, which itself was built from the restoredIterator.
So we need to find the right initializing operation. An easy way to do that is to build the Iterator in another way:
Now the operation is super easy to grab from the restored graph:
And that solved it, leading to the piece of code in the beginning. Enjoy!