aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-07-17 23:08:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-17 23:12:29 -0700
commitf1de0ddd55dcae6237ea7d21ccddcc6467a6cf8b (patch)
tree4238193e76288333027c86aeac2dc86f165af641
parentaa15692e54390cf3967d51bc60acf5f783df9c08 (diff)
Add support for MirroredVariables in init_from_checkpoint and warm_start in estimator.
PiperOrigin-RevId: 205030626
-rw-r--r--tensorflow/contrib/distribute/python/BUILD37
-rw-r--r--tensorflow/contrib/distribute/python/checkpoint_utils_test.py72
-rw-r--r--tensorflow/contrib/distribute/python/values.py15
-rw-r--r--tensorflow/contrib/distribute/python/warm_starting_util_test.py97
-rw-r--r--tensorflow/python/training/checkpoint_utils.py52
-rw-r--r--tensorflow/python/training/warm_starting_util.py18
6 files changed, 266 insertions, 25 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 40dbfa3dd2..f5d7e24ae2 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -610,3 +610,40 @@ cuda_py_test(
"no_pip",
],
)
+
+cuda_py_test(
+ name = "warm_starting_util_test",
+ size = "medium",
+ srcs = ["warm_starting_util_test.py"],
+ additional_deps = [
+ ":combinations",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
+ ],
+)
+
+cuda_py_test(
+ name = "checkpoint_utils_test",
+ size = "medium",
+ srcs = ["checkpoint_utils_test.py"],
+ additional_deps = [
+ ":combinations",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:checkpoint_utils_test",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
+ ],
+)
diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
new file mode 100644
index 0000000000..fe3df9cbb9
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
@@ -0,0 +1,72 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for checkpoint_utils.init_from_checkpoint with Distribution Strategy.
+
+These tests are located here instead of as part of
+`python.training.CheckpointsTest` because they need access to distribution
+strategies which are only present in contrib right now.
+TODO(priyag): Move the tests to core `python.training.CheckpointsTest` when
+distribution strategy moves out of contrib.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import checkpoint_utils_test
+
+
+class CheckpointUtilsWithDistributionStrategyTest(
+ test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(combinations.combine(
+ distribution=[combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus],
+ in_tower_mode=[True, False],
+ mode=["graph"]))
+ def testInitFromCheckpoint(self, distribution, in_tower_mode):
+ checkpoint_dir = self.get_temp_dir()
+ with self.test_session() as session:
+ v1_value, _, _, _ = checkpoint_utils_test._create_checkpoints(
+ session, checkpoint_dir)
+
+ def init_and_verify(g):
+ v1 = variable_scope.get_variable("new_var1", [1, 10])
+ checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
+ "var1": "new_var1",
+ })
+ with self.test_session(graph=g) as session:
+ session.run(variables.global_variables_initializer())
+ self.assertAllEqual(v1_value, self.evaluate(v1))
+
+ with ops.Graph().as_default() as g, distribution.scope():
+ if in_tower_mode:
+ distribution.call_for_each_tower(init_and_verify, g)
+ else:
+ init_and_verify(g)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 1b5e00bc79..1761a43251 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -33,7 +33,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
@@ -336,23 +335,27 @@ class MirroredVariable(DistributedVariable, Mirrored,
raise ValueError("You must specify an aggregation method to update a "
"MirroredVariable in Tower Context.")
- def merge_fn(strategy, value):
+ def merge_fn(strategy, value, *other_args, **other_kwargs):
return strategy.update(
self, f,
strategy.reduce(
- aggregation=self._aggregation, value=value, destinations=self))
+ aggregation=self._aggregation, value=value, destinations=self),
+ *other_args, **other_kwargs)
return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
**kwargs)
def assign_sub(self, *args, **kwargs):
- return self._assign_func(f=state_ops.assign_sub, *args, **kwargs)
+ assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
+ return self._assign_func(f=assign_sub_fn, *args, **kwargs)
def assign_add(self, *args, **kwargs):
- return self._assign_func(f=state_ops.assign_add, *args, **kwargs)
+ assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
+ return self._assign_func(f=assign_add_fn, *args, **kwargs)
def assign(self, *args, **kwargs):
- return self._assign_func(f=state_ops.assign, *args, **kwargs)
+ assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
+ return self._assign_func(f=assign_fn, *args, **kwargs)
def is_initialized(self, name=None):
# We have to cast the self._index.values() to a `list` because when we
diff --git a/tensorflow/contrib/distribute/python/warm_starting_util_test.py b/tensorflow/contrib/distribute/python/warm_starting_util_test.py
new file mode 100644
index 0000000000..d8bacdb338
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/warm_starting_util_test.py
@@ -0,0 +1,97 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for warm_starting_util with Distribution Strategy.
+
+These tests are located here instead of as part of `WarmStartingUtilTest`
+because they need access to distribution strategies which are only present in
+contrib right now.
+TODO(priyag): Move the tests to core `WarmStartingUtilTest` when distribution
+strategy moves out of contrib.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from absl.testing import parameterized
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import warm_starting_util as ws_util
+
+
+class WarmStartingUtilWithDistributionStrategyTest(
+ test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(combinations.combine(
+ distribution=[combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus],
+ save_with_distribution=[True, False],
+ restore_with_distribution=[True, False],
+ mode=["graph"]))
+ def testWarmStart(self, distribution, save_with_distribution,
+ restore_with_distribution):
+
+ var_name = "v"
+ original_value = [[1., 2.], [3., 4.]]
+
+ # Create variable and save checkpoint from which to warm-start.
+ def create_var(g):
+ with self.test_session(graph=g) as sess:
+ var = variable_scope.get_variable(var_name, initializer=original_value)
+ sess.run(variables.global_variables_initializer())
+ saver = saver_lib.Saver()
+ ckpt_prefix = os.path.join(self.get_temp_dir(), "model")
+ saver.save(sess, ckpt_prefix, global_step=0)
+ return var, sess.run(var)
+
+ if save_with_distribution:
+ with ops.Graph().as_default() as g, distribution.scope():
+ _, prev_init_val = create_var(g)
+ else:
+ with ops.Graph().as_default() as g:
+ _, prev_init_val = create_var(g)
+
+ # Verify we initialized the values correctly.
+ self.assertAllEqual(original_value, prev_init_val)
+
+ def warm_start(g):
+ with self.test_session(graph=g) as sess:
+ # Initialize with zeros.
+ var = variable_scope.get_variable(
+ var_name, initializer=[[0., 0.], [0., 0.]])
+ ws_util.warm_start(self.get_temp_dir())
+ sess.run(variables.global_variables_initializer())
+ # Verify weights were correctly warm-started to previous values.
+ self.assertAllEqual(original_value, self.evaluate(var))
+
+ # Warm start in a new graph.
+ if restore_with_distribution:
+ with ops.Graph().as_default() as g, distribution.scope():
+ warm_start(g)
+ else:
+ with ops.Graph().as_default() as g:
+ warm_start(g)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index 5b372e82b3..883f4fd910 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import tf_export
@@ -179,6 +180,16 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
ValueError: If missing variables in current graph.
"""
+ if distribute_lib.get_cross_tower_context():
+ _init_from_checkpoint(None, ckpt_dir_or_file, assignment_map)
+ else:
+ distribute_lib.get_tower_context().merge_call(
+ _init_from_checkpoint, ckpt_dir_or_file, assignment_map)
+
+
+def _init_from_checkpoint(_, ckpt_dir_or_file, assignment_map):
+ """See `init_from_checkpoint` for documentation."""
+
ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file)
reader = load_checkpoint(ckpt_dir_or_file)
variable_map = reader.get_variable_to_shape_map()
@@ -187,10 +198,9 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
var = None
# Check if this is Variable object or list of Variable objects (in case of
# partitioned variables).
- is_var = lambda x: isinstance(x, variables.Variable)
- if is_var(current_var_or_name) or (
+ if _is_variable(current_var_or_name) or (
isinstance(current_var_or_name, list)
- and all(is_var(v) for v in current_var_or_name)):
+ and all(_is_variable(v) for v in current_var_or_name)):
var = current_var_or_name
else:
store_vars = vs._get_default_variable_store()._vars # pylint:disable=protected-access
@@ -205,7 +215,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
raise ValueError("Tensor %s is not found in %s checkpoint %s" % (
tensor_name_in_ckpt, ckpt_dir_or_file, variable_map
))
- if is_var(var):
+ if _is_variable(var):
# Additional at-call-time checks.
if not var.get_shape().is_compatible_with(
variable_map[tensor_name_in_ckpt]):
@@ -297,13 +307,34 @@ def _set_checkpoint_initializer(variable,
with ops.device(variable.device), ops.device("/cpu:0"):
restore_op = io_ops.restore_v2(
ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
- if isinstance(variable, resource_variable_ops.ResourceVariable):
+
+ # TODO(priyag, allenl): Use `SaveableObject.restore` instead here.
+ if resource_variable_ops.is_resource_variable(variable):
init_op = variable.assign(restore_op, read_value=False)
else:
init_op = state_ops.assign(variable, restore_op)
- variable._initializer_op = init_op # pylint:disable=protected-access
- restore_op.set_shape(variable.shape)
- variable._initial_value = restore_op # pylint:disable=protected-access
+
+ # pylint:disable=protected-access
+ # We need special handling for `DistributedVariable`s as they contain
+ # mutliple actual variables. `assign` on a `DistributedVariable` returns a
+ # combined `init_op` which contains initializers for all the contained
+ # variables. We then set each underlying variable's `_initializer_op` using
+ # the corresponding `init_op`.
+ # TODO(priyag): Use `isinstance` checks when `DistributedVariable` class
+ # moves out of contrib.
+ if any(base.__name__ == "DistributedVariable"
+ for base in variable.__class__.__bases__):
+ assert distribute_lib.get_cross_tower_context()
+ assert hasattr(variable, "_index")
+ for (d, v) in six.iteritems(variable._index):
+ v._initializer_op = init_op._index[d]
+ restore_op.set_shape(v.shape)
+ v._initial_value = restore_op
+ else:
+ variable._initializer_op = init_op
+ restore_op.set_shape(variable.shape)
+ variable._initial_value = restore_op
+ # pylint:enable=protected-access
def _set_variable_or_list_initializer(variable_or_list, ckpt_file,
@@ -337,6 +368,11 @@ def _set_variable_or_list_initializer(variable_or_list, ckpt_file,
_set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "")
+def _is_variable(x):
+ return (isinstance(x, variables.Variable) or
+ resource_variable_ops.is_resource_variable(x))
+
+
def _collect_partitioned_variable(name, all_vars):
"""Returns list of `tf.Variable` that comprise the partitioned variable."""
if name + "/part_0" in all_vars:
diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py
index ec740abdd1..b1a7cfab83 100644
--- a/tensorflow/python/training/warm_starting_util.py
+++ b/tensorflow/python/training/warm_starting_util.py
@@ -22,7 +22,6 @@ import collections
import six
from tensorflow.python.framework import ops
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
@@ -83,11 +82,6 @@ class VocabInfo(
)
-def _is_variable(x):
- return (isinstance(x, variables_lib.Variable) or
- isinstance(x, resource_variable_ops.ResourceVariable))
-
-
def _infer_var_name(var):
"""Returns name of the `var`.
@@ -126,9 +120,10 @@ def _warm_start_var(var, prev_ckpt, prev_tensor_name=None):
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
None, we lookup tensor with same name as given `var`.
"""
- if _is_variable(var):
+ if checkpoint_utils._is_variable(var): # pylint: disable=protected-access
current_var_name = _infer_var_name([var])
- elif isinstance(var, list) and all(_is_variable(v) for v in var):
+ elif (isinstance(var, list) and
+ all(checkpoint_utils._is_variable(v) for v in var)): # pylint: disable=protected-access
current_var_name = _infer_var_name(var)
elif isinstance(var, variables_lib.PartitionedVariable):
current_var_name = _infer_var_name([var])
@@ -193,9 +188,10 @@ def _warm_start_var_with_vocab(var,
prev_vocab_path):
raise ValueError("Invalid args: Must provide all of [current_vocab_path, "
"current_vocab_size, prev_ckpt, prev_vocab_path}.")
- if _is_variable(var):
+ if checkpoint_utils._is_variable(var):
var = [var]
- elif isinstance(var, list) and all(_is_variable(v) for v in var):
+ elif (isinstance(var, list) and
+ all(checkpoint_utils._is_variable(v) for v in var)):
var = var
elif isinstance(var, variables_lib.PartitionedVariable):
var = var._get_variable_list()
@@ -271,7 +267,7 @@ def _get_grouped_variables(vars_to_warm_start):
for v in vars_to_warm_start:
list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
scope=v)
- elif all([_is_variable(v) for v in vars_to_warm_start]):
+ elif all([checkpoint_utils._is_variable(v) for v in vars_to_warm_start]): # pylint: disable=protected-access
list_of_vars = vars_to_warm_start
else:
raise ValueError("If `vars_to_warm_start` is a list, it must be all "