aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint/python/containers_test.py
blob: ac85c7be803cd4c2f8ba19d3ef887a3c65a15933 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import six

from tensorflow.contrib.checkpoint.python import containers
from tensorflow.python.framework import test_util
from tensorflow.python.keras import layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util


class UniqueNameTrackerTests(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def testNames(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    x1 = resource_variable_ops.ResourceVariable(2.)
    x2 = resource_variable_ops.ResourceVariable(3.)
    x3 = resource_variable_ops.ResourceVariable(4.)
    y = resource_variable_ops.ResourceVariable(5.)
    slots = containers.UniqueNameTracker()
    slots.track(x1, "x")
    slots.track(x2, "x")
    slots.track(x3, "x_1")
    slots.track(y, "y")
    self.evaluate((x1.initializer, x2.initializer, x3.initializer,
                   y.initializer))
    save_root = util.Checkpoint(slots=slots)
    save_path = save_root.save(checkpoint_prefix)

    restore_slots = tracking.Checkpointable()
    restore_root = util.Checkpoint(
        slots=restore_slots)
    status = restore_root.restore(save_path)
    restore_slots.x = resource_variable_ops.ResourceVariable(0.)
    restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.)
    restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.)
    restore_slots.y = resource_variable_ops.ResourceVariable(0.)
    status.assert_consumed().run_restore_ops()
    self.assertEqual(2., self.evaluate(restore_slots.x))
    self.assertEqual(3., self.evaluate(restore_slots.x_1))
    self.assertEqual(4., self.evaluate(restore_slots.x_1_1))
    self.assertEqual(5., self.evaluate(restore_slots.y))

  @test_util.run_in_graph_and_eager_modes
  def testExample(self):
    class SlotManager(tracking.Checkpointable):

      def __init__(self):
        self.slotdeps = containers.UniqueNameTracker()
        slotdeps = self.slotdeps
        slots = []
        slots.append(slotdeps.track(
            resource_variable_ops.ResourceVariable(3.), "x"))
        slots.append(slotdeps.track(
            resource_variable_ops.ResourceVariable(4.), "y"))
        slots.append(slotdeps.track(
            resource_variable_ops.ResourceVariable(5.), "x"))
        self.slots = data_structures.NoDependency(slots)

    manager = SlotManager()
    self.evaluate([v.initializer for v in manager.slots])
    checkpoint = util.Checkpoint(slot_manager=manager)
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = checkpoint.save(checkpoint_prefix)
    metadata = util.object_metadata(save_path)
    dependency_names = []
    for node in metadata.nodes:
      for child in node.children:
        dependency_names.append(child.local_name)
    six.assertCountEqual(
        self,
        dependency_names,
        ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"])

  @test_util.run_in_graph_and_eager_modes
  def testLayers(self):
    tracker = containers.UniqueNameTracker()
    tracker.track(layers.Dense(3), "dense")
    tracker.layers[0](array_ops.zeros([1, 1]))
    self.assertEqual(2, len(tracker.trainable_weights))

if __name__ == "__main__":
  test.main()