diff options
Diffstat (limited to 'tensorflow/python/training/checkpointable/tracking_test.py')
-rw-r--r-- | tensorflow/python/training/checkpointable/tracking_test.py | 37 |
1 files changed, 31 insertions, 6 deletions
diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py index 96da0d6e47..f8d17cd417 100644 --- a/tensorflow/python/training/checkpointable/tracking_test.py +++ b/tensorflow/python/training/checkpointable/tracking_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import os import numpy +import six from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import training @@ -144,6 +145,29 @@ class InterfaceTests(test.TestCase): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) @test_util.run_in_graph_and_eager_modes + def testDictionariesBasic(self): + a = training.Model() + b = training.Model() + a.attribute = {"b": b} + c = training.Model() + a.attribute["c"] = [] + a.attribute["c"].append(c) + a_deps = util.list_objects(a) + self.assertIn(b, a_deps) + self.assertIn(c, a_deps) + self.assertIs(b, a.attribute["b"]) + six.assertCountEqual( + self, + ["b", "c"], + [dep.name for dep in a.attribute._checkpoint_dependencies]) + self.assertEqual([b, c], a.layers) + self.assertEqual([b, c], a.attribute.layers) + self.assertEqual([c], a.attribute["c"].layers) + checkpoint = util.Checkpoint(a=a) + save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + checkpoint.restore(save_path).assert_consumed() + + @test_util.run_in_graph_and_eager_modes def testNoDepList(self): a = training.Model() a.l1 = data_structures.NoDependency([]) @@ -159,12 +183,13 @@ class InterfaceTests(test.TestCase): @test_util.run_in_graph_and_eager_modes def testAssertions(self): a = tracking.Checkpointable() - a.l = [numpy.zeros([2, 2])] - self.assertAllEqual([numpy.zeros([2, 2])], a.l) - self.assertAllClose([numpy.zeros([2, 2])], a.l) - nest.map_structure(self.assertAllClose, a.l, [numpy.zeros([2, 2])]) - a.tensors = [array_ops.ones([2, 2]), array_ops.zeros([3, 3])] - self.assertAllClose([numpy.ones([2, 2]), numpy.zeros([3, 3])], + a.l = {"k": [numpy.zeros([2, 2])]} + self.assertAllEqual(nest.flatten({"k": [numpy.zeros([2, 2])]}), + nest.flatten(a.l)) + self.assertAllClose({"k": [numpy.zeros([2, 2])]}, a.l) + nest.map_structure(self.assertAllClose, a.l, {"k": [numpy.zeros([2, 2])]}) + a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]} + self.assertAllClose({"k": [numpy.ones([2, 2]), numpy.zeros([3, 3])]}, self.evaluate(a.tensors)) if __name__ == "__main__": |