diff options
author | 2018-08-16 18:35:03 -0700 | |
---|---|---|
committer | 2018-08-16 18:39:16 -0700 | |
commit | 8bca3e4ed80a212bdcd8dc1c8505c4e92d2eac15 (patch) | |
tree | 50e3da0d2911d07331b9f13e8fc6c443d2eccf5e /tensorflow/python/training/checkpointable/util_test.py | |
parent | 4d5f6fb8b296bfbd7f72eabd9b7a9a8d29eab633 (diff) |
tf.contrib.checkpoint.NumpyState for saving/restoring NumPy arrays with TF checkpoints
A bit of extra infrastructure in checkpointable restore (save was already done) to support Python callbacks.
The same strategy should work for any Python state, although it's confined to non-pickled NumPy arrays at the moment.
PiperOrigin-RevId: 209085928
Diffstat (limited to 'tensorflow/python/training/checkpointable/util_test.py')
-rw-r--r-- | tensorflow/python/training/checkpointable/util_test.py | 15 |
1 files changed, 5 insertions, 10 deletions
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index a0a87b6b79..cac293e916 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -1073,16 +1073,11 @@ class CheckpointingTests(test.TestCase): self.assertEqual(5, self.evaluate(checkpoint.var_5)) self.assertEqual(1, self.evaluate(checkpoint.var_1)) self.assertEqual(0, self.evaluate(checkpoint.var_0)) - if context.executing_eagerly(): - checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops() - self.assertEqual(9, self.evaluate(checkpoint.var_9)) - self.assertEqual(8, self.evaluate(checkpoint.var_8)) - self.assertEqual(1, self.evaluate(checkpoint.var_1)) - self.assertEqual(0, self.evaluate(checkpoint.var_0)) - else: - # Restoring into modified graphs is an error while graph building. - with self.assertRaises(NotImplementedError): - checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops() + checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops() + self.assertEqual(9, self.evaluate(checkpoint.var_9)) + self.assertEqual(8, self.evaluate(checkpoint.var_8)) + self.assertEqual(1, self.evaluate(checkpoint.var_1)) + self.assertEqual(0, self.evaluate(checkpoint.var_0)) def testManyRestoresGraph(self): """Restores after the first should not modify the graph.""" |