aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint/python/python_state.py
blob: 9b11035b6d277851ea0a0071062bf5cf6b6b2185 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""Utilities for including Python state in TensorFlow checkpoints."""
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools

import numpy

from tensorflow.python.training.checkpointable import base

# pylint: disable=g-import-not-at-top
try:
  # In Python 2.x, use the faster string buffering option.
  from cStringIO import StringIO as BytesIO
except ImportError:
  from io import BytesIO
# pylint: enable=g-import-not-at-top


class NumpyState(base.CheckpointableBase):
  """A checkpointable object whose NumPy array attributes are saved/restored.

  Example usage:

  ```python
  arrays = tf.contrib.checkpoint.NumpyState()
  checkpoint = tf.train.Checkpoint(numpy_arrays=arrays)
  arrays.x = numpy.zeros([3, 4])
  save_path = checkpoint.save("/tmp/ckpt")
  arrays.x[1, 1] = 4.
  checkpoint.restore(save_path)
  assert (arrays.x == numpy.zeros([3, 4])).all()

  second_checkpoint = tf.train.Checkpoint(
      numpy_arrays=tf.contrib.checkpoint.NumpyState())
  # Attributes of NumpyState objects are created automatically by restore()
  second_checkpoint.restore(save_path)
  assert (second_checkpoint.numpy_arrays.x == numpy.zeros([3, 4])).all()
  ```

  Note that `NumpyState` objects re-create the attributes of the previously
  saved object on `restore()`. This is in contrast to TensorFlow variables, for
  which a `Variable` object must be created and assigned to an attribute.

  This snippet works both when graph building and when executing eagerly. On
  save, the NumPy array(s) are fed as strings to be saved in the checkpoint (via
  a placeholder when graph building, or as a string constant when executing
  eagerly). When restoring they skip the TensorFlow graph entirely, and so no
  restore ops need be run. This means that restoration always happens eagerly,
  rather than waiting for `checkpoint.restore(...).run_restore_ops()` like
  TensorFlow variables when graph building.
  """

  def _lookup_dependency(self, name):
    """Create placeholder NumPy arrays for to-be-restored attributes.

    Typically `_lookup_dependency` is used to check by name whether a dependency
    exists. We cheat slightly by creating a checkpointable object for `name` if
    we don't already have one, giving us attribute re-creation behavior when
    loading a checkpoint.

    Args:
      name: The name of the dependency being checked.
    Returns:
      An existing dependency if one exists, or a new `_NumpyWrapper` placeholder
      dependency (which will generally be restored immediately).
    """
    value = super(NumpyState, self)._lookup_dependency(name)
    if value is None:
      value = _NumpyWrapper(numpy.array([]))
      new_reference = base.CheckpointableReference(name=name, ref=value)
      self._unconditional_checkpoint_dependencies.append(new_reference)
      self._unconditional_dependency_names[name] = value
      super(NumpyState, self).__setattr__(name, value)
    return value

  def __getattribute__(self, name):
    """Un-wrap `_NumpyWrapper` objects when accessing attributes."""
    value = super(NumpyState, self).__getattribute__(name)
    if isinstance(value, _NumpyWrapper):
      return value.array
    return value

  def __setattr__(self, name, value):
    """Automatically wrap NumPy arrays assigned to attributes."""
    # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
    # ndarrays checkpointable natively and using standard checkpointable list
    # tracking.
    if isinstance(value, numpy.ndarray):
      try:
        existing = super(NumpyState, self).__getattribute__(name)
        existing.array = value
        return
      except AttributeError:
        value = _NumpyWrapper(value)
        self._track_checkpointable(value, name=name, overwrite=True)
    elif (name not in ("_setattr_tracking", "_update_uid")
          and getattr(self, "_setattr_tracking", True)):
      # Mixing restore()-created attributes with user-added checkpointable
      # objects is tricky, since we can't use the `_lookup_dependency` trick to
      # re-create attributes (we might accidentally steal the restoration for
      # another checkpointable object). For now `NumpyState` objects must be
      # leaf nodes. Theoretically we could add some extra arguments to
      # `_lookup_dependency` to figure out whether we should create a NumPy
      # array for the attribute or not.
      raise NotImplementedError(
          ("Assigned %s to the %s property of %s, which is not a NumPy array. "
           "Currently mixing NumPy arrays and other checkpointable objects is "
           "not supported. File a feature request if this limitation bothers "
           "you.")
          % (value, name, self))
    super(NumpyState, self).__setattr__(name, value)


class _NumpyWrapper(base.CheckpointableBase):
  """Wraps a NumPy array for storage in an object-based checkpoint."""

  def __init__(self, array):
    """Specify a NumPy array to wrap.

    Args:
      array: The NumPy array to save and restore (may be overwritten).
    """
    self.array = array

  def _serialize(self):
    """Callback for `PythonStringStateSaveable` to serialize the array."""
    string_file = BytesIO()
    try:
      numpy.save(string_file, self.array, allow_pickle=False)
      serialized = string_file.getvalue()
    finally:
      string_file.close()
    return serialized

  def _deserialize(self, string_value):
    """Callback for `PythonStringStateSaveable` to deserialize the array."""
    string_file = BytesIO(string_value)
    try:
      self.array = numpy.load(string_file, allow_pickle=False)
    finally:
      string_file.close()

  def _gather_saveables_for_checkpoint(self):
    """Specify callbacks for saving and restoring `array`."""
    return {
        "array": functools.partial(
            base.PythonStringStateSaveable,
            state_callback=self._serialize,
            restore_callback=self._deserialize)
        }