diff options
Diffstat (limited to 'tensorflow/contrib/checkpoint/python/python_state_test.py')
-rw-r--r-- | tensorflow/contrib/checkpoint/python/python_state_test.py | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py index 0439a4755e..45494351ff 100644 --- a/tensorflow/contrib/checkpoint/python/python_state_test.py +++ b/tensorflow/contrib/checkpoint/python/python_state_test.py @@ -40,10 +40,13 @@ class NumpyStateTests(test.TestCase): save_state.a = numpy.ones([2, 2]) save_state.b = numpy.ones([2, 2]) save_state.b = numpy.zeros([2, 2]) + save_state.c = numpy.int64(3) self.assertAllEqual(numpy.ones([2, 2]), save_state.a) self.assertAllEqual(numpy.zeros([2, 2]), save_state.b) + self.assertEqual(3, save_state.c) first_save_path = saver.save(prefix) save_state.a[1, 1] = 2. + save_state.c = numpy.int64(4) second_save_path = saver.save(prefix) load_state = python_state.NumpyState() @@ -51,6 +54,7 @@ class NumpyStateTests(test.TestCase): loader.restore(first_save_path).initialize_or_restore() self.assertAllEqual(numpy.ones([2, 2]), load_state.a) self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + self.assertEqual(3, load_state.c) load_state.a[0, 0] = 42. self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a) loader.restore(first_save_path).run_restore_ops() @@ -58,6 +62,7 @@ class NumpyStateTests(test.TestCase): loader.restore(second_save_path).run_restore_ops() self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a) self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + self.assertEqual(4, load_state.c) def testNoGraphPollution(self): graph = ops.Graph() |