aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint/python/python_state_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/checkpoint/python/python_state_test.py')
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state_test.py5
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py
index 0439a4755e..45494351ff 100644
--- a/tensorflow/contrib/checkpoint/python/python_state_test.py
+++ b/tensorflow/contrib/checkpoint/python/python_state_test.py
@@ -40,10 +40,13 @@ class NumpyStateTests(test.TestCase):
save_state.a = numpy.ones([2, 2])
save_state.b = numpy.ones([2, 2])
save_state.b = numpy.zeros([2, 2])
+ save_state.c = numpy.int64(3)
self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
+ self.assertEqual(3, save_state.c)
first_save_path = saver.save(prefix)
save_state.a[1, 1] = 2.
+ save_state.c = numpy.int64(4)
second_save_path = saver.save(prefix)
load_state = python_state.NumpyState()
@@ -51,6 +54,7 @@ class NumpyStateTests(test.TestCase):
loader.restore(first_save_path).initialize_or_restore()
self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+ self.assertEqual(3, load_state.c)
load_state.a[0, 0] = 42.
self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
loader.restore(first_save_path).run_restore_ops()
@@ -58,6 +62,7 @@ class NumpyStateTests(test.TestCase):
loader.restore(second_save_path).run_restore_ops()
self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+ self.assertEqual(4, load_state.c)
def testNoGraphPollution(self):
graph = ops.Graph()