aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/util_test.py
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-16 18:35:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 18:39:16 -0700
commit8bca3e4ed80a212bdcd8dc1c8505c4e92d2eac15 (patch)
tree50e3da0d2911d07331b9f13e8fc6c443d2eccf5e /tensorflow/python/training/checkpointable/util_test.py
parent4d5f6fb8b296bfbd7f72eabd9b7a9a8d29eab633 (diff)
tf.contrib.checkpoint.NumpyState for saving/restoring NumPy arrays with TF checkpoints
A bit of extra infrastructure in checkpointable restore (save was already done) to support Python callbacks. The same strategy should work for any Python state, although it's confined to non-pickled NumPy arrays at the moment. PiperOrigin-RevId: 209085928
Diffstat (limited to 'tensorflow/python/training/checkpointable/util_test.py')
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py15
1 files changed, 5 insertions, 10 deletions
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index a0a87b6b79..cac293e916 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -1073,16 +1073,11 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(5, self.evaluate(checkpoint.var_5))
self.assertEqual(1, self.evaluate(checkpoint.var_1))
self.assertEqual(0, self.evaluate(checkpoint.var_0))
- if context.executing_eagerly():
- checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops()
- self.assertEqual(9, self.evaluate(checkpoint.var_9))
- self.assertEqual(8, self.evaluate(checkpoint.var_8))
- self.assertEqual(1, self.evaluate(checkpoint.var_1))
- self.assertEqual(0, self.evaluate(checkpoint.var_0))
- else:
- # Restoring into modified graphs is an error while graph building.
- with self.assertRaises(NotImplementedError):
- checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops()
+ checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops()
+ self.assertEqual(9, self.evaluate(checkpoint.var_9))
+ self.assertEqual(8, self.evaluate(checkpoint.var_8))
+ self.assertEqual(1, self.evaluate(checkpoint.var_1))
+ self.assertEqual(0, self.evaluate(checkpoint.var_0))
def testManyRestoresGraph(self):
"""Restores after the first should not modify the graph."""