aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/saver.py
blob: fdaca90fd13576e6ca8a3408aaf528dbc2384b0c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""Saver for eager mode TensorFlow."""
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import contextlib

from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import saver as _saver


def _init_from_checkpoint(self, *args, **kwargs):
  """Overrides default init by loading value from checkpoint."""
  # pylint: disable=protected-access
  self._old_init(*args, **kwargs)
  ckpt_name = self._map_func(self._shared_name)
  if ckpt_name not in self._ckpt_var_cache:
    raise errors.NotFoundError(None, None,
                               "%s not found in checkpoint" % ckpt_name)

  val = self._ckpt_var_cache.get(ckpt_name, None)
  if val is not None:
    self.assign(val)
    # Avoid assigning for the second time.
    self._ckpt_var_cache[ckpt_name] = None
  # pylint: enable=protected-access


@contextlib.contextmanager
def restore_variables_on_create(save_path, map_func=None):
  """ContextManager that restores variables on creation.

    When save_path is None (e.g. No checkpoint), does nothing.
    Otherwise, it preloads all values from checkpoint. When the
    corresponding variable is first created, it assigns the checkpoint
    value to the variable.

    ```python
    with restore_variables_on_create(
        tf.train.latest_checkpoint(checkpoint_dir)):
    ```

  Args:
    save_path: The checkpoint file prefix.
    map_func: A function that given the variable name as argument
        and returns a variable name in checkpoint for restore. If
        None, use the variable with the same name in checkpoint to restore.
        It's an error that the mapped variable name doesn't exist in
        checkpoint.

  Yields:
    Nothing.

  Raises:
    NotFoundError: If the variable is not found in checkpoint.
    ValueError: If not used in eager mode or map_func is not callable.
  """
  if not context.executing_eagerly():
    raise ValueError(
        "Currently, restore_variables_on_create can only be used with "
        "eager execution enabled.")
  if save_path:
    if map_func is None:
      map_func_wrapper = lambda self, x: x
    else:
      if not callable(map_func):
        raise ValueError("map_func must be callable.")
      map_func_wrapper = lambda self, x: map_func(x)

    ckpt_var_cache = dict()
    reader = checkpoint_utils.load_checkpoint(save_path)
    for k, _ in checkpoint_utils.list_variables(save_path):
      ckpt_var_cache[k] = reader.get_tensor(k)

    old_init = getattr(resource_variable_ops.ResourceVariable,
                       "_init_from_args", None)
    assert old_init, "ResourceVariable misses _init_from_args method."
    setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
            _init_from_checkpoint)
    setattr(resource_variable_ops.ResourceVariable, "_old_init", old_init)
    setattr(resource_variable_ops.ResourceVariable, "_map_func",
            map_func_wrapper)
    setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache",
            ckpt_var_cache)
  try:
    yield
  except Exception as e:
    raise e
  finally:
    if save_path:
      setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
              old_init)
      setattr(resource_variable_ops.ResourceVariable, "_old_init", None)
      setattr(resource_variable_ops.ResourceVariable, "_map_func", None)
      setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache", None)


class Saver(object):
  """A tf.train.Saver adapter for use when eager execution is enabled.
  """

  def __init__(self, var_list):
    """A  tf.train.Saver adapter for use when eager execution is enabled.

      The API, and on-disk format, mimic tf.train.Saver except that no
      Session is needed.

    Args:
      var_list: The list of variables that will be saved and restored. Either a
        list of `tfe.Variable` objects, or a dictionary mapping names to
        `tfe.Variable` objects.

    Raises:
      RuntimeError: if invoked when eager execution has not been enabled.
    """
    if not context.executing_eagerly():
      raise RuntimeError("tfe.Saver can only be used when eager "
                         "execution is enabled. Use tf.train.Saver when "
                         "building graphs.")
    self._saver = _saver.Saver(var_list=var_list)

  def save(self, file_prefix, global_step=None):
    """Saves variables.

    Args:
      file_prefix: Path prefix of files created for the checkpoint.
      global_step: If provided the global step number is appended to file_prefix
        to create the checkpoint filename. The optional argument can be a
        Tensor, a Variable, or an integer.

    Returns:
      A string: prefix of filenames created for the checkpoint. This may be
       an extension of file_prefix that is suitable to pass as an argument
       to a subsequent call to `restore()`.
    """
    with ops.device("/device:CPU:0"):
      return self._saver.save(
          None, file_prefix, write_meta_graph=False, global_step=global_step)

  def restore(self, file_prefix):
    """Restores previously saved variables.

    Args:
      file_prefix: Path prefix where parameters were previously saved.
        Typically obtained from a previous `save()` call, or from
        @{tf.train.latest_checkpoint}.
    """
    with ops.device("/device:CPU:0"):
      self._saver.restore(None, file_prefix)


def get_optimizer_variables(optimizer):
  """Returns a list of variables for the given `tf.train.Optimizer`.

  Equivalent to `optimizer.variables()`.

  Args:
    optimizer: An instance of `tf.train.Optimizer` which has created variables
      (typically after a call to `Optimizer.minimize`).
  Returns:
    A list of variables which have been created by the `Optimizer`.
  """
  return optimizer.variables()