aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/data_structures.py
blob: c46585b4178cbd24dc7d2507b4b42aa823ea1305 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""Checkpointable data structures."""
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

import six

from tensorflow.python.ops import variables
from tensorflow.python.training.checkpointable import base as checkpointable_lib
from tensorflow.python.training.checkpointable import layer_utils


# TODO(allenl): We could track regular Python data structures which get assigned
# to Checkpointable objects. Making this work with restore-on-create would be
# tricky; we'd need to re-create nested structures with our own wrapped objects
# on assignment to an attribute, and track the user's original structure to make
# sure they don't modify it except through the wrappers (since we could save the
# user's updated structure, but would have no way to support restore-on-create
# for those modifications).
# TODO(allenl): A dictionary data structure would be good too.
class CheckpointableDataStructure(checkpointable_lib.CheckpointableBase):
  """Base class for data structures which contain checkpointable objects."""

  def __init__(self):
    self._layers = []
    self.trainable = True
    self._extra_variables = []

  def _track_value(self, value, name):
    """Add a dependency on `value`."""
    if isinstance(value, checkpointable_lib.CheckpointableBase):
      self._track_checkpointable(value, name=name)
      if isinstance(value, variables.Variable):
        self._extra_variables.append(value)
    else:
      raise ValueError(
          ("Only checkpointable objects (such as Layers or Optimizers) may be "
           "stored in a List object. Got %s, which does not inherit from "
           "CheckpointableBase.") % (value,))
    if (isinstance(value, CheckpointableDataStructure)
        or layer_utils.is_layer(value)):
      if value not in self._layers:
        self._layers.append(value)
        if hasattr(value, "_use_resource_variables"):
          # In subclassed models, legacy layers (tf.layers) must always use
          # resource variables.
          value._use_resource_variables = True  # pylint: disable=protected-access

  @property
  def layers(self):
    return self._layers

  @property
  def trainable_weights(self):
    return layer_utils.gather_trainable_weights(
        trainable=self.trainable,
        sub_layers=self.layers,
        extra_variables=self._extra_variables)

  @property
  def non_trainable_weights(self):
    return layer_utils.gather_non_trainable_weights(
        trainable=self.trainable,
        sub_layers=self.layers,
        extra_variables=self._extra_variables)

  @property
  def weights(self):
    return self.trainable_weights + self.non_trainable_weights

  @property
  def trainable_variables(self):
    return self.trainable_weights

  @property
  def non_trainable_variables(self):
    return self.non_trainable_weights

  @property
  def variables(self):
    return self.weights

  @property
  def updates(self):
    """Aggregate updates from any `Layer` instances."""
    # Updates and conditional losses are forwarded as-is rather than being
    # filtered based on inputs, since this is just a container and won't ever
    # have any inputs.
    aggregated = []
    for layer in self.layers:
      aggregated += layer.updates
    return aggregated

  @property
  def losses(self):
    """Aggregate losses from any `Layer` instances."""
    aggregated = []
    for layer in self.layers:
      aggregated += layer.losses
    return aggregated

  def __hash__(self):
    # Support object-identity hashing, so these structures can be used as keys
    # in sets/dicts.
    return id(self)

  def __eq__(self, other):
    # Similar to Tensors, checkpointable data structures use object-identity
    # equality to support set/dict membership.
    return self is other


class List(CheckpointableDataStructure, collections.Sequence):
  """An append-only sequence type which is checkpointable.

  Maintains checkpoint dependencies on its contents (which must also be
  checkpointable), and forwards any `Layer` metadata such as updates and losses.

  Note that `List` is purely a container. It lets a `tf.keras.Model` or
  other checkpointable object know about its contents, but does not call any
  `Layer` instances which are added to it. To indicate a sequence of `Layer`
  instances which should be called sequentially, use `tf.keras.Sequential`.

  Example usage:
  ```python
  class HasList(tf.keras.Model):

    def __init__(self):
      super(HasList, self).__init__()
      self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)])
      self.layer_list.append(layers.Dense(4))

    def call(self, x):
      aggregation = 0.
      for l in self.layer_list:
        x = l(x)
        aggregation += tf.reduce_sum(x)
      return aggregation
  ```

  This kind of wrapping is necessary because `Checkpointable` objects do not
  (yet) deeply inspect regular Python data structures, so for example assigning
  a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a
  checkpoint dependency and does not add the `Layer` instance's weights to its
  parent `Model`.
  """

  def __init__(self, *args, **kwargs):
    """Construct a new sequence. Arguments are passed to `list()`."""
    super(List, self).__init__()
    self._storage = list(*args, **kwargs)
    for index, element in enumerate(self._storage):
      self._track_value(element, name=self._name_element(index))

  def _name_element(self, index):
    return "%d" % (index,)

  def append(self, value):
    """Add a new checkpointable value."""
    self._track_value(value, self._name_element(len(self._storage)))
    self._storage.append(value)

  def extend(self, values):
    """Add a sequence of checkpointable values."""
    for index_offset, value in enumerate(values):
      self._track_value(
          value, name=self._name_element(len(self._storage) + index_offset))
    self._storage.extend(values)

  def __iadd__(self, values):
    self.extend(values)
    return self

  def __add__(self, other):
    if isinstance(other, List):
      return List(self._storage + other._storage)  # pylint: disable=protected-access
    else:
      return List(self._storage + other)

  def __getitem__(self, key):
    return self._storage[key]

  def __len__(self):
    return len(self._storage)

  def __repr__(self):
    return "List(%s)" % (repr(self._storage),)


class Mapping(CheckpointableDataStructure, collections.Mapping):
  """An append-only checkpointable mapping data structure with string keys.

  Maintains checkpoint dependencies on its contents (which must also be
  checkpointable), named based on its keys.

  Note that once a key has been added, it may not be deleted or replaced. If
  names may not be unique, see `tf.contrib.checkpoint.UniqueNameTracker`.
  """

  def __init__(self, *args, **kwargs):
    """Construct a new sequence. Arguments are passed to `dict()`."""
    super(Mapping, self).__init__()
    self._storage = dict(*args, **kwargs)
    for key, value in self._storage.items():
      self._track_value(value, name=self._name_element(key))

  def _name_element(self, key):
    if not isinstance(key, six.string_types):
      raise TypeError(
          "Mapping accepts only string keys, but got a key %s."
          % repr(key))
    return str(key)

  def __setitem__(self, key, value):
    current_value = self._storage.setdefault(key, value)
    if current_value is not value:
      raise ValueError(
          ("Mappings are an append-only data structure. Tried to overwrite the "
           "key '%s' with value %s, but it already contains %s")
          % (key, value, current_value))
    self._track_value(value, name=self._name_element(key))

  def update(self, *args, **kwargs):
    for key, value in dict(*args, **kwargs).items():
      self[key] = value

  def __getitem__(self, key):
    return self._storage[key]

  def __len__(self):
    return len(self._storage)

  def __repr__(self):
    return "Mapping(%s)" % (repr(self._storage),)

  def __iter__(self):
    return iter(self._storage)