aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/tracking_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/tracking_test.py')
-rw-r--r--tensorflow/python/training/checkpointable/tracking_test.py37
1 files changed, 31 insertions, 6 deletions
diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py
index 96da0d6e47..f8d17cd417 100644
--- a/tensorflow/python/training/checkpointable/tracking_test.py
+++ b/tensorflow/python/training/checkpointable/tracking_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import os
import numpy
+import six
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training
@@ -144,6 +145,29 @@ class InterfaceTests(test.TestCase):
checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
@test_util.run_in_graph_and_eager_modes
+ def testDictionariesBasic(self):
+ a = training.Model()
+ b = training.Model()
+ a.attribute = {"b": b}
+ c = training.Model()
+ a.attribute["c"] = []
+ a.attribute["c"].append(c)
+ a_deps = util.list_objects(a)
+ self.assertIn(b, a_deps)
+ self.assertIn(c, a_deps)
+ self.assertIs(b, a.attribute["b"])
+ six.assertCountEqual(
+ self,
+ ["b", "c"],
+ [dep.name for dep in a.attribute._checkpoint_dependencies])
+ self.assertEqual([b, c], a.layers)
+ self.assertEqual([b, c], a.attribute.layers)
+ self.assertEqual([c], a.attribute["c"].layers)
+ checkpoint = util.Checkpoint(a=a)
+ save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ checkpoint.restore(save_path).assert_consumed()
+
+ @test_util.run_in_graph_and_eager_modes
def testNoDepList(self):
a = training.Model()
a.l1 = data_structures.NoDependency([])
@@ -159,12 +183,13 @@ class InterfaceTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testAssertions(self):
a = tracking.Checkpointable()
- a.l = [numpy.zeros([2, 2])]
- self.assertAllEqual([numpy.zeros([2, 2])], a.l)
- self.assertAllClose([numpy.zeros([2, 2])], a.l)
- nest.map_structure(self.assertAllClose, a.l, [numpy.zeros([2, 2])])
- a.tensors = [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]
- self.assertAllClose([numpy.ones([2, 2]), numpy.zeros([3, 3])],
+ a.l = {"k": [numpy.zeros([2, 2])]}
+ self.assertAllEqual(nest.flatten({"k": [numpy.zeros([2, 2])]}),
+ nest.flatten(a.l))
+ self.assertAllClose({"k": [numpy.zeros([2, 2])]}, a.l)
+ nest.map_structure(self.assertAllClose, a.l, {"k": [numpy.zeros([2, 2])]})
+ a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]}
+ self.assertAllClose({"k": [numpy.ones([2, 2]), numpy.zeros([3, 3])]},
self.evaluate(a.tensors))
if __name__ == "__main__":