aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-29 15:28:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-29 15:32:01 -0700
commit6f5d7a97cd2c0741ddfa756853ce5321377b5d53 (patch)
treee79afd91cd68bc9ed75bfe278511312da3918fe6 /tensorflow/contrib/distribute
parent40f8291db5c0b05b31d7bbe23b847cdbb2408718 (diff)
Add tf.contrib.distribute, which defines classes DistributionStrategy
and MirroredStrategy, and related functionality. Also add tf.contrib.optimizer_v2, an update to the Optimizer API. RELNOTES: Can now pass tf.contrib.distribute.MirroredStrategy() to tf.estimator.RunConfig() to run an Estimator model on multiple GPUs on one machine. PiperOrigin-RevId: 190996247
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/BUILD36
-rw-r--r--tensorflow/contrib/distribute/__init__.py52
-rw-r--r--tensorflow/contrib/distribute/python/BUILD431
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py293
-rw-r--r--tensorflow/contrib/distribute/python/combinations_test.py115
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py410
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py185
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_utils.py153
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py279
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py486
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py435
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_test.py91
-rw-r--r--tensorflow/contrib/distribute/python/monitor.py61
-rw-r--r--tensorflow/contrib/distribute/python/monitor_test.py84
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py148
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy_test.py54
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py70
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py167
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py68
-rw-r--r--tensorflow/contrib/distribute/python/shared_variable_creator.py97
-rw-r--r--tensorflow/contrib/distribute/python/shared_variable_creator_test.py75
-rw-r--r--tensorflow/contrib/distribute/python/simple_estimator_example.py97
-rw-r--r--tensorflow/contrib/distribute/python/single_loss_example.py102
-rw-r--r--tensorflow/contrib/distribute/python/step_fn.py103
-rw-r--r--tensorflow/contrib/distribute/python/step_fn_test.py62
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py225
-rw-r--r--tensorflow/contrib/distribute/python/values.py575
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py807
28 files changed, 5761 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD
new file mode 100644
index 0000000000..74b2cd90a1
--- /dev/null
+++ b/tensorflow/contrib/distribute/BUILD
@@ -0,0 +1,36 @@
+# Implementation of a prototype TF distributed computation library.
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "distribute",
+ srcs = ["__init__.py"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/contrib/distribute/python:cross_tower_ops",
+ "//tensorflow/contrib/distribute/python:mirrored_strategy",
+ "//tensorflow/contrib/distribute/python:monitor",
+ "//tensorflow/contrib/distribute/python:one_device_strategy",
+ "//tensorflow/contrib/distribute/python:step_fn",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ ],
+)
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
new file mode 100644
index 0000000000..76711baf3a
--- /dev/null
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -0,0 +1,52 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Prototype of a distributed computation library for TF."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.distribute.python.cross_tower_ops import *
+from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
+from tensorflow.contrib.distribute.python.monitor import Monitor
+from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy
+from tensorflow.contrib.distribute.python.step_fn import *
+from tensorflow.python.training.distribute import *
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+
+_allowed_symbols = [
+ 'AllReduceCrossTowerOps',
+ 'CrossTowerOps',
+ 'DistributionStrategy',
+ 'MirroredStrategy',
+ 'Monitor',
+ 'OneDeviceStrategy',
+ 'ReductionToOneDeviceCrossTowerOps',
+ 'Step',
+ 'StandardInputStep',
+ 'StandardSingleLossStep',
+ 'TowerContext',
+ 'get_cross_tower_context',
+ 'get_distribution_strategy',
+ 'get_loss_reduction',
+ 'get_tower_context',
+ 'has_distribution_strategy',
+ 'require_tower_context',
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
new file mode 100644
index 0000000000..4dfd3f7228
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -0,0 +1,431 @@
+# Implementation of a prototype TF distributed computation library.
+
+package(
+ default_visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+# TODO(priyag): Figure out testonly issues that are preventing us from
+# including our tests in pip for now.
+
+py_library(
+ name = "values",
+ srcs = ["values.py"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":prefetching_ops_v2",
+ "//tensorflow/contrib/data/python/ops:transformation_ops",
+ "//tensorflow/contrib/eager/python:datasets",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:checkpointable",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ "//tensorflow/python/eager:context",
+ "@six_archive//:six",
+ ],
+)
+
+cuda_py_test(
+ name = "values_test",
+ srcs = ["values_test.py"],
+ additional_deps = [
+ ":mirrored_strategy",
+ ":values",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/estimator:model_fn",
+ ],
+)
+
+py_library(
+ name = "mirrored_strategy",
+ srcs = ["mirrored_strategy.py"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":cross_tower_ops",
+ ":shared_variable_creator",
+ ":values",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:device",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:pywrap_tensorflow",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:tape",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
+ name = "one_device_strategy",
+ srcs = ["one_device_strategy.py"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":values",
+ "//tensorflow/contrib/eager/python:datasets",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:context",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
+ name = "strategy_test_lib",
+ testonly = 1,
+ srcs = ["strategy_test_lib.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ ],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:backprop",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ ],
+)
+
+py_library(
+ name = "combinations",
+ testonly = 1,
+ srcs = ["combinations.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ ],
+ deps = [
+ ":mirrored_strategy",
+ ":one_device_strategy",
+ "//tensorflow/contrib/optimizer_v2:training",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ "//tensorflow/python/eager:context",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "combinations_test",
+ srcs = ["combinations_test.py"],
+ tags = [
+ "no_pip",
+ ],
+ deps = [
+ ":combinations",
+ "//tensorflow/python/eager:test",
+ ],
+)
+
+py_test(
+ name = "mirrored_strategy_test",
+ srcs = ["mirrored_strategy_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ ],
+ deps = [
+ ":mirrored_strategy",
+ ":strategy_test_lib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ ],
+)
+
+py_test(
+ name = "one_device_strategy_test",
+ srcs = ["one_device_strategy_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ ],
+ deps = [
+ ":one_device_strategy",
+ ":strategy_test_lib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/eager:test",
+ ],
+)
+
+cuda_py_test(
+ name = "mirrored_strategy_multigpu_test",
+ srcs = ["mirrored_strategy_multigpu_test.py"],
+ additional_deps = [
+ ":mirrored_strategy",
+ ":values",
+ ":strategy_test_lib",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ ],
+ tags = [
+ "guitar",
+ "no_pip",
+ "multi_and_single_gpu",
+ # Do not perform the extra analysis on this test, because it is already
+ # performed for the `:mirrored_strategy_test` target.
+ "no_oss",
+ "noasan",
+ "notap",
+ "notsan",
+ ],
+)
+
+py_library(
+ name = "step_fn",
+ srcs = ["step_fn.py"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:backprop",
+ ],
+)
+
+cuda_py_test(
+ name = "minimize_loss_test",
+ srcs = ["minimize_loss_test.py"],
+ additional_deps = [
+ ":combinations",
+ ":single_loss_example",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/ops/losses",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
+ ],
+)
+
+cuda_py_test(
+ name = "optimizer_v2_test",
+ srcs = ["optimizer_v2_test.py"],
+ additional_deps = [
+ ":combinations",
+ ":single_loss_example",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
+ ],
+)
+
+py_library(
+ name = "single_loss_example",
+ srcs = ["single_loss_example.py"],
+ deps = [
+ ":step_fn",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+cuda_py_test(
+ name = "step_fn_test",
+ srcs = ["step_fn_test.py"],
+ additional_deps = [
+ ":single_loss_example",
+ ":combinations",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
+ ],
+)
+
+py_library(
+ name = "monitor",
+ srcs = ["monitor.py"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+cuda_py_test(
+ name = "monitor_test",
+ srcs = ["monitor_test.py"],
+ additional_deps = [
+ ":combinations",
+ ":monitor",
+ ":one_device_strategy",
+ ":single_loss_example",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
+ ],
+)
+
+py_library(
+ name = "shared_variable_creator",
+ srcs = ["shared_variable_creator.py"],
+ visibility = ["//tensorflow:internal"],
+)
+
+py_test(
+ name = "shared_variable_creator_test",
+ srcs = ["shared_variable_creator_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":shared_variable_creator",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/eager:test",
+ ],
+)
+
+py_binary(
+ name = "simple_estimator_example",
+ srcs = ["simple_estimator_example.py"],
+ deps = [
+ ":mirrored_strategy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:training",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/estimator:estimator_py",
+ "//tensorflow/python/estimator:model_fn",
+ ],
+)
+
+py_library(
+ name = "cross_tower_utils",
+ srcs = ["cross_tower_utils.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/nccl:nccl_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ ],
+)
+
+py_library(
+ name = "cross_tower_ops",
+ srcs = ["cross_tower_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":cross_tower_utils",
+ ":values",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:context",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "cross_tower_ops_test",
+ srcs = ["cross_tower_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ ],
+ deps = [
+ ":combinations",
+ ":cross_tower_ops",
+ ":values",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_library(
+ name = "prefetching_ops_v2",
+ srcs = ["prefetching_ops_v2.py"],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:contrib_op_loader",
+ "//tensorflow/contrib/data/python/ops:prefetching_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+cuda_py_test(
+ name = "prefetching_ops_v2_test",
+ srcs = ["prefetching_ops_v2_test.py"],
+ additional_deps = [
+ ":prefetching_ops_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+)
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
new file mode 100644
index 0000000000..dd8e7c4376
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -0,0 +1,293 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Facilities for creating multiple test combinations.
+
+Here is an example of testing various optimizers in Eager and Graph mode:
+
+class AdditionExample(test.TestCase, parameterized.TestCase):
+ @combinations.generate(
+ combinations.combine(mode=["graph", "eager"],
+ optimizer=[AdamOptimizer(),
+ GradientDescentOptimizer()]))
+ def testOptimizer(self, optimizer):
+ ... f(optimizer)...
+
+This will run `testOptimizer` 4 times with the specified optimizers: 2 in
+Eager and 2 in Graph mode.
+The test will be provided with arguments that match the arguments of combine
+by name. It is necessary to request all arguments, except for `mode`, which is
+optional.
+
+`combine()` function is available for creating a cross product of various
+options. `times()` function exists for creating a product of N `combine()`-ed
+results. See below.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import OrderedDict
+import sys
+from absl.testing import parameterized
+
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import one_device_strategy
+from tensorflow.contrib.optimizer_v2 import adam as adam_v2
+from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.training import adam
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.util import tf_inspect
+
+
+GPU_TEST = "test_gpu" in sys.argv[0]
+
+
+def generate(combinations):
+ """A decorator for generating test cases of a test method or a test class.
+
+ Args:
+ combinations: a list of dictionaries created using combine() and times().
+
+ Restrictions:
+ -- there should always be a "mode" argument. Accepted values are "eager"
+ and "graph".
+ -- arguments of the test method must match by name to get the corresponding
+ value of the combination. Tests must accept all arguments (except "mode",
+ which is optional).
+ -- distribution argument is special. It is meant for passing instances of
+ DistributionStrategy. Each instance is to be passed as `(<int>,
+ <DistributionStrategy>)` tuple, where <int> is the number of required
+ GPUs. If the required number of GPUs for the DistributionStrategy isn't
+ available then the test case is going to be skipped.
+
+ Returns:
+ a decorator that will cause the test method to be run under the specified
+ conditions.
+
+ Raises:
+ ValueError - if "mode" argument wasn't either "eager" or "graph.
+ """
+
+ def decorator(test_function):
+ """The decorator to be returned."""
+
+ # Generate good test names that can be used with --test_filter.
+ for combination in combinations:
+ # We use OrderedDicts in `combine()` and `times()` to ensure stable
+ # order of keys in each dictionary.
+ assert isinstance(combination, OrderedDict)
+ name = "".join([
+ "_{}_{}".format(
+ "".join(filter(str.isalnum, key)),
+ "".join(filter(str.isalnum, str(value))))
+ for key, value in combination.items()
+ ])
+ combination.update({"testcase_name": "_test{}".format(name)})
+
+ @parameterized.named_parameters(*combinations)
+ def decorated(self, **kwargs):
+ """A wrapped test method that sets up `test_function`."""
+ assert "mode" in kwargs
+ mode = kwargs["mode"]
+
+ if "distribution" in kwargs:
+ distribution = kwargs["distribution"]
+ kwargs["distribution"] = distribution.strategy
+ if not distribution.required_gpus:
+ if GPU_TEST:
+ self.skipTest("Test that doesn't require GPUs.")
+ elif context.num_gpus() < distribution.required_gpus:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(distribution.required_gpus, context.num_gpus()))
+
+ requested_arguments = tf_inspect.getfullargspec(test_function).args
+ missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
+ set(requested_arguments + ["mode"]))
+ if missing_arguments:
+ raise ValueError("The test is missing arguments {} .".format(
+ missing_arguments))
+
+ kwargs_to_pass = {}
+ for arg in requested_arguments:
+ if arg == "self":
+ kwargs_to_pass[arg] = self
+ else:
+ kwargs_to_pass[arg] = kwargs[arg]
+
+ if mode == "eager":
+ with context.eager_mode(), ops.Graph().as_default():
+ test_function(**kwargs_to_pass)
+ elif mode == "graph":
+ with context.graph_mode(), ops.Graph().as_default():
+ test_function(**kwargs_to_pass)
+ else:
+ raise ValueError(
+ "'mode' has to be either 'eager' or 'graph' and not {}".format(
+ mode))
+
+ return decorated
+ return decorator
+
+
+def combine(**kwargs):
+ """Generate combinations based on its keyword arguments.
+
+ Two sets of returned combinations can be concatenated using +. Their product
+ can be computed using `times()`.
+
+ Args:
+ **kwargs: keyword arguments of form `option=[possibilities, ...]`.
+
+ Returns:
+ a list of dictionaries for each combination. Keys in the dictionaries are
+ the keyword argument names. Each key has one value - one of the
+ corresponding keyword argument values.
+ """
+ if not kwargs:
+ return [OrderedDict()]
+
+ sort_by_key = lambda k: k[0][0]
+ kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key))
+ first = list(kwargs.items())[0]
+
+ rest = dict(list(kwargs.items())[1:])
+ rest_combined = combine(**rest)
+
+ key = first[0]
+ values = first[1]
+
+ return [
+ OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key))
+ for v in values
+ for combined in rest_combined
+ ]
+
+
+def times(*combined):
+ """Generate a product of N sets of combinations.
+
+ times(combine(a=[1,2]), combine(b=[3,4])) == combine(a=[1,2], b=[3,4])
+
+ Args:
+ *combined: N lists of dictionaries that specify combinations.
+
+ Returns:
+ a list of dictionaries for each combination.
+
+ Raises:
+ ValueError: if some of the inputs have overlapping keys.
+ """
+ assert combined
+
+ if len(combined) == 1:
+ return combined[0]
+
+ first = combined[0]
+ rest_combined = times(*combined[1:])
+
+ combined_results = []
+ for a in first:
+ for b in rest_combined:
+ if set(a.keys()).intersection(set(b.keys())):
+ raise ValueError("Keys need to not overlap: {} vs {}".format(
+ a.keys(), b.keys()))
+
+ combined_results.append(OrderedDict(list(a.items()) + list(b.items())))
+ return combined_results
+
+
+class NamedObject(object):
+ """A class that translates an object into a good test name."""
+
+ def __init__(self, name, obj):
+ self._name = name
+ self._obj = obj
+
+ def __getattr__(self, name):
+ return getattr(self._obj, name)
+
+ def __call__(self, *args, **kwargs):
+ return self._obj(*args, **kwargs)
+
+ def __repr__(self):
+ return self._name
+
+
+class NamedDistribution(object):
+ """Translates DistributionStrategy and its data into a good name."""
+
+ def __init__(self, name, distribution, required_gpus):
+ self._distribution = distribution
+ self._name = name
+ self._required_gpus = required_gpus
+
+ def __repr__(self):
+ return self._name
+
+ @property
+ def strategy(self):
+ return self._distribution
+
+ @property
+ def required_gpus(self):
+ return self._required_gpus
+
+
+one_device_strategy = NamedDistribution(
+ "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"),
+ None)
+mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
+ "MirroredCPUAndGPU",
+ mirrored_strategy.MirroredStrategy(["/gpu:0", "/cpu:0"]), 1)
+mirrored_strategy_with_two_gpus = NamedDistribution(
+ "Mirrored2GPUs",
+ mirrored_strategy.MirroredStrategy(["/gpu:0", "/gpu:1"]), 2)
+
+adam_optimizer_v1_fn = NamedObject(
+ "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1))
+gradient_descent_optimizer_v1_fn = NamedObject(
+ "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
+
+adam_optimizer_v2_fn = NamedObject(
+ "AdamV2", lambda: adam_v2.AdamOptimizer(0.2, epsilon=1))
+gradient_descent_optimizer_v2_fn = NamedObject(
+ "GradientDescentV2",
+ lambda: gradient_descent_v2.GradientDescentOptimizer(0.2))
+
+graph_and_eager_modes = ["graph", "eager"]
+
+
+def distributions_and_v1_optimizers():
+ """A common set of combination with DistributionStrategies and Optimizers."""
+ return combine(
+ distribution=[
+ one_device_strategy, mirrored_strategy_with_gpu_and_cpu,
+ mirrored_strategy_with_two_gpus
+ ],
+ optimizer_fn=[adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn])
+
+
+def distributions_and_v2_optimizers():
+ """DistributionStrategies and V2 Optimizers."""
+ return combine(
+ distribution=[
+ one_device_strategy, mirrored_strategy_with_gpu_and_cpu,
+ mirrored_strategy_with_two_gpus
+ ],
+ optimizer_fn=[adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn])
diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py
new file mode 100644
index 0000000000..219b24160f
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/combinations_test.py
@@ -0,0 +1,115 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for some testing utils from strategy_test_lib."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import OrderedDict
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.python.eager import test
+
+
+class TestingCombinationsTest(test.TestCase):
+
+ def test_combine(self):
+ self.assertEqual([{
+ "a": 1,
+ "b": 2
+ }, {
+ "a": 1,
+ "b": 3
+ }, {
+ "a": 2,
+ "b": 2
+ }, {
+ "a": 2,
+ "b": 3
+ }], combinations.combine(a=[1, 2], b=[2, 3]))
+
+ def test_add(self):
+ self.assertEqual(
+ [{
+ "a": 1
+ }, {
+ "a": 2
+ }, {
+ "b": 2
+ }, {
+ "b": 3
+ }],
+ combinations.combine(a=[1, 2]) +
+ combinations.combine(b=[2, 3]))
+
+ def test_times(self):
+ c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"])
+ c2 = combinations.combine(mode=["eager"], loss=["callable"])
+ c3 = combinations.combine(distribution=["d1", "d2"])
+ c4 = combinations.times(c3, c1 + c2)
+ self.assertEqual([
+ OrderedDict([("distribution", "d1"), ("loss", "callable"),
+ ("mode", "graph")]),
+ OrderedDict([("distribution", "d1"), ("loss", "tensor"),
+ ("mode", "graph")]),
+ OrderedDict([("distribution", "d1"), ("loss", "callable"),
+ ("mode", "eager")]),
+ OrderedDict([("distribution", "d2"), ("loss", "callable"),
+ ("mode", "graph")]),
+ OrderedDict([("distribution", "d2"), ("loss", "tensor"),
+ ("mode", "graph")]),
+ OrderedDict([("distribution", "d2"), ("loss", "callable"),
+ ("mode", "eager")])
+ ], c4)
+
+ def test_times_variable_arguments(self):
+ c1 = combinations.combine(mode=["graph", "eager"])
+ c2 = combinations.combine(optimizer=["adam", "gd"])
+ c3 = combinations.combine(distribution=["d1", "d2"])
+ c4 = combinations.times(c3, c1, c2)
+ self.assertEqual([
+ OrderedDict([("distribution", "d1"), ("mode", "graph"),
+ ("optimizer", "adam")]),
+ OrderedDict([("distribution", "d1"), ("mode", "graph"),
+ ("optimizer", "gd")]),
+ OrderedDict([("distribution", "d1"), ("mode", "eager"),
+ ("optimizer", "adam")]),
+ OrderedDict([("distribution", "d1"), ("mode", "eager"),
+ ("optimizer", "gd")]),
+ OrderedDict([("distribution", "d2"), ("mode", "graph"),
+ ("optimizer", "adam")]),
+ OrderedDict([("distribution", "d2"), ("mode", "graph"),
+ ("optimizer", "gd")]),
+ OrderedDict([("distribution", "d2"), ("mode", "eager"),
+ ("optimizer", "adam")]),
+ OrderedDict([("distribution", "d2"), ("mode", "eager"),
+ ("optimizer", "gd")])
+ ], c4)
+ self.assertEqual(
+ combinations.combine(
+ mode=["graph", "eager"],
+ optimizer=["adam", "gd"],
+ distribution=["d1", "d2"]), c4)
+
+ def test_overlapping_keys(self):
+ c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"])
+ c2 = combinations.combine(mode=["eager"], loss=["callable"])
+ with self.assertRaisesRegexp(ValueError, ".*Keys.+overlap.+"):
+ _ = combinations.times(c1, c2)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
new file mode 100644
index 0000000000..cb98351735
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -0,0 +1,410 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Classes for different algortihms of reduction and broadcasting."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.contrib.distribute.python import cross_tower_utils
+from tensorflow.contrib.distribute.python import values as value_lib
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import device_util
+
+
+def _validate_destinations(destinations):
+ if not isinstance(destinations,
+ (value_lib.DistributedValues, six.string_types, list)):
+ raise ValueError("destinations must be one of a `DistributedValues` object,"
+ " a device string, a list of device strings or None")
+
+ if not destinations:
+ raise ValueError("destinations can not be empty")
+
+
+def _validate_value_destination_pairs(value_destination_pairs):
+ # pylint: disable=g-missing-docstring
+ if not value_destination_pairs: return False
+ if not isinstance(value_destination_pairs, (list, tuple)): return False
+ if not all([isinstance(pair, tuple) for pair in value_destination_pairs]):
+ return False
+ if not all([isinstance(v[0], value_lib.PerDevice)
+ for v in value_destination_pairs]):
+ return False
+ return True
+
+
+def _get_devices_from(destinations):
+ if isinstance(destinations, value_lib.DistributedValues):
+ return list(destinations.devices)
+ elif isinstance(destinations, six.string_types):
+ return [device_util.canonicalize(destinations)]
+ else:
+ return [
+ device_util.canonicalize(destination) for destination in destinations
+ ]
+
+
+def _devices_match(left, right):
+ return set(_get_devices_from(left)) == set(_get_devices_from(right))
+
+
+def _all_devices_match(value_destination_pairs):
+ if not all([d is None or _devices_match(v, d)
+ for v, d in value_destination_pairs]):
+ return False
+ if not all([_devices_match(v, value_destination_pairs[0][0])
+ for v, _ in value_destination_pairs[1:]]):
+ return False
+ return True
+
+
+def _simple_broadcast(tensor, destinations):
+ index = {}
+ devices = _get_devices_from(destinations)
+ for d in devices:
+ with ops.device(d):
+ index[d] = array_ops.identity(tensor)
+ return value_lib.Mirrored(index)
+
+
+def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
+ method_string):
+ # pylint: disable=g-missing-docstring
+ all_values = []
+ count = 0
+ for v in per_device_value._index.values(): # pylint: disable=protected-access
+ if isinstance(v, value_lib.MapOutput):
+ v_list = v.get()
+ if not v_list:
+ continue
+ count += len(v_list)
+ # Sum within each device before aggregating across devices.
+ v = math_ops.add_n(v_list)
+ else:
+ count += 1
+ all_values.append(v)
+ if not all_values:
+ raise ValueError("`per_device_value` must be non-empty")
+
+ with ops.device(reduce_to_device):
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ if method_string == "sum":
+ reduced = accumulation_fn(all_values)
+ elif method_string == "mean":
+ reduced = accumulation_fn(all_values) / count
+ else:
+ raise ValueError("`method_string` must be 'sum' or 'mean'")
+ return reduced
+
+
+class CrossTowerOps(object):
+ """Base class for cross-tower reduction and broadcasting algorithms."""
+
+ def __init__(self):
+ pass
+
+ def reduce(self, method_string, per_device_value, destinations=None):
+ """Reduce `per_device_value` to `destinations`.
+
+ It runs the reduction operation defined by `method_string` and put the
+ result on `destinations`.
+
+ Args:
+ method_string: either 'sum' or 'mean' specifying the reduction method.
+ per_device_value: a PerDevice object.
+ destinations: the reduction destinations.
+
+ Returns:
+ a Mirrored object.
+
+ Raises:
+ ValueError: if per_device_value is not a PerDevice object.
+ """
+ if not isinstance(per_device_value, value_lib.PerDevice):
+ raise ValueError("`per_device_value` must be a `PerDevice` object.")
+ if destinations is not None:
+ _validate_destinations(destinations)
+ return self._reduce(method_string, per_device_value, destinations)
+
+ def batch_reduce(self, method_string, value_destination_pairs):
+ """Reduce PerDevice objects in a batch.
+
+ Reduce each first element in `value_destination_pairs` to each second
+ element which indicates the destinations.
+
+ Args:
+ method_string: either 'sum' or 'mean' specifying the reduction method.
+ value_destination_pairs: a list or a tuple of tuples of PerDevice objects
+ and destinations. If a destionation is None, then the destinations
+ are set to match the devices of the input PerDevice object.
+
+ Returns:
+ a list of Mirrored objects.
+
+ Raises:
+ ValueError: if `value_destination_pairs` is not a list or a tuple of
+ tuples of PerDevice objects and destinations
+ """
+ if not _validate_value_destination_pairs(value_destination_pairs):
+ raise ValueError("`value_destination_pairs` must be a list or a tuple of "
+ "tuples of PerDevice objects and destinations")
+ for _, d in value_destination_pairs:
+ if d is not None:
+ _validate_destinations(d)
+
+ return self._batch_reduce(method_string, value_destination_pairs)
+
+ def broadcast(self, tensor, destinations):
+ """Broadcast the `tensor` to destinations.
+
+ Args:
+ tensor: the tensor to broadcast.
+ destinations: the broadcast destinations.
+
+ Returns:
+ a Mirrored object.
+ """
+ _validate_destinations(destinations)
+ return self._broadcast(tensor, destinations)
+
+ def _reduce(self, method_string, per_device_value, destinations):
+ raise NotImplementedError(
+ "_reduce method must be implemented in descendants.")
+
+ def _batch_reduce(self, method_string, value_destination_pairs):
+ raise NotImplementedError(
+ "_batch_reduce method must be implemented in descendants.")
+
+ def _broadcast(self, tensor, destinations):
+ return _simple_broadcast(tensor, destinations)
+
+
+class ReductionToOneDeviceCrossTowerOps(CrossTowerOps):
+ """Always do reduction to one device first and then do broadcasting.
+
+ Batch reduction is done by reduction on each element one by one.
+ """
+
+ def __init__(self, reduce_to_device=None, accumulation_fn=math_ops.add_n):
+ """Constructor.
+
+ Args:
+ reduce_to_device: the intermediate device to reduce to. If None, reduce
+ to the first device in `destinations` of the reduce() method.
+ accumulation_fn: a function that does accumulation.
+ """
+ self.reduce_to_device = reduce_to_device
+ self.accumulation_fn = accumulation_fn
+ super(ReductionToOneDeviceCrossTowerOps, self).__init__()
+
+ def _reduce(self, method_string, per_device_value, destinations):
+ devices = _get_devices_from(destinations or per_device_value)
+ reduce_to_device = self.reduce_to_device or devices[0]
+ reduced = _simple_reduce(per_device_value, reduce_to_device,
+ self.accumulation_fn, method_string)
+ return self.broadcast(reduced, devices)
+
+ def _batch_reduce(self, method_string, value_destination_pairs):
+ return [self._reduce(method_string, t, destinations=v)
+ for t, v in value_destination_pairs]
+
+
+def _group_value_by_device(per_device_values):
+ """Group values into sublists by their devices.
+
+ This grouping is needed to call the allreduce library.
+
+ Args:
+ per_device_values: a list of PerDevice obejcts.
+
+ Returns:
+ a list of lists, each sublist has components for its corresponding device of
+ PerDevice objects, paired with a None.
+ """
+ destinations = per_device_values[0].devices
+ grouped = [[] for _ in range(len(destinations))]
+ for per_device_value in per_device_values:
+ # pylint: disable=protected-access
+ for i, v in enumerate(per_device_value._index.values()):
+ assert per_device_value.devices == destinations
+ grouped[i].append((v, None))
+ return grouped
+
+
+def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string):
+ """Ungroup results from allreduce and make Mirrored objects.
+
+ Each allreduce result would be divided by the number of destinations before
+ Mirrored objects are created if method_string is "mean".
+ """
+ index = [{} for _ in range(len(grouped_reduced[0]))]
+ for d, per_device_reduced in enumerate(grouped_reduced):
+ for i, (v, _) in enumerate(per_device_reduced):
+ if method_string == "mean":
+ index[i][destinations[d]] = v / len(destinations)
+ else:
+ index[i][destinations[d]] = v
+ return [value_lib.Mirrored(v) for v in index]
+
+
+class AllReduceCrossTowerOps(CrossTowerOps):
+ """Reduction using all reduce."""
+
+ def __init__(self, all_reduce_alg="nccl", gradient_repacking=1):
+ """Initialize this subclass of CrossTowerOps with allreduce.
+
+ Gradients would be repacked for more efficient cross-device transportation.
+
+ Args:
+ all_reduce_alg: the allreduce algorithm to use, currently only "nccl" or
+ "hierarchical_copy" are supported.
+ gradient_repacking: If zero, no gradient repacking would be done. If
+ non-zero value it specifies the number of split packs that will be
+ formed.
+ """
+ self.all_reduce_alg = all_reduce_alg
+ self.gradient_repacking = gradient_repacking
+ super(AllReduceCrossTowerOps, self).__init__()
+
+ def _reduce(self, method_string, per_device_value, destinations):
+ if ((destinations is None or _devices_match(per_device_value, destinations))
+ and not context.executing_eagerly()):
+ return self._batch_all_reduce(method_string, [per_device_value])[0]
+ else:
+ devices = _get_devices_from(destinations or per_device_value)
+ reduce_to_device = devices[0]
+ reduced = _simple_reduce(per_device_value, reduce_to_device,
+ math_ops.add_n, method_string)
+ return self.broadcast(reduced, devices)
+
+ def _batch_reduce(self, method_string, value_destination_pairs):
+ if (_all_devices_match(value_destination_pairs) and
+ not context.executing_eagerly()):
+ return self._batch_all_reduce(method_string,
+ [v[0] for v in value_destination_pairs])
+ else:
+ if not context.executing_eagerly():
+ logging.warning("Efficient batch_reduce is not supported if "
+ "destinations are different.")
+ return [
+ self._reduce(method_string, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
+
+ def _batch_all_reduce(self, method_string, per_device_values):
+ """All reduce algorithm in a batch."""
+ logging.info("batch_all_reduce invoked for batches size = %d with algorithm"
+ " = %s and gradient repacking = %d", len(per_device_values),
+ self.all_reduce_alg, self.gradient_repacking)
+ destinations = per_device_values[0].devices
+ grouped = _group_value_by_device(per_device_values)
+ if self.gradient_repacking == 0:
+ if self.all_reduce_alg == "nccl":
+ reduced = cross_tower_utils.aggregate_gradients_using_nccl(grouped)
+ else:
+ # TODO(yuefengz): check that gpu ids in `destinations` are in ascending
+ # order.
+ reduced = (
+ cross_tower_utils.aggregate_gradients_using_hierarchical_copy(
+ destinations, grouped))
+ else:
+ device_grad_packs = []
+ all_tower_shapes = []
+ all_tower_sizes = []
+ for tower_grads_and_vars in grouped:
+ with ops.colocate_with(tower_grads_and_vars[0][0]):
+ # Flatten all the grads.
+ flat_grads = [
+ array_ops.reshape(g, [-1]) for g, _ in tower_grads_and_vars
+ ]
+ # Remember the original shape of all the grads.
+ tower_shapes = [array_ops.shape(g) for g, _ in tower_grads_and_vars]
+ # Remember the original sizes of all the grads.
+ tower_sizes = [array_ops.size(g) for g, _ in tower_grads_and_vars]
+ # Concat all the flat grads into a big flat tensor.
+ concat_grads = array_ops.concat(flat_grads, 0)
+
+ # Split the big tensor into num_splits packs. In cases where the
+ # total size is not divisible num_splits, the last pack gets
+ # more elements.
+ # TODO(zhengxq): it is possible to optimize away the additional
+ # data movement by copying along the original variable boundary.
+ # TODO(zhengxq): it is also possible to optimize away all the concat
+ # as well.
+ num_splits = self.gradient_repacking
+ total_grad_size = array_ops.size(concat_grads)
+ split_size = total_grad_size // num_splits
+ split_size_last = total_grad_size - split_size * (num_splits - 1)
+ split_sizes = [split_size] * (num_splits - 1) + [split_size_last]
+ grad_packs = array_ops.split(concat_grads, split_sizes)
+
+ # Ready to aggregate the repacked gradients, with fake variables.
+ # TODO(zhengxq): It is hacky to have to use fake variables.
+ # We should remove the need for variables in
+ # aggregate_gradients_using*.
+ device_grad_packs.append(zip(grad_packs, [None] * num_splits))
+ all_tower_shapes.append(tower_shapes)
+ all_tower_sizes.append(tower_sizes)
+
+ # The actual aggregation of the repacked gradients. Note that they are
+ # sharded among different aggregation trees. So it is important to
+ # strike the balance on num_splits.
+ if self.all_reduce_alg == "nccl":
+ summed_device_grad_packs = (
+ cross_tower_utils.aggregate_gradients_using_nccl(device_grad_packs))
+ else:
+ summed_device_grad_packs = (
+ cross_tower_utils.aggregate_gradients_using_hierarchical_copy(
+ destinations, device_grad_packs))
+
+ aggregated_device_grads = []
+ for (summed_tower_grad_packs, tower_grads_and_vars, tower_shapes,
+ tower_sizes) in zip(summed_device_grad_packs, grouped,
+ all_tower_shapes, all_tower_sizes):
+ # pylint: enable=line-too-long
+ # Reverse the packing operations in the previous steps. Form the
+ # summed gradients back into their original shapes.
+ with ops.colocate_with(summed_tower_grad_packs[0][0]):
+ # Form a list of the summed grad packs.
+ device_grad_packs = [g for g, _ in summed_tower_grad_packs]
+
+ # Concat them back into a big flat tensor.
+ device_grads_concat = array_ops.concat(device_grad_packs, 0)
+
+ # Split the tensors back into their original sizes.
+ grads_with_sizes = array_ops.split(device_grads_concat, tower_sizes)
+
+ # Reshape the tensors back into their original shapes.
+ grads_with_shapes = [
+ array_ops.reshape(grad, shape)
+ for shape, grad in zip(tower_shapes, grads_with_sizes)
+ ]
+
+ # Form the list with the original list of variables.
+ summed_tower_grads = [
+ (g, v)
+ for g, (_, v) in zip(grads_with_shapes, tower_grads_and_vars)
+ ]
+ aggregated_device_grads.append(summed_tower_grads)
+ reduced = aggregated_device_grads
+ return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices,
+ method_string)
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
new file mode 100644
index 0000000000..bb43147f5e
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -0,0 +1,185 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for CrossTowerOps."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
+from tensorflow.contrib.distribute.python import values as value_lib
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+
+def _make_per_device(values, devices):
+ devices = cross_tower_ops_lib._get_devices_from(devices)
+ assert len(values) == len(devices)
+ index = {}
+ for d, v in zip(devices, values):
+ with ops.device(d):
+ placed_v = array_ops.identity(v)
+ index[d] = placed_v
+ return value_lib.PerDevice(index)
+
+
+# pylint: disable=g-doc-args,g-doc-return-or-yield
+def _fake_mirrored(value, devices):
+ """Create a faked Mirrored object for testing.
+
+ All components of the returned Mirrored have the same objects, which is not
+ true in reality.
+ """
+ devices = cross_tower_ops_lib._get_devices_from(devices)
+ return value_lib.Mirrored(
+ {d: v for d, v in zip(devices, [value] * len(devices))})
+
+
+_cpu_device = "/device:CPU:0"
+
+
+class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
+
+ def _assert_value_equal(self, left, right):
+ if isinstance(left, list):
+ for l, r in zip(left, right):
+ self._assert_value_equal(l, r)
+ else:
+ self.assertEqual(type(left), type(right))
+ self.assertEqual(left.devices, right.devices)
+ if context.executing_eagerly():
+ self.assertEqual([v.numpy() for v in left._index.values()],
+ list(right._index.values()))
+ else:
+ with self.test_session() as sess:
+ self.assertEqual(
+ sess.run(list(left._index.values())), list(right._index.values()))
+
+ # TODO(yuefengz): decouple the num_gpus check from distribution in
+ # combinations module so that we can pass in devices instead of a distribution
+ # strategy.
+ reduction_to_one_combinations = combinations.combine(
+ cross_tower_ops=[
+ combinations.NamedObject(
+ "DefaultReductionToOneDeviceCrossTowerOps",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
+ combinations.NamedObject(
+ "ReductionToCPUDeviceCrossTowerOps",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
+ reduce_to_device=_cpu_device)),
+ combinations.NamedObject(
+ "AccumulateNCrossTowerOp",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
+ accumulation_fn=math_ops.accumulate_n)),
+ ],
+ distribution=[
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus
+ ],
+ mode=["graph", "eager"])
+ allreduce_combinations = combinations.combine(
+ cross_tower_ops=[
+ combinations.NamedObject("AllReduce",
+ cross_tower_ops_lib.AllReduceCrossTowerOps(
+ "nccl", 1)),
+ combinations.NamedObject("HierarchicalCopy",
+ cross_tower_ops_lib.AllReduceCrossTowerOps(
+ "hierarchical_copy", 8)),
+ combinations.NamedObject("AllReduceNoGradientRepacking",
+ cross_tower_ops_lib.AllReduceCrossTowerOps(
+ "nccl", 0)),
+ combinations.NamedObject("HierarchicalCopyNoGradientRepacking",
+ cross_tower_ops_lib.AllReduceCrossTowerOps(
+ "hierarchical_copy", 0))
+ ],
+ distribution=[
+ combinations.mirrored_strategy_with_two_gpus
+ ],
+ mode=["graph", "eager"])
+
+ @combinations.generate(reduction_to_one_combinations + allreduce_combinations)
+ def testReductionAndBroadcast(self, cross_tower_ops, distribution):
+ devices = distribution.worker_devices
+
+ values = [constant_op.constant(float(d)) for d in range(len(devices))]
+ per_device = _make_per_device(values, devices)
+ mean = (len(devices) - 1.) / 2.
+
+ values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))]
+ per_device_2 = _make_per_device(values_2, devices)
+ mean_2 = mean + 1.
+
+ destination_mirrored = _fake_mirrored(1., devices)
+ destination_different = _fake_mirrored(1., _cpu_device)
+ destination_str = _cpu_device
+ destination_list = devices
+
+ all_destinations = [
+ None, destination_mirrored, destination_different, destination_str,
+ destination_list
+ ]
+
+ # test reduce()
+ for destinations in all_destinations:
+ self._assert_value_equal(
+ cross_tower_ops.reduce("mean", per_device, destinations=destinations),
+ _fake_mirrored(mean, destinations or per_device))
+ self._assert_value_equal(
+ cross_tower_ops.reduce(
+ "mean", per_device_2, destinations=destinations),
+ _fake_mirrored(mean_2, destinations or per_device))
+ self._assert_value_equal(
+ cross_tower_ops.reduce("sum", per_device, destinations=destinations),
+ _fake_mirrored(mean * len(devices), destinations or per_device))
+ self._assert_value_equal(
+ cross_tower_ops.reduce(
+ "sum", per_device_2, destinations=destinations),
+ _fake_mirrored(mean_2 * len(devices), destinations or per_device))
+
+ # test batch_reduce()
+ for d1, d2 in itertools.product(all_destinations, all_destinations):
+ self._assert_value_equal(
+ cross_tower_ops.batch_reduce(
+ "mean", [(per_device, d1), (per_device_2, d2)]),
+ [_fake_mirrored(mean, d1 or per_device),
+ _fake_mirrored(mean_2, d2 or per_device_2)])
+ self._assert_value_equal(
+ cross_tower_ops.batch_reduce(
+ "sum", [(per_device, d1), (per_device_2, d2)]),
+ [_fake_mirrored(mean * len(devices), d1 or per_device),
+ _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)])
+
+ # test broadcast()
+ for destinations in all_destinations:
+ if destinations is None:
+ continue
+ else:
+ self._assert_value_equal(
+ cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
+ _fake_mirrored(1., destinations))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py
new file mode 100644
index 0000000000..93acd835d7
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py
@@ -0,0 +1,153 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for cross_tower_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import nccl
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+
+def aggregate_gradients_using_nccl(tower_grads):
+ """Aggregate gradients using nccl allreduce."""
+ agg_all_g_and_v = []
+ for single_g_and_v in zip(*tower_grads):
+ single_grads = [g for g, _ in single_g_and_v]
+ agg_grads = nccl.all_sum(single_grads)
+ agg_all_g_and_v.append(
+ [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)])
+
+ agg_all_g_and_v = list(zip(*agg_all_g_and_v))
+
+ return agg_all_g_and_v
+
+
+def aggregate_gradients_using_hierarchical_copy(avail_devices, tower_grads):
+ """Aggregate gradients using hierarchical copies.
+
+ Args:
+ avail_devices: available GPU devices.
+ tower_grads: List of lists of (gradient, variable) tuples. The outer list
+ is over towers. The inner list is over individual gradients.
+
+ Returns:
+ The list of (aggregated_gradient, variable), where the gradient has been
+ summed across all towers and the variable is chosen from the first tower.
+ """
+ # This only works for DGX-1 type of machine topology
+ # Device peer to peer matrix
+ # DMA: 0 1 2 3 4 5 6 7
+ # 0: Y Y Y Y Y N N N
+ # 1: Y Y Y Y N Y N N
+ # 2: Y Y Y Y N N Y N
+ # 3: Y Y Y Y N N N Y
+ # 4: Y N N N Y Y Y Y
+ # 5: N Y N N Y Y Y Y
+ # 6: N N Y N Y Y Y Y
+ # 7: N N N Y Y Y Y Y
+ agg_grads = []
+ num_devices = len(avail_devices)
+ # In the special case of DGX-1 machine topology, the two groups have equal
+ # size.
+ group_size = num_devices // 2
+ for i, single_grads in enumerate(zip(*tower_grads)):
+ group_0_main_device = i % num_devices
+ group_1_main_device = (group_0_main_device + group_size) % num_devices
+ if group_0_main_device < group_size:
+ group_0_begin = 0
+ group_1_begin = group_size
+ else:
+ group_0_begin = group_size
+ group_1_begin = 0
+
+ # Aggregate the first group.
+ group_0_device_grads = single_grads[group_0_begin:
+ group_0_begin + group_size]
+ with ops.device(avail_devices[group_0_main_device]):
+ group_0_agg_grads, _ = aggregate_single_gradient_using_copy(
+ group_0_device_grads, False, False)
+
+ # Aggregate the second group.
+ group_1_device_grads = single_grads[group_1_begin:
+ group_1_begin + group_size]
+ with ops.device(avail_devices[group_1_main_device]):
+ group_1_agg_grads, _ = aggregate_single_gradient_using_copy(
+ group_1_device_grads, False, False)
+
+ # Aggregate between the groups.
+ with ops.device(avail_devices[group_0_main_device]):
+ (agg_total_grads, _), _ = aggregate_single_gradient_using_copy(
+ [group_0_agg_grads, group_1_agg_grads], False, False)
+
+ # Broadcast the result back into the root of each group.
+ with ops.device(avail_devices[group_0_main_device]):
+ group_0_agg_grads_bcast = array_ops.identity(agg_total_grads)
+ with ops.device(avail_devices[group_1_main_device]):
+ group_1_agg_grads_bcast = array_ops.identity(agg_total_grads)
+
+ agg_grads_bcast = []
+ for j in range(len(single_grads)):
+ with ops.device(avail_devices[j]):
+ # Broadcast the result back to each member in the group from the root.
+ if (group_0_main_device < group_size) == (j < group_size):
+ src_device_grad = group_0_agg_grads_bcast
+ else:
+ src_device_grad = group_1_agg_grads_bcast
+ agg_grads_bcast.append(array_ops.identity(src_device_grad))
+
+ agg_grads.append(
+ [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)])
+
+ agg_grads = list(zip(*agg_grads))
+
+ return agg_grads
+
+
+def aggregate_single_gradient_using_copy(grad_and_vars, use_mean,
+ check_inf_nan):
+ """Calculate the average gradient for a shared variable across all towers.
+
+ Note that this function provides a synchronization point across all towers.
+
+ Args:
+ grad_and_vars: A list or tuple of (gradient, variable) tuples. Each
+ (gradient, variable) pair within the outer list represents the gradient
+ of the variable calculated for a single tower, and the number of pairs
+ equals the number of towers.
+ use_mean: if True, mean is taken, else sum of gradients is taken.
+ check_inf_nan: check grads for nans and infs.
+
+ Returns:
+ The tuple ([(average_gradient, variable),], has_nan_or_inf) where the
+ gradient has been averaged across all towers. The variable is chosen from
+ the first tower. The has_nan_or_inf indicates the grads has nan or inf.
+ """
+ grads = [g for g, _ in grad_and_vars]
+ grad = math_ops.add_n(grads)
+
+ if use_mean and len(grads) > 1:
+ grad = array_ops.multiply(grad, 1.0 / len(grads))
+
+ v = grad_and_vars[0][1]
+ if check_inf_nan:
+ has_nan_or_inf = array_ops.logical_not(
+ array_ops.reduce_all(array_ops.is_finite(grads)))
+ return (grad, v), has_nan_or_inf
+ else:
+ return (grad, v), None
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
new file mode 100644
index 0000000000..0fa90df79b
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -0,0 +1,279 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for running legacy optimizer code with DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example
+from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
+from tensorflow.python.ops.losses import losses_impl
+
+
+class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(
+ combinations.times(
+ combinations.distributions_and_v1_optimizers(),
+ combinations.combine(mode=["graph"], use_callable_loss=[True, False])
+ + combinations.combine(mode=["eager"], use_callable_loss=[True])))
+ def testTrainNetwork(self, distribution, optimizer_fn,
+ use_callable_loss=True):
+ with distribution.scope():
+ model_fn, dataset, layer = minimize_loss_example(
+ optimizer_fn,
+ use_bias=True,
+ use_callable_loss=use_callable_loss)
+
+ iterator = distribution.distribute_dataset(dataset)
+
+ def run_step():
+ return distribution.group(
+ distribution.call_for_each_tower(
+ model_fn, iterator.get_next(), run_concurrently=layer.built))
+
+ if not context.executing_eagerly():
+ with self.test_session() as sess:
+ run_step = sess.make_callable(run_step())
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ weights, biases = [], []
+ for _ in range(10):
+ run_step()
+
+ weights.append(self.evaluate(distribution.fetch(layer.kernel)))
+ biases.append(self.evaluate(distribution.fetch(layer.bias)))
+
+ error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
+ is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
+ self.assertTrue(is_not_increasing)
+
+ @combinations.generate(
+ combinations.times(
+ combinations.distributions_and_v1_optimizers() +
+ combinations.distributions_and_v2_optimizers(),
+ combinations.combine(mode=["graph", "eager"])))
+ def testOptimizerInsideModelFn(self, distribution, optimizer_fn):
+ created_variables = []
+ trainable_variables = []
+
+ def appending_creator(next_creator, *args, **kwargs):
+ v = next_creator(*args, **kwargs)
+ created_variables.append(v.name)
+ if "trainable" in kwargs and kwargs["trainable"]:
+ trainable_variables.append(v.name)
+ return v
+
+ # Creator scope needs to be set before it's used inside
+ # `distribution.scope`.
+ with variable_scope.variable_creator_scope(
+ appending_creator), distribution.scope():
+ model_fn, dataset, layer = minimize_loss_example(
+ optimizer_fn,
+ use_bias=True,
+ use_callable_loss=True,
+ create_optimizer_inside_model_fn=True)
+
+ iterator = distribution.distribute_dataset(dataset)
+
+ def run_step():
+ return distribution.group(
+ distribution.call_for_each_tower(
+ model_fn, iterator.get_next(), run_concurrently=layer.built))
+
+ if not context.executing_eagerly():
+ with self.test_session() as sess:
+ run_step = sess.make_callable(run_step())
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ run_step()
+
+ def get_expected_variables(optimizer_fn, num_parameter_devices):
+ variables_map = {
+ "GradientDescent": ["dense/kernel", "dense/bias"],
+ "Adam": [
+ "dense/kernel", "dense/bias", "beta1_power", "beta2_power",
+ "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam",
+ "dense/bias/Adam_1"
+ ]
+ }
+ variables = variables_map[optimizer_fn().get_name()]
+ variables.extend([
+ v + "/replica_{}".format(replica)
+ for v in variables
+ for replica in range(1, num_parameter_devices)
+ ])
+ return set([v + ":0" for v in variables])
+
+ self.assertEqual(
+ get_expected_variables(optimizer_fn,
+ len(distribution.parameter_devices)),
+ set(created_variables))
+
+ @combinations.generate(
+ combinations.times(combinations.distributions_and_v1_optimizers(),
+ combinations.combine(
+ mode=["graph", "eager"],
+ momentum=[0.8, 0.9, 0.99],
+ renorm=[False, True])))
+ def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum,
+ renorm):
+ """Verifies that moving mean updates are reduced across towers."""
+ with distribution.scope():
+ num_towers = len(distribution.worker_devices)
+ model_fn, dataset, batchnorm = batchnorm_example(
+ optimizer_fn,
+ batch_per_epoch=num_towers,
+ momentum=momentum,
+ renorm=renorm)
+
+ # Disable prefetching since that makes the specific input on each device
+ # to be non deterministic, and this test relies on specific input being
+ # on each device.
+ if isinstance(distribution, mirrored_strategy.MirroredStrategy):
+ distribution._prefetch_on_device = False
+ iterator = distribution.distribute_dataset(dataset)
+
+ def run_step():
+ return control_flow_ops.group(
+ distribution.unwrap(
+ distribution.call_for_each_tower(
+ model_fn,
+ iterator.get_next(),
+ run_concurrently=batchnorm.built)) +
+ ops.get_collection(ops.GraphKeys.UPDATE_OPS))
+
+ if not context.executing_eagerly():
+ with self.test_session() as sess:
+ run_step = sess.make_callable(run_step())
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ expected_moving_means = [0.] * 8
+
+ def averaged_batch_mean(i):
+ # Each batch has shape [16, 8] where the ith element in jth list is
+ # (8 * j + i + tower_id * 100). So the batch mean in each tower is
+ # (60 + i + tower_id * 100). So here comes its batch mean over all
+ # towers:
+ return 60. + i + (num_towers - 1.) / 2. * 100.
+
+ for _ in range(10):
+ run_step()
+ moving_means = self.evaluate(distribution.fetch(batchnorm.moving_mean))
+
+ # We make sure that the moving_mean is updated as if the sample mean is
+ # calculated over all towers.
+ for i, expected_moving_mean in enumerate(expected_moving_means):
+ expected_moving_means[i] -= ((
+ expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum))
+ self.assertNear(expected_moving_means[i], moving_means[i], 0.0001)
+
+ @combinations.generate(
+ combinations.times(
+ combinations.combine(
+ distribution=[combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus],
+ optimizer_fn=[combinations.gradient_descent_optimizer_v1_fn,
+ combinations.gradient_descent_optimizer_v2_fn],
+ loss_reduction=[losses_impl.Reduction.SUM,
+ losses_impl.Reduction.MEAN,
+ losses_impl.Reduction.SUM_OVER_BATCH_SIZE,
+ losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS]),
+ combinations.combine(mode=["graph"], use_callable_loss=[True, False])
+ + combinations.combine(mode=["eager"], use_callable_loss=[True])))
+ def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction,
+ use_callable_loss):
+ with distribution.scope():
+ all_vars = []
+
+ def model_fn(x, y):
+
+ def loss_fn():
+ # Use fixed initialization to make the steps deterministic.
+ w = variable_scope.get_variable("w", initializer=[[2.]])
+ all_vars.append(w)
+ predict = math_ops.matmul(x, w)
+ return losses_impl.mean_squared_error(
+ y, predict, reduction=loss_reduction)
+
+ optimizer = optimizer_fn() # GradientDescent with 0.2 learning rate
+
+ if use_callable_loss:
+ return optimizer.minimize(loss_fn)
+ else:
+ return optimizer.minimize(loss_fn())
+
+ features = dataset_ops.Dataset.from_tensors([[2.], [7.]])
+ labels = dataset_ops.Dataset.from_tensors([[6.], [21.]])
+ dataset = dataset_ops.Dataset.zip((features, labels)).repeat()
+ iterator = distribution.distribute_dataset(dataset)
+
+ def run_step():
+ return distribution.group(
+ distribution.call_for_each_tower(
+ model_fn, *iterator.get_next(), run_concurrently=False))
+
+ if not context.executing_eagerly():
+ with self.test_session() as sess:
+ run_step = sess.make_callable(run_step())
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ run_step()
+
+ self.assertEqual(distribution.num_towers, len(all_vars))
+ v = all_vars[0]
+ self.assertTrue(all([v is vi for vi in all_vars[1:]]))
+ weight = numpy.squeeze(self.evaluate(distribution.fetch(v)))
+ # Our model is:
+ # predict = x * w
+ # loss = (predict - y)^2
+ # dloss/dpredict = 2*(predict - y)
+ # dloss/dw = 2 * x^T @ (predict - y)
+ # For our batch size of 2, assuming sum loss reduction:
+ # x = [2, 7]
+ # y = [6, 21]
+ # w_initial = 2
+ # predict = [4, 14]
+ # predict - y = [-2, -7]
+ # dloss/dw = 2 <[2, 7], [-2, -7]> = - 2(4 + 49) = -106
+ # So unreplicated the update to w with lr=0.2 is -0.2 * -106 = 21.2
+ # with sum loss reduction, or 10.6 with mean.
+ if loss_reduction == losses_impl.Reduction.SUM:
+ # Note that the "distribution.num_towers" factor will go away once
+ # we split the input across towers, instead of pulling a complete
+ # batch of input per tower.
+ self.assertNear(weight, 2 + 21.2 * distribution.num_towers, 0.0001)
+ else:
+ # One of the mean loss reductions.
+ self.assertNear(weight, 2 + 10.6, 0.0001)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
new file mode 100644
index 0000000000..8cf83c52d8
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -0,0 +1,486 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Class MirroredStrategy implementing DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+import six
+
+from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
+from tensorflow.contrib.distribute.python import shared_variable_creator
+from tensorflow.contrib.distribute.python import values
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
+from tensorflow.python.framework import device as tf_device
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import coordinator
+from tensorflow.python.training import device_util
+from tensorflow.python.training import distribute as distribute_lib
+
+
+# TODO(josh11b): Replace asserts in this file with if ...: raise ...
+
+
+def _cpu_device(device):
+ cpu_device = tf_device.DeviceSpec.from_string(device)
+ cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0))
+ return cpu_device.to_string()
+
+
+class _RequestedStop(Exception):
+ pass
+
+
+class MirroredStrategy(distribute_lib.DistributionStrategy):
+ """Mirrors vars to distribute across multiple devices on a single machine.
+
+ This strategy uses one tower per device and sync replication.
+ """
+
+ def __init__(self,
+ devices=None,
+ num_gpus=None,
+ cross_tower_ops=None,
+ prefetch_on_device=None):
+ super(MirroredStrategy, self).__init__()
+ # Convert `num_gpus` into `devices`, shouldn't specify both.
+ if devices is None:
+ if num_gpus is None:
+ num_gpus = context.num_gpus()
+ devices = ["/device:GPU:%d" % d for d in range(num_gpus)]
+ elif num_gpus is not None:
+ raise ValueError("Must only specify one of `devices` and `num_gpus`.")
+
+ assert devices, "Must specify at least one device."
+ assert len(set(devices)) == len(devices), (
+ "No duplicates allowed in `devices` argument.")
+ # TODO(josh11b): Require at least 2 devices?
+ self._devices = devices
+ self._canonical_device_set = set(
+ [device_util.canonicalize(d) for d in devices])
+ self._device_index = values.PerDevice(
+ dict((d, i) for i, d in enumerate(devices)))
+ self.cross_tower_ops = (
+ cross_tower_ops or
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps())
+ self._prefetch_on_device = prefetch_on_device
+
+ def _create_variable(self, next_creator, *args, **kwargs):
+ """Create a mirrored variable. See `DistributionStrategy.scope`."""
+ # Figure out what collections this variable should be added to.
+ # We'll add the MirroredVariable to those collections instead.
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ colocate_with = kwargs.pop("colocate_with", None)
+ devices = self._get_devices_from(colocate_with)
+
+ tower_local = kwargs.pop("tower_local_reduce_method", None)
+ if tower_local is not None:
+ kwargs["trainable"] = False
+
+ # TODO(josh11b,apassos): It would be better if variable initialization
+ # was never recorded on the tape instead of having to do this manually
+ # here.
+ with tape.stop_recording():
+ index = {}
+ for i, d in enumerate(devices):
+ with ops.device(d):
+ if i > 0:
+ # Give replicas meaningful distinct names:
+ var0name = index[devices[0]].name.split(":")[0]
+ kwargs["name"] = "%s/replica_%d" % (var0name, i)
+ # Initialize replicas with the same value:
+ if context.executing_eagerly():
+ initial_value = index[devices[0]].value()
+ else:
+ initial_value = index[devices[0]].initial_value
+ kwargs["initial_value"] = array_ops.identity(initial_value)
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ v = next_creator(*args, **kwargs)
+ assert not isinstance(v, values.DistributedVariable)
+ index[d] = v
+
+ if tower_local is None:
+ result = values.MirroredVariable(index, index[devices[0]])
+ else:
+ result = values.TowerLocalVariable(
+ index, index[devices[0]], tower_local)
+
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the member variables
+ # to the TRAINABLE_VARIABLES collection, so we manually remove
+ # them and replace with the MirroredVariable. We can't set
+ # "trainable" to False for next_creator() since that causes functions
+ # like implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ for v in index.values():
+ l.remove(v)
+ g.add_to_collections(collections, result)
+ return result
+
+ def distribute_dataset(self, dataset):
+ per_device_dataset = values.PerDeviceDataset(
+ dataset, self._devices, self._prefetch_on_device)
+ return per_device_dataset.make_one_shot_iterator()
+
+ def _broadcast(self, tensor, destinations):
+ # TODO(josh11b): In eager mode, use one thread per device, or async mode.
+ return self.cross_tower_ops.broadcast(tensor, destinations or self._devices)
+
+ def _call_for_each_tower(self, fn, *args, **kwargs):
+ """Run `fn` in separate threads, once per tower/worker device.
+
+ Args:
+ fn: function to run (will be run once per device, each in its own thread).
+ *args: positional arguments for `fn`
+ **kwargs: keyword arguments for `fn`.
+ `"run_concurrently"`: Boolean indicating whether executions of `fn`
+ can be run concurrently (under eager execution only), defaults to
+ `True`.
+
+ Returns:
+ Merged return value of `fn` across all towers.
+
+ Raises:
+ RuntimeError: If fn() calls get_tower_context().merge_call() a different
+ number of times for when called for different devices.
+ """
+ run_concurrently = kwargs.pop("run_concurrently", True)
+ if not context.executing_eagerly():
+ # Lots of TF library code isn't thread-safe in graph mode, and
+ # there is little to be gained by turning on multithreading when
+ # constructing a graph.
+ run_concurrently = False
+ # Needed for per-thread device, etc. contexts in graph mode.
+ ops.get_default_graph().switch_to_thread_local()
+ elif run_concurrently is None:
+ run_concurrently = True
+
+ coord = coordinator.Coordinator(
+ clean_stop_exception_types=(_RequestedStop,))
+
+ shared_variable_store = {}
+
+ # TODO(isaprykin): Create these threads once instead of during every run()
+ # call.
+ threads = []
+ for index, d in enumerate(self._devices):
+ variable_creator_fn = shared_variable_creator.make_fn(
+ shared_variable_store, index)
+ t = MirroredStrategy._MirroredTowerThread(
+ self, coord, d, variable_creator_fn, fn,
+ *values.select_device(d, args), **values.select_device(d, kwargs))
+ threads.append(t)
+
+ for t in threads:
+ t.start()
+
+ # When `fn` starts `should_run` event is set on _MirroredTowerThread
+ # (`MTT`) threads. The execution waits until
+ # `MTT.has_paused` is set, which indicates that either `fn` is
+ # complete or a `get_tower_context().merge_call()` is called. If `fn` is
+ # complete, then `MTT.done` is set to True. Otherwise, arguments
+ # of `get_tower_context().merge_call` from all paused threads are grouped
+ # and the `merge_fn` is performed. Results of the
+ # `get_tower_context().merge_call` are then set to `MTT.merge_result`.
+ # Each such `get_tower_context().merge_call` call returns the
+ # `MTT.merge_result` for that thread when `MTT.should_run` event
+ # is reset again. Execution of `fn` resumes.
+
+ try:
+ with coord.stop_on_exception():
+ all_done = False
+ while not all_done and not coord.should_stop():
+ done = []
+ if run_concurrently:
+ for t in threads:
+ t.should_run.set()
+ for t in threads:
+ t.has_paused.wait()
+ t.has_paused.clear()
+ if coord.should_stop():
+ return None
+ done.append(t.done)
+ else:
+ for t in threads:
+ t.should_run.set()
+ t.has_paused.wait()
+ t.has_paused.clear()
+ if coord.should_stop():
+ return None
+ done.append(t.done)
+ if coord.should_stop():
+ return None
+ all_done = all(done)
+ if not all_done:
+ if any(done):
+ raise RuntimeError("Some towers made a different number of "
+ "tower_context().merge_call() calls.")
+ # get_tower_context().merge_call() case
+ merge_args = values.regroup(
+ {t.device: t.merge_args for t in threads})
+ merge_kwargs = values.regroup(
+ {t.device: t.merge_kwargs for t in threads})
+ merge_result = threads[0].merge_fn(
+ self, *merge_args, **merge_kwargs)
+ for t in threads:
+ t.merge_result = values.select_device(t.device, merge_result)
+ finally:
+ for t in threads:
+ t.should_run.set()
+ coord.join(threads)
+
+ return values.regroup({t.device: t.main_result for t in threads})
+
+ def map(self, map_over, fn, *args, **kwargs):
+ # TODO(josh11b): In eager mode, use one thread per device.
+ index = {}
+ i = 0
+ for m in map_over:
+ d = self._devices[i % len(self._devices)]
+ with ops.device(d):
+ l = index.get(d, [])
+ l.append(fn(m,
+ *values.select_device_mirrored(d, args),
+ **values.select_device_mirrored(d, kwargs)))
+ index[d] = l
+ # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput
+ # in addition to PerDevice data.
+ return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()})
+
+ def _reduce(self, method_string, value, destinations):
+ if len(self._devices) == 1 and not isinstance(value, values.PerDevice):
+ value = values.PerDevice({self._devices[0]: value})
+ assert isinstance(value, values.PerDevice)
+ return self.cross_tower_ops.reduce(
+ method_string, value, destinations=destinations)
+
+ def _batch_reduce(self, method_string, value_destination_pairs):
+ return self.cross_tower_ops.batch_reduce(method_string,
+ value_destination_pairs)
+
+ def _update(self, var, fn, *args, **kwargs):
+ # TODO(josh11b): Also support TowerLocalVariables here? If so, args and
+ # kwargs don't need to be mirrored.
+ assert isinstance(var, values.MirroredVariable)
+ # TODO(josh11b): In eager mode, use one thread per device.
+ updates = {}
+ for d, v in var._index.items(): # pylint: disable=protected-access
+ name = "update_%d" % self._device_index.get(d)
+ with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
+ updates[d] = fn(v,
+ *values.select_device_mirrored(d, args),
+ **values.select_device_mirrored(d, kwargs))
+ return values.regroup(updates, values.Mirrored)
+
+ def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ assert isinstance(colocate_with, list)
+ # TODO(josh11b): In eager mode, use one thread per device.
+ updates = {}
+ for d in colocate_with:
+ name = "update_%d" % self._device_index.get(d)
+ with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
+ updates[d] = fn(*values.select_device_mirrored(d, args),
+ **values.select_device_mirrored(d, kwargs))
+ return values.regroup(updates, values.Mirrored)
+
+ def _fetch(self, val, destination, fn):
+ """Return a copy of `val` or `fn(val)` on `destination`."""
+ assert isinstance(destination, six.string_types)
+ if isinstance(val, values.TowerLocalVariable):
+ val = self.reduce(val.reduce_method, val, destinations=destination)
+ with ops.device(destination):
+ return fn(self.unwrap(val)[0])
+
+ assert isinstance(val, values.Mirrored), (
+ "val = %s (type %s)" % (val, val.__class__.__name__))
+ if val.on_device(destination):
+ with ops.device(destination):
+ # Use an identity here to make sure we are returning a tensor
+ # instead of e.g. a variable object.
+ return array_ops.identity(fn(val.get(destination)))
+ device = None
+ for d in self._devices:
+ if val.on_device(d):
+ device = d
+ break
+ assert device is not None, (
+ "Could not find destination %s in list of devices %s." %
+ (destination, val.devices))
+ with ops.device(device):
+ v = fn(val.get(device))
+ with ops.device(destination):
+ return array_ops.identity(v)
+
+ def _unwrap(self, val):
+ if isinstance(val, values.DistributedValues):
+ # Return in a deterministic order.
+ if set(val.devices) == self._canonical_device_set:
+ return [val.get(device=d) for d in self._devices]
+ return [val.get(device=d) for d in sorted(val.devices)]
+ return [val]
+
+ @property
+ def is_single_tower(self):
+ return len(self._devices) == 1
+
+ @property
+ def num_towers(self):
+ return len(self._devices)
+
+ def _worker_device_index(self):
+ return self._device_index
+
+ @property
+ def worker_devices(self):
+ # Make a copy to prevent users from accidentally mutating our copy.
+ return list(self._devices)
+
+ @property
+ def parameter_devices(self):
+ return list(self._devices)
+
+ def non_slot_devices(self, var_list):
+ del var_list
+ return list(self._devices)
+
+ def _get_devices_from(self, colocate_with=None):
+ if colocate_with is None:
+ return self._devices
+ elif isinstance(colocate_with, values.DistributedValues):
+ # pylint: disable=protected-access
+ return list(colocate_with._index.keys())
+ elif isinstance(colocate_with, six.string_types):
+ return [colocate_with]
+ else:
+ return colocate_with
+
+ class _MirroredTowerThread(threading.Thread):
+ """A thread that runs() a function on a device."""
+
+ def __init__(self, dist, coord, device, variable_creator_fn, fn, *args,
+ **kwargs):
+ super(MirroredStrategy._MirroredTowerThread, self).__init__() # pylint: disable=protected-access
+ self.coord = coord
+ self.distribution = dist
+ self.device = device
+ self.tower_id = dist.worker_devices.index(device)
+ self.variable_creator_fn = variable_creator_fn
+ # State needed to run and return the results of `fn`.
+ self.main_fn = fn
+ self.main_args = args
+ self.main_kwargs = kwargs
+ self.main_result = None
+ self.done = False
+ # State needed to run the next merge_call() (if any) requested via
+ # TowerContext.
+ self.merge_fn = None
+ self.merge_args = None
+ self.merge_kwargs = None
+ self.merge_result = None
+ # We use a thread.Event for the main thread to signal when this
+ # thread should start running (`should_run`), and another for
+ # this thread to transfer control back to the main thread
+ # (`has_paused`, either when it gets to a
+ # `get_tower_context().merge_call` or when `fn` returns). In
+ # either case the event starts cleared, is signaled by calling
+ # set(). The receiving thread waits for the signal by calling
+ # wait() and then immediately clearing the event using clear().
+ self.should_run = threading.Event()
+ self.has_paused = threading.Event()
+ # These fields have to do with inheriting various contexts from the
+ # parent thread:
+ # pylint: disable=protected-access
+ self.context_mode = context.context()._eager_context.mode
+ if not context.context()._context_handle:
+ context.context()._initialize_handle_and_devices()
+ self.context_device_policy = (
+ pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
+ context.context()._context_handle))
+ self.graph = ops.get_default_graph()
+ self._variable_creator_stack = self.graph._variable_creator_stack[:]
+ self._captured_var_scope = variable_scope.get_variable_scope()
+ # Adding a "/" at end lets us re-enter this scope later.
+ self._captured_name_scope = self.graph.get_name_scope()
+ if self._captured_name_scope:
+ self._captured_name_scope += "/"
+ if self.tower_id > 0:
+ if not self._captured_name_scope:
+ self._captured_name_scope = ""
+ self._captured_name_scope += "tower_%d/" % self.tower_id
+
+ def run(self):
+ # pylint: disable=protected-access
+ self.graph._variable_creator_stack = self._variable_creator_stack
+ self.should_run.wait()
+ self.should_run.clear()
+ try:
+ if self.coord.should_stop():
+ return
+ with self.coord.stop_on_exception(), \
+ context.context()._mode(self.context_mode), \
+ context.context().device_policy(self.context_device_policy), \
+ self.graph.as_default(), \
+ MirroredTowerContext(self.distribution, self.tower_id), \
+ ops.device(self.device), \
+ ops.name_scope(self._captured_name_scope), \
+ variable_scope.variable_scope(
+ self._captured_var_scope, reuse=self.tower_id > 0), \
+ variable_scope.variable_creator_scope(self.variable_creator_fn):
+ self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
+ self.done = True
+ finally:
+ self.has_paused.set()
+
+
+class MirroredTowerContext(distribute_lib.TowerContext):
+ """TowerContext used in MirroredStrategy.call_for_each_tower().
+
+ Opened in `_MirroredTowerThread`, to allow the user to invoke
+ `MirroredStrategy`'s specific implementation of `merge_call()`,
+ which works by delegating the function and its arguments to
+ the main thread (the one that invoked
+ `MirroredStrategy.call_for_each_tower()`).
+ """
+
+ def _merge_call(self, fn, *args, **kwargs):
+ """Delegate to the main thread to actually perform merge_call()."""
+ t = threading.current_thread() # a _MirroredTowerThread
+ t.merge_fn = fn
+ t.merge_args = args
+ t.merge_kwargs = kwargs
+ t.has_paused.set()
+ t.should_run.wait()
+ t.should_run.clear()
+ if t.coord.should_stop():
+ raise _RequestedStop()
+ return t.merge_result
+
+ @property
+ def device(self):
+ distribute_lib.require_tower_context(self)
+ return self._distribution_strategy.worker_devices[self._tower_id]
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
new file mode 100644
index 0000000000..9e9f06da8e
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -0,0 +1,435 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Multi-GPU tests for MirroredStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import strategy_test_lib
+from tensorflow.contrib.distribute.python import values
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.layers import core
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.training import distribute as distribute_lib
+
+GPU_TEST = "test_gpu" in sys.argv[0]
+
+
+class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
+
+ def _get_distribution_strategy(self):
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ if GPU_TEST:
+ self.assertGreater(context.num_gpus(), 0)
+ if context.num_gpus() > 1:
+ devices = ["/device:GPU:0", "/device:GPU:1"]
+ print(self.id().split(".")[-1], "devices:", ", ".join(devices))
+ return mirrored_strategy.MirroredStrategy(devices)
+
+ def testMinimizeLossEager(self):
+ if not GPU_TEST:
+ self.skipTest("Not GPU test")
+ self._test_minimize_loss_eager(self._get_distribution_strategy())
+
+ def testMinimizeLossGraph(self):
+ soft_placement = not GPU_TEST
+ print("testMinimizeLossGraph soft_placement:", soft_placement)
+ self._test_minimize_loss_graph(
+ self._get_distribution_strategy(), soft_placement=soft_placement)
+
+ def testMapReduce(self):
+ if not GPU_TEST:
+ self.skipTest("Not GPU test")
+ self._test_map_reduce(self._get_distribution_strategy())
+
+ def testDeviceIndex(self):
+ if not GPU_TEST:
+ self.skipTest("Not GPU test")
+ self._test_device_index(self._get_distribution_strategy())
+
+ def testTowerId(self):
+ if not GPU_TEST:
+ self.skipTest("Not GPU test")
+ self._test_tower_id(self._get_distribution_strategy())
+
+ def testNumTowers(self):
+ if not GPU_TEST:
+ self.skipTest("Not GPU test")
+ self.assertEqual(2, self._get_distribution_strategy().num_towers)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testCallAndMergeExceptions(self):
+ if not GPU_TEST:
+ self.skipTest("Not GPU test")
+ self._test_call_and_merge_exceptions(self._get_distribution_strategy())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testRunRegroupError(self):
+
+ def run_fn(device_id):
+ # Generates a list with different lengths on different devices.
+ # Will fail in _regroup() (if more than one device).
+ return list(range(device_id))
+
+ dist = self._get_distribution_strategy()
+ with dist.scope(), self.assertRaises(AssertionError):
+ dist.call_for_each_tower(run_fn, dist.worker_device_index)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testReduceToCpu(self):
+ if not GPU_TEST:
+ self.skipTest("Not GPU test")
+
+ def run_fn(device_id):
+ return device_id
+
+ dist = self._get_distribution_strategy()
+ with dist.scope():
+ result = dist.call_for_each_tower(run_fn, dist.worker_device_index)
+ reduced = dist.reduce("sum", result, destinations="/device:CPU:0")
+ unwrapped = dist.unwrap(reduced)
+ self.assertEqual(1, len(unwrapped))
+ expected = sum(range(len(dist.worker_devices)))
+ self.assertEqual(expected, self.evaluate(unwrapped[0]))
+
+
+@test_util.with_c_api
+class MirroredStrategyVariableCreationTest(test.TestCase):
+
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = True
+
+ def _skip_eager_if_gpus_less_than(self, num_gpus):
+ if context.num_gpus() < num_gpus and context.executing_eagerly():
+ self.skipTest("Enough GPUs not available for this test in eager mode.")
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testSingleVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def model_fn():
+ # This variable should be created only once across the threads because of
+ # special variable_creator functions used by `dist.call_for_each_tower`.
+ v = variable_scope.variable(1.0, name="foo")
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ return v
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ self.assertIsInstance(result, values.MirroredVariable)
+ self.assertEquals("foo:0", result.name)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testUnnamedVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def model_fn():
+ v = variable_scope.variable(1.0)
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ return v
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ self.assertIsInstance(result, values.MirroredVariable)
+ # Default name of "Variable" will be used.
+ self.assertEquals("Variable:0", result.name)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testMultipleVariables(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def model_fn():
+ vs = []
+ for i in range(5):
+ vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ return vs
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ for i, v in enumerate(result):
+ self.assertIsInstance(v, values.MirroredVariable)
+ self.assertEquals("foo" + str(i) + ":0", v.name)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testMultipleVariablesWithSameCanonicalName(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def model_fn():
+ vs = []
+ vs.append(variable_scope.variable(1.0, name="foo/bar"))
+ vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
+ vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
+ vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ return vs
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ for v in result:
+ self.assertIsInstance(v, values.MirroredVariable)
+ self.assertEquals(4, len(result))
+ self.assertEquals("foo/bar:0", result[0].name)
+ self.assertEquals("foo_1/bar:0", result[1].name)
+ self.assertEquals("foo_1/bar_1:0", result[2].name)
+ self.assertEquals("foo/bar_1:0", result[3].name)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testVariableWithSameCanonicalNameAcrossThreads(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def model_fn(device_id):
+ v = variable_scope.variable(1.0, name="foo_" + str(device_id))
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ return v
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ result = dist.call_for_each_tower(
+ model_fn, dist.worker_device_index, run_concurrently=False)
+ self.assertIsInstance(result, values.MirroredVariable)
+ # The resulting mirrored variable will use the name from the first device.
+ self.assertEquals("foo_0:0", result.name)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testWithLayers(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def model_fn(features):
+ with variable_scope.variable_scope("common"):
+ layer1 = core.Dense(1)
+ layer1(features)
+ layer2 = core.Dense(1)
+ layer2(features)
+ # This will pause the current thread, and execute the other thread.
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ layer3 = core.Dense(1)
+ layer3(features)
+ return [(layer1.kernel, layer1.bias),
+ (layer2.kernel, layer2.bias),
+ (layer3.kernel, layer3.bias)]
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+ features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
+ features = dist.distribute_dataset(features).get_next()
+
+ with dist.scope():
+ result = dist.call_for_each_tower(
+ model_fn, features, run_concurrently=False)
+ suffixes = ["", "_1", "_2"]
+ for (kernel, bias), suffix in zip(result, suffixes):
+ self.assertIsInstance(kernel, values.MirroredVariable)
+ self.assertEquals("common/dense" + suffix + "/kernel:0", kernel.name)
+ self.assertIsInstance(bias, values.MirroredVariable)
+ self.assertEquals("common/dense" + suffix + "/bias:0", bias.name)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testWithGetVariableAndVariableScope(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def model_fn():
+ v0 = variable_scope.get_variable("var-thread0", [1])
+ with variable_scope.variable_scope("common"):
+ v1 = variable_scope.get_variable("var-thread1", [1])
+ # This will pause the current thread, and execute the other thread.
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ v2 = variable_scope.get_variable("var-thread2", [1])
+
+ return v0, v1, v2
+
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with variable_scope.variable_scope("main"):
+ v = variable_scope.get_variable("var-main0", [1])
+ self.assertEquals("main/var-main0:0", v.name)
+
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ self.assertEquals(3, len(result))
+ v0, v1, v2 = result
+ self.assertIsInstance(v0, values.MirroredVariable)
+ self.assertEquals("main/var-thread0:0", v0.name)
+ self.assertIsInstance(v1, values.MirroredVariable)
+ self.assertEquals("main/common/var-thread1:0", v1.name)
+ self.assertIsInstance(v2, values.MirroredVariable)
+ self.assertEquals("main/common/var-thread2:0", v2.name)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testThreeDevices(self):
+ self._skip_eager_if_gpus_less_than(2)
+
+ def model_fn():
+ v = variable_scope.variable(1.0, name="foo")
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ return v
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"])
+
+ with dist.scope():
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ self.assertIsInstance(result, values.MirroredVariable)
+ self.assertEquals("foo:0", result.name)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testNonMatchingVariableCreation(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def model_fn(name):
+ v = variable_scope.variable(1.0, name=name)
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ return v
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ names = values.DistributedValues({
+ "/device:CPU:0": "foo",
+ "/device:GPU:0": "bar"
+ })
+ with self.assertRaises(RuntimeError):
+ _ = dist.call_for_each_tower(model_fn, names, run_concurrently=False)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testTowerLocalVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ all_v_sum = {}
+ all_v_mean = {}
+
+ def model_fn(device_id):
+ tower_context = distribute_lib.get_tower_context()
+ with tower_context.tower_local_var_scope("sum"):
+ v_sum = variable_scope.variable(1.0)
+ with tower_context.tower_local_var_scope("mean"):
+ v_mean = variable_scope.variable(4.0)
+ self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
+ self.assertTrue(isinstance(v_mean, values.TowerLocalVariable))
+ updates = [v_sum.assign_add(2.0 + device_id),
+ v_mean.assign(6.0 * device_id)]
+ all_v_sum[device_id] = v_sum
+ all_v_mean[device_id] = v_mean
+ return updates, v_sum, v_mean
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ # Create "sum" and "mean" versions of TowerLocalVariables.
+ ret_ops, ret_v_sum, ret_v_mean = dist.call_for_each_tower(
+ model_fn, dist.worker_device_index, run_concurrently=False)
+ # Should see the same wrapping instance in all towers.
+ self.assertIs(all_v_sum[0], ret_v_sum)
+ self.assertIs(all_v_mean[0], ret_v_mean)
+ for i in range(1, dist.num_towers):
+ self.assertIs(all_v_sum[0], all_v_sum[1])
+ self.assertIs(all_v_mean[0], all_v_mean[1])
+
+ # Apply updates
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate([y for x in ret_ops for y in dist.unwrap(x)])
+ expected_sum = 0.0
+ expected_mean = 0.0
+ for i, d in enumerate(dist.worker_devices):
+ # Test access within a device scope, should see different values.
+ with ops.device(d):
+ v_sum_value = self.evaluate(ret_v_sum.read_value())
+ v_mean_value = self.evaluate(ret_v_mean.read_value())
+ expected = i + 3.0
+ self.assertEqual(expected, v_sum_value)
+ expected_sum += expected
+ expected = i * 6.0
+ self.assertEqual(expected, v_mean_value)
+ expected_mean += expected
+
+ # fetch() should return the value you get by applying the
+ # reduction across all towers.
+ self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum)))
+ expected_mean /= len(dist.worker_devices)
+ self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean)))
+
+ # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not
+ # testing this in eager mode.
+
+ def testNameScope(self):
+ def model_fn():
+ with ops.name_scope("foo"):
+ a = constant_op.constant(1.0, name="a")
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ b = constant_op.constant(1.0, name="b")
+ return a, b
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with context.graph_mode(), dist.scope():
+ with ops.name_scope("main"):
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ self.assertEquals(2, len(result))
+ for v, name in zip(result, ["a", "b"]):
+ self.assertIsInstance(v, values.DistributedValues)
+ v0, v1 = dist.unwrap(v)
+ self.assertEquals("main/foo/" + name + ":0", v0.name)
+ self.assertEquals("main/tower_1/foo/" + name + ":0", v1.name)
+
+ def testWithDefaultName(self):
+ def model_fn():
+ with ops.name_scope(None, "foo"):
+ a = constant_op.constant(1.0, name="a")
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ b = constant_op.constant(2.0, name="b")
+ return a, b
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with context.graph_mode(), dist.scope():
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ self.assertEquals(2, len(result))
+ for v, name in zip(result, ["a", "b"]):
+ self.assertIsInstance(v, values.DistributedValues)
+ v0, v1 = dist.unwrap(v)
+ self.assertEquals("foo/" + name + ":0", v0.name)
+ self.assertEquals("tower_1/foo/" + name + ":0", v1.name)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
new file mode 100644
index 0000000000..a1ef0ecc77
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
@@ -0,0 +1,91 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for class MirroredStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import strategy_test_lib
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import distribute as distribute_lib
+
+
+@test_util.with_c_api
+class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
+
+ def _get_distribution_strategy(self):
+ return mirrored_strategy.MirroredStrategy(["/device:CPU:0"])
+
+ def testMinimizeLossEager(self):
+ self._test_minimize_loss_eager(self._get_distribution_strategy())
+
+ def testMinimizeLossGraph(self):
+ self._test_minimize_loss_graph(self._get_distribution_strategy())
+
+ def testMapReduce(self):
+ self._test_map_reduce(self._get_distribution_strategy())
+
+ def testDeviceIndex(self):
+ self._test_device_index(self._get_distribution_strategy())
+
+ def testTowerId(self):
+ self._test_tower_id(self._get_distribution_strategy())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testCallAndMergeExceptions(self):
+ self._test_call_and_merge_exceptions(self._get_distribution_strategy())
+
+
+@test_util.with_c_api
+class VariableCreatorStackTest(test.TestCase):
+
+ def testCreatorStacksAreThreadLocal(self):
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+
+ def model_fn(device_id):
+ assert isinstance(device_id, int)
+ def thread_creator_fn(next_creator, *args, **kwargs):
+ return next_creator(*args, **kwargs) + ":thread_" + str(device_id)
+
+ with variable_scope.variable_creator_scope(thread_creator_fn):
+ # Create a variable in this scope.
+ v = variable_scope.variable(1.0)
+
+ # This will pause the current thread, and execute the other thread.
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ return v
+
+ def main_thread_creator(next_creator, *args, **kwargs):
+ # We are not using the underlying next_creator for test purposes.
+ del next_creator, args, kwargs
+ return "main_thread"
+
+ with context.graph_mode(), \
+ dist.scope(), \
+ variable_scope.variable_creator_scope(main_thread_creator):
+ result = dist.call_for_each_tower(model_fn, dist.worker_device_index)
+ result = dist.unwrap(result)
+ expected = ["main_thread:thread_0", "main_thread:thread_1"]
+ self.assertEquals(expected, result)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py
new file mode 100644
index 0000000000..fe80bb4df5
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/monitor.py
@@ -0,0 +1,61 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Monitor is responsible for training, checkpointing and recovery."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.ops import variables
+
+
+class Monitor(object):
+ """Executes training steps, recovers and checkpoints.
+
+ Note that this class is particularly preliminary, experimental, and
+ expected to change.
+ """
+ # TODO(isaprykin): Support step functions that need multiple session calls.
+ # TODO(isaprykin): Support extra arguments to the step function.
+ # TODO(isaprykin): Support recovery, checkpointing and summaries.
+
+ def __init__(self, step_callable, session=None):
+ """Initialize the Monitor with components for executing training steps.
+
+ Args:
+ step_callable: a training `Step` that's capable of signaling when done.
+ session: a `Session` instance that's needed for graph mode.
+
+ Raises:
+ ValueError: if `session` was provided for eager mode or not provided for
+ graph mode.
+ """
+ if context.executing_eagerly():
+ if session is not None:
+ raise ValueError("Should not provide a `session` in Eager mode.")
+ self._run_step = step_callable
+ else:
+ if session is None:
+ raise ValueError("Should provide a `session` in Graph mode.")
+ self._run_step = session.make_callable(step_callable())
+ session.run(variables.global_variables_initializer())
+
+ def run_steps(self, num_steps=None):
+ step = 0
+ done = False
+ while done is not None and (num_steps is None or step < num_steps):
+ done = self._run_step()
+ step += 1
diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py
new file mode 100644
index 0000000000..8277e1e791
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/monitor_test.py
@@ -0,0 +1,84 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for class Monitor."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import monitor as monitor_lib
+from tensorflow.contrib.distribute.python import one_device_strategy
+from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import ops
+from tensorflow.python.training import gradient_descent
+
+
+class MonitorTest(test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(
+ combinations.times(
+ combinations.distributions_and_v1_optimizers(),
+ combinations.combine(mode=combinations.graph_and_eager_modes)))
+ def testTrainNetwork(self, distribution, optimizer_fn):
+ with distribution.scope():
+ single_loss_step, layer = single_loss_example(optimizer_fn, distribution)
+
+ if context.executing_eagerly():
+ monitor = monitor_lib.Monitor(single_loss_step, None)
+ else:
+ with self.test_session() as sess:
+ monitor = monitor_lib.Monitor(single_loss_step, sess)
+
+ monitor.run_steps(1)
+
+ self.assertEqual(1, len(layer.trainable_variables))
+ mirrored_weight_variable = layer.trainable_variables[0]
+ start_error = self.evaluate(distribution.fetch(mirrored_weight_variable))
+ start_error = abs(numpy.array(start_error) - 1)
+
+ monitor.run_steps(9)
+ end_error = self.evaluate(distribution.fetch(mirrored_weight_variable))
+ end_error = abs(numpy.array(end_error) - 1)
+ self.assertGreaterEqual(start_error, end_error)
+
+ def testPassingASessionInEager(self):
+ distribution = one_device_strategy.OneDeviceStrategy(
+ "/device:CPU:0")
+ step_function, _ = single_loss_example(
+ lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution)
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(ValueError, "Should not provide"):
+ _ = monitor_lib.Monitor(step_function, sess)
+
+ def testNotPassingASessionInGraph(self):
+ distribution = one_device_strategy.OneDeviceStrategy(
+ "/device:CPU:0")
+ step_function, _ = single_loss_example(
+ lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution)
+
+ with context.graph_mode(), ops.Graph().as_default():
+ with self.assertRaisesRegexp(ValueError, "Should provide"):
+ _ = monitor_lib.Monitor(step_function, session=None)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
new file mode 100644
index 0000000000..39c49442b9
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -0,0 +1,148 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Class OneDeviceStrategy implementing DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.contrib.distribute.python import values
+from tensorflow.contrib.eager.python import datasets
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import distribute as distribute_lib
+
+
+# TODO(josh11b): Replace asserts in this file with if ...: raise ...
+
+
+class OneDeviceStrategy(distribute_lib.DistributionStrategy):
+ """A distribution strategy for running on a single device."""
+ # TODO(josh11b): Do we wrap values in types to generate errors if you are
+ # doing something that won't work with other DistributionStrategy
+ # implementations?
+
+ def __init__(self, device):
+ super(OneDeviceStrategy, self).__init__()
+ self._device = device
+
+ def _create_variable(self, next_creator, *args, **kwargs):
+ # No need to distinguish tower-local variables when not mirroring,
+ # we just enforce that they are not trainable.
+ if kwargs.pop("tower_local_reduce_method", None) is not None:
+ kwargs["trainable"] = False
+
+ colocate_with = kwargs.pop("colocate_with", None)
+ if colocate_with is None:
+ with ops.device(self._device):
+ return next_creator(*args, **kwargs)
+ if isinstance(colocate_with, six.string_types):
+ with ops.device(colocate_with):
+ return next_creator(*args, **kwargs)
+ if (isinstance(colocate_with, list) and len(colocate_with) == 1 and
+ isinstance(colocate_with[0], six.string_types)):
+ with ops.device(colocate_with[0]):
+ return next_creator(*args, **kwargs)
+ with ops.colocate_with(colocate_with):
+ return next_creator(*args, **kwargs)
+
+ def distribute_dataset(self, dataset):
+ if context.executing_eagerly():
+ return datasets.Iterator(dataset)
+ else:
+ return dataset.make_one_shot_iterator()
+
+ def _broadcast(self, tensor, destinations):
+ return tensor
+
+ def _call_for_each_tower(self, fn, *args, **kwargs):
+ # We don't run `fn` in multiple threads in OneDeviceStrategy.
+ kwargs.pop("run_concurrently", None)
+ with ops.device(self._device), _OneDeviceTowerContext(self):
+ return fn(*args, **kwargs)
+
+ def map(self, map_over, fn, *args, **kwargs):
+ with ops.device(self._device):
+ return values.MapOutput([fn(m, *args, **kwargs) for m in map_over])
+
+ def _reduce(self, method_string, value, destinations):
+ if not isinstance(value, values.MapOutput):
+ return value
+ l = value.get()
+ assert l
+ with ops.device(self._device):
+ if method_string == "sum":
+ return math_ops.add_n(l)
+ elif method_string == "mean":
+ return math_ops.add_n(l) / len(l)
+ else:
+ assert False
+
+ def _update(self, var, fn, *args, **kwargs):
+ with ops.device(self._device), distribute_lib.UpdateContext(self._device):
+ return fn(var, *args, **kwargs)
+
+ def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ del colocate_with
+ with ops.device(self._device), distribute_lib.UpdateContext(self._device):
+ return fn(*args, **kwargs)
+
+ def _fetch(self, val, destination, fn):
+ """Return a copy of `val` or `fn(val)` on `destination`."""
+ with ops.device(self._device):
+ v = fn(val)
+ with ops.device(destination):
+ return array_ops.identity(v)
+
+ def _unwrap(self, value):
+ return [value]
+
+ @property
+ def is_single_tower(self):
+ return True
+
+ @property
+ def num_towers(self):
+ return 1
+
+ @property
+ def worker_devices(self):
+ return [self._device]
+
+ @property
+ def parameter_devices(self):
+ return [self._device]
+
+ def non_slot_devices(self, var_list):
+ del var_list
+ return [self._device]
+
+ def _worker_device_index(self):
+ return 0
+
+
+class _OneDeviceTowerContext(distribute_lib.TowerContext):
+
+ def __init__(self, distribution_strategy):
+ distribute_lib.TowerContext.__init__(
+ self, distribution_strategy, tower_id=0)
+
+ @property
+ def device(self):
+ return self._distribution_strategy.worker_devices[0]
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py
new file mode 100644
index 0000000000..7101ed0756
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py
@@ -0,0 +1,54 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for class OneDeviceStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distribute.python import one_device_strategy
+from tensorflow.contrib.distribute.python import strategy_test_lib
+from tensorflow.python.eager import test
+from tensorflow.python.framework import test_util
+
+
+@test_util.with_c_api
+class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase):
+
+ def _get_distribution_strategy(self):
+ return one_device_strategy.OneDeviceStrategy("/device:CPU:0")
+
+ def testMinimizeLossEager(self):
+ self._test_minimize_loss_eager(self._get_distribution_strategy())
+
+ def testMinimizeLossGraph(self):
+ self._test_minimize_loss_graph(self._get_distribution_strategy())
+
+ def testMapReduce(self):
+ self._test_map_reduce(self._get_distribution_strategy())
+
+ def testDeviceIndex(self):
+ self._test_device_index(self._get_distribution_strategy())
+
+ def testTowerId(self):
+ self._test_tower_id(self._get_distribution_strategy())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testCallAndMergeExceptions(self):
+ self._test_call_and_merge_exceptions(self._get_distribution_strategy())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
new file mode 100644
index 0000000000..a0912b625f
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -0,0 +1,70 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for running legacy optimizer code with DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variables
+
+
+class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(
+ combinations.times(
+ combinations.distributions_and_v2_optimizers(),
+ combinations.combine(mode=["graph"], use_callable_loss=[True, False])
+ + combinations.combine(mode=["eager"], use_callable_loss=[True])))
+ def testTrainNetwork(self, distribution, optimizer_fn,
+ use_callable_loss=True):
+ with distribution.scope():
+ model_fn, dataset, layer = minimize_loss_example(
+ optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
+
+ iterator = distribution.distribute_dataset(dataset)
+
+ def run_step():
+ return control_flow_ops.group(distribution.unwrap(
+ distribution.call_for_each_tower(
+ model_fn, iterator.get_next(), run_concurrently=layer.built)))
+
+ if not context.executing_eagerly():
+ with self.test_session() as sess:
+ run_step = sess.make_callable(run_step())
+ self.evaluate(variables.global_variables_initializer())
+
+ weights, biases = [], []
+ for _ in range(10):
+ run_step()
+
+ weights.append(self.evaluate(distribution.fetch(layer.kernel)))
+ biases.append(self.evaluate(distribution.fetch(layer.bias)))
+
+ error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
+ is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
+ self.assertTrue(is_not_increasing)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
new file mode 100644
index 0000000000..b9ffd2f266
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
@@ -0,0 +1,167 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Extension of prefetching_ops to support more than one device."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
+from tensorflow.contrib.data.python.ops import prefetching_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.util import nest as data_nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.util import nest
+
+
+# pylint: disable=protected-access
+class _PrefetchToDeviceIterator(object):
+ """A replacement for @{tf.data.Iterator} that prefetches to another device."""
+
+ def __init__(self, input_dataset, devices, buffer_size):
+ self._input_dataset = input_dataset
+ self._get_next_call_count = 0
+ self._devices = devices
+ input_iterator = input_dataset.make_one_shot_iterator()
+ input_iterator_handle = input_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, input_iterator.output_types, input_iterator.output_shapes,
+ input_iterator.output_classes)
+ return remote_iterator.get_next()
+
+ target_device = gen_dataset_ops.iterator_get_device(
+ input_iterator._iterator_resource)
+ self._buffering_resources = []
+ for device in nest.flatten(self._devices):
+ with ops.device(device):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_prefetch_fn,
+ target_device=target_device,
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ thread_pool_size=0)
+ self._buffering_resources.append(buffer_resource_handle)
+
+ def get_next(self, name=None):
+ """See @{tf.data.Iterator.get_next}."""
+ self._get_next_call_count += 1
+ if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
+ warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
+
+ flat_result = []
+ # TODO(priyag): This will fail if the input size (typically number of
+ # batches) is not divisible by number of devices.
+ # How do we handle that more gracefully / let the user know?
+ for buffer_resource in self._buffering_resources:
+ flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
+ buffer_resource,
+ output_types=data_nest.flatten(sparse.as_dense_types(
+ self.output_types, self.output_classes)), name=name)
+
+ ret = sparse.deserialize_sparse_tensors(
+ data_nest.pack_sequence_as(self.output_types, flat_ret),
+ self.output_types, self.output_shapes, self.output_classes)
+
+ for tensor, shape in zip(
+ data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
+ if isinstance(tensor, ops.Tensor):
+ tensor.set_shape(shape)
+ flat_result.append(ret)
+
+ return nest.pack_sequence_as(self._devices, flat_result)
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+# pylint: enable=protected-access
+
+
+class _PrefetchToDeviceDataset(dataset_ops.Dataset):
+ """A `Dataset` whose iterator prefetches elements to other device(s)."""
+
+ def __init__(self, input_dataset, devices, buffer_size):
+ self._input_dataset = input_dataset
+ self._devices = devices
+ self._buffer_size = buffer_size if buffer_size is not None else 1
+
+ def make_one_shot_iterator(self):
+ return _PrefetchToDeviceIterator(self._input_dataset, self._devices,
+ self._buffer_size)
+
+ def make_initializable_iterator(self, shared_name=None):
+ raise NotImplementedError("`prefetch_to_devices()` is not currently "
+ "compatible with initializable iterators. Use "
+ "`make_one_shot_iterator()` instead.")
+
+ def _as_variant_tensor(self):
+ # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
+ # transformation methods is called.
+ # TODO(mrry): Investigate support for chaining further transformations after
+ # the prefetch, including GPU support.
+ raise NotImplementedError("`prefetch_to_devices()` must be the last "
+ "transformation in a dataset pipeline.")
+
+ # TODO(priyag): Fix the output types, shapes and classes to match the result
+ # of get_next (which has the additional nesting layer of devices now).
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+def prefetch_to_devices(devices, buffer_size=None):
+ """A transformation that prefetches dataset values to the given `devices`.
+
+ NOTE: Although the transformation creates a @{tf.data.Dataset}, the
+ transformation must be the final `Dataset` in the input pipeline.
+
+ Args:
+ devices: A nested structure of devices on which to prefetch the data. It can
+ be a single device name, or a tuple or list of device names.
+ buffer_size: (Optional.) The number of elements to buffer on each device.
+ Defaults to an automatically chosen value.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+ def _apply_fn(dataset):
+ return _PrefetchToDeviceDataset(dataset, devices, buffer_size)
+
+ return _apply_fn
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
new file mode 100644
index 0000000000..8ed16f4607
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
@@ -0,0 +1,68 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for prefetching_ops_v2."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distribute.python import prefetching_ops_v2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class PrefetchingOpsV2Test(test.TestCase):
+
+ def testPrefetchToOneDevice(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices("/gpu:0"))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToTwoDevicesInAList(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ output = []
+ with self.test_session() as sess:
+ for _ in range(5):
+ result = sess.run(next_element)
+ self.assertEqual(2, len(result))
+ output.extend(result)
+ self.assertEquals(set(range(10)), set(output))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator.py b/tensorflow/contrib/distribute/python/shared_variable_creator.py
new file mode 100644
index 0000000000..aca9c7af05
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/shared_variable_creator.py
@@ -0,0 +1,97 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility to re-use variables created on first device on subsequent devices."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+_VARIABLE_UNIQUIFYING_REGEX = re.compile(r"_\d/")
+_VARIABLE_UNIQUIFYING_REGEX_AT_END = re.compile(r"_\d$")
+
+
+def _canonicalize_variable_name(name):
+ # If no name is specified, uses default name "Variable".
+ if name is None:
+ return "Variable"
+ # Replace all instances of "_<num>/" with "/"
+ name = _VARIABLE_UNIQUIFYING_REGEX.sub("/", name)
+ # Replace any instances of "_<num>" at the end of the string with ""
+ name = _VARIABLE_UNIQUIFYING_REGEX_AT_END.sub("", name)
+ return name
+
+
+def make_fn(shared_variable_store, device_id):
+ """Construct the variable creator function for device `device_id`.
+
+ Constructs custom variable creator functions for the given device.
+ On first device (device_id == 0), it creates the variable using the
+ `next_creator`, and stores it in the provided `shared_variable_store`.
+ On all other devices (device_id > 0), it tries to re-use the variable
+ already created with the same name. If no such variable exists, it throws an
+ error.
+ Additionally, we de-uniquify variable names before checking for matches. This
+ helps re-use variables which are intended to be the same but have different
+ names due to variable uniquificaton happening upstream. Since this might
+ mean we may have multiple variables with the same canonical name, we store
+ them in a list per canonical name and return them in the same order as well.
+
+ Args:
+ shared_variable_store: A dictionary that we will use to store variables
+ created on the first device, and re-used by creators for other devices.
+ device_id: Integer index of the device whose creator should be
+ constructed.
+
+ Returns:
+ An appropriate creator function based on device_id.
+
+ """
+ variable_scope_access_index = {}
+ assert isinstance(device_id, int)
+
+ def create_new_variable(next_creator, *args, **kwargs):
+ """Create the variable using `next_creator` and store it."""
+ canonical_name = _canonicalize_variable_name(kwargs.get("name"))
+ v = next_creator(*args, **kwargs)
+
+ if canonical_name not in shared_variable_store:
+ shared_variable_store[canonical_name] = []
+ shared_variable_store[canonical_name].append(v)
+ return v
+
+ def reuse_variable(next_creator, *args, **kwargs):
+ """Re-use existing variable from store with same name (in order)."""
+ del next_creator, args
+ name = kwargs.get("name")
+ canonical_name = _canonicalize_variable_name(name)
+
+ try:
+ variable_index = variable_scope_access_index.get(canonical_name, 0)
+ v = shared_variable_store[canonical_name][variable_index]
+ # TODO(priyag): Make this variable re-use more robust by adding checks
+ # that the requested shape and dtype match the existing variable.
+ variable_scope_access_index[canonical_name] = variable_index + 1
+ return v
+ except (KeyError, IndexError):
+ raise RuntimeError(
+ "Tried to create variable {} with mismatching name on device {}".
+ format(name, device_id))
+
+ if device_id == 0:
+ return create_new_variable
+ else:
+ return reuse_variable
diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py
new file mode 100644
index 0000000000..713494d603
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py
@@ -0,0 +1,75 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SharedVariableCreator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distribute.python import shared_variable_creator
+from tensorflow.python.eager import test
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import variable_scope
+
+
+class CanonicalizeVariableNameTest(test.TestCase):
+
+ def _canonicalize(self, name):
+ return shared_variable_creator._canonicalize_variable_name(name)
+
+ def testNoName(self):
+ self.assertEquals("Variable", self._canonicalize(None))
+
+ def testPatternInMiddle(self):
+ self.assertEquals("foo/bar/baz", self._canonicalize("foo_1/bar_1/baz"))
+
+ def testPatternAtEnd(self):
+ self.assertEquals("foo", self._canonicalize("foo_1"))
+
+ def testWrongPatterns(self):
+ self.assertEquals("foo_1:0", self._canonicalize("foo_1:0"))
+ self.assertEquals("foo1", self._canonicalize("foo1"))
+ self.assertEquals("foo_a", self._canonicalize("foo_a"))
+
+
+@test_util.with_c_api
+class SharedVariableCreatorTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testSharedVariable(self):
+
+ shared_variable_store = {}
+ num_devices = 3
+ creator_fns = []
+ for i in range(num_devices):
+ creator_fn = shared_variable_creator.make_fn(shared_variable_store, i)
+ creator_fns.append(creator_fn)
+
+ with variable_scope.variable_creator_scope(creator_fns[0]):
+ v0 = variable_scope.variable(1.0, name="foo")
+
+ with variable_scope.variable_creator_scope(creator_fns[1]):
+ v1 = variable_scope.variable(1.0, name="foo")
+
+ with variable_scope.variable_creator_scope(creator_fns[2]):
+ v2 = variable_scope.variable(1.0, name="foo")
+
+ # v1 and v2 should be same as v0
+ self.assertIs(v1, v0)
+ self.assertIs(v2, v0)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/simple_estimator_example.py b/tensorflow/contrib/distribute/python/simple_estimator_example.py
new file mode 100644
index 0000000000..7095d801ad
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/simple_estimator_example.py
@@ -0,0 +1,97 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A simple example to test the a DistributionStrategy with Estimators.
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.estimator import estimator as estimator_lib
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator import run_config
+from tensorflow.python.framework import constant_op
+from tensorflow.python.layers import core
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import app
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import training_util
+
+
+def build_model_fn_optimizer():
+ """Simple model_fn with optimizer."""
+ # TODO(anjalisridhar): Move this inside the model_fn once OptimizerV2 is
+ # done?
+ optimizer = gradient_descent.GradientDescentOptimizer(0.2)
+
+ def model_fn(features, labels, mode): # pylint: disable=unused-argument
+ """model_fn which uses a single unit Dense layer."""
+ # You can also use the Flatten layer if you want to test a model without any
+ # weights.
+ layer = core.Dense(1, use_bias=True)
+ logits = layer(features)
+
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ predictions = {"logits": logits}
+ return model_fn_lib.EstimatorSpec(mode, predictions=predictions)
+
+ def loss_fn():
+ y = array_ops.reshape(logits, []) - constant_op.constant(1.)
+ return y * y
+
+ if mode == model_fn_lib.ModeKeys.EVAL:
+ return model_fn_lib.EstimatorSpec(mode, loss=loss_fn())
+
+ assert mode == model_fn_lib.ModeKeys.TRAIN
+
+ global_step = training_util.get_global_step()
+ train_op = optimizer.minimize(loss_fn(), global_step=global_step)
+ return model_fn_lib.EstimatorSpec(mode, loss=loss_fn(), train_op=train_op)
+
+ return model_fn
+
+
+def main(_):
+ distribution = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:GPU:1"])
+ config = run_config.RunConfig(distribute=distribution)
+
+ def input_fn():
+ features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
+ labels = dataset_ops.Dataset.from_tensors([1.]).repeat(10)
+ return dataset_ops.Dataset.zip((features, labels))
+
+ estimator = estimator_lib.Estimator(
+ model_fn=build_model_fn_optimizer(), config=config)
+ estimator.train(input_fn=input_fn, steps=10)
+
+ eval_result = estimator.evaluate(input_fn=input_fn)
+ print("Eval result: {}".format(eval_result))
+
+ def predict_input_fn():
+ predict_features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
+ return predict_features
+
+ predictions = estimator.predict(input_fn=predict_input_fn)
+ # TODO(anjalsridhar): This returns a generator object, figure out how to get
+ # meaningful results here.
+ print("Prediction results: {}".format(predictions))
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py
new file mode 100644
index 0000000000..cef5fd2f89
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/single_loss_example.py
@@ -0,0 +1,102 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A simple network to use in tests and examples."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distribute.python import step_fn
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.layers import core
+from tensorflow.python.layers import normalization
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+
+def single_loss_example(optimizer_fn, distribution, use_bias=False):
+ """Build a very simple network to use in tests and examples."""
+ dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
+ optimizer = optimizer_fn()
+ layer = core.Dense(1, use_bias=use_bias)
+
+ def loss_fn(x):
+ y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
+ return y * y
+
+ single_loss_step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer,
+ distribution)
+
+ # Layer is returned for inspecting the kernels in tests.
+ return single_loss_step, layer
+
+
+def minimize_loss_example(optimizer_fn,
+ use_bias=False,
+ use_callable_loss=True,
+ create_optimizer_inside_model_fn=False):
+ """Example of non-distribution-aware legacy code."""
+ dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
+ # An Optimizer instance is created either outside or inside model_fn.
+ outer_optimizer = None
+ if not create_optimizer_inside_model_fn:
+ outer_optimizer = optimizer_fn()
+
+ layer = core.Dense(1, use_bias=use_bias)
+
+ def model_fn(x):
+ """A very simple model written by the user."""
+
+ def loss_fn():
+ y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
+ return y * y
+
+ optimizer = outer_optimizer or optimizer_fn()
+
+ if use_callable_loss:
+ return optimizer.minimize(loss_fn)
+ else:
+ return optimizer.minimize(loss_fn())
+
+ return model_fn, dataset, layer
+
+
+def batchnorm_example(optimizer_fn,
+ batch_per_epoch=1,
+ momentum=0.9,
+ renorm=False):
+ """Example of non-distribution-aware legacy code with batch normalization."""
+ # input shape is [16, 8], input values are increasing in both dimensions.
+ dataset = dataset_ops.Dataset.from_tensor_slices(
+ [[[float(x * 8 + y + z * 100)
+ for y in range(8)]
+ for x in range(16)]
+ for z in range(batch_per_epoch)]).repeat()
+ optimizer = optimizer_fn()
+ batchnorm = normalization.BatchNormalization(
+ renorm=renorm, momentum=momentum, fused=False)
+
+ def model_fn(x):
+
+ def loss_fn():
+ y = math_ops.reduce_sum(batchnorm(x, training=True), axis=1)
+ loss = math_ops.reduce_mean(y - constant_op.constant(1.))
+ return loss
+
+ # Callable loss.
+ return optimizer.minimize(loss_fn)
+
+ return model_fn, dataset, batchnorm
diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py
new file mode 100644
index 0000000000..82514c64be
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/step_fn.py
@@ -0,0 +1,103 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The step function abstraction represents a single training step."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import backprop
+from tensorflow.python.training import optimizer as optimizer_lib
+
+
+class Step(object):
+ """Interface for performing each step of a training algorithm."""
+
+ def __init__(self, distribution):
+ self._distribution = distribution
+
+ @property
+ def distribution(self):
+ return self._distribution
+
+ def __call__(self):
+ """Perform one step of this training algorithm."""
+ return self.step(self.inputs())
+
+ def inputs(self):
+ """For the generating the input to be passed to `step()`."""
+ raise NotImplementedError("must be implemented in descendants")
+
+ def step(self, inputs):
+ """Perform the main computation of this training algorithm."""
+ raise NotImplementedError("must be implemented in descendants")
+
+
+class StandardInputStep(Step):
+ """Step with a standard implementation of input handling.
+
+ Args:
+ input_dataset: a tf.data Dataset that provides input.
+ """
+
+ def __init__(self, input_dataset, distribution):
+ Step.__init__(self, distribution)
+ self._distributed_input = distribution.distribute_dataset(input_dataset)
+
+ def inputs(self):
+ return self._distributed_input.get_next()
+
+
+class StandardSingleLossStep(StandardInputStep):
+ """A step function that implements a training step for a feed forward network.
+
+ An instance of this class is intended to be used as a callable:
+
+ ```python
+ ...
+ step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer)
+ step.initialize(distribution)
+
+ # Run a single training step on a given DistributionStrategy:
+ step(distribution)
+ ...
+ ```
+
+ Args:
+ input_dataset: a tf.data Dataset that provides input.
+ loss_fn: a function that returns loss.
+ optimizer: an optimizer that implements an update rule.
+ distribution: a `DistributionStrategy` object.
+ """
+
+ def __init__(self, input_dataset, loss_fn, optimizer, distribution):
+ StandardInputStep.__init__(self, input_dataset, distribution)
+ self._loss_fn = loss_fn
+ self._optimizer = optimizer
+ self._is_run_concurrently = False
+
+ def step(self, inputs):
+ with self._distribution.scope():
+ gradients_fn = backprop.implicit_grad(self._loss_fn)
+ gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
+
+ grads_and_vars = self.distribution.call_for_each_tower(
+ gradients_fn, inputs, run_concurrently=self._is_run_concurrently)
+ # If threads use layers, then we need to run the first step sequentially,
+ # so that layers.build() is not executed in parallel. Otherwise, multiple
+ # sets of mirrored variables are going to be created.
+ self._is_run_concurrently = True
+ return self._optimizer._distributed_apply( # pylint: disable=protected-access
+ self.distribution, grads_and_vars)
diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py
new file mode 100644
index 0000000000..75c5ec9659
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/step_fn_test.py
@@ -0,0 +1,62 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for class Step."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.ops import variables
+
+
+class SingleLossStepTest(test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(
+ combinations.times(
+ combinations.distributions_and_v1_optimizers(),
+ combinations.combine(mode=combinations.graph_and_eager_modes)))
+ def testTrainNetwork(self, distribution, optimizer_fn):
+ with distribution.scope():
+ single_loss_step, layer = single_loss_example(
+ optimizer_fn, distribution, use_bias=True)
+
+ if context.executing_eagerly():
+ run_step = single_loss_step
+ else:
+ with self.test_session() as sess:
+ run_step = sess.make_callable(single_loss_step())
+ self.evaluate(variables.global_variables_initializer())
+
+ weights, biases = [], []
+ for _ in range(10):
+ run_step()
+
+ weights.append(self.evaluate(distribution.fetch(layer.kernel)))
+ biases.append(self.evaluate(distribution.fetch(layer.bias)))
+
+ error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
+ is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
+ self.assertTrue(is_not_increasing)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
new file mode 100644
index 0000000000..2b4ad9f146
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -0,0 +1,225 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Library for testing DistributionStrategy descendants."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.layers import core
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import optimizer
+
+
+class _TestException(Exception):
+ pass
+
+
+# May be the argument to either distribution.call_for_each_tower() or
+# get_tower_context().merge_call()
+def _raise_exception_fn(_=None):
+ raise _TestException()
+
+
+# Must be the argument to a distribution.call_for_each_tower() call, calls a
+# get_tower_context().merge_call() that raises an exception.
+def _merge_raises_fn():
+ distribute_lib.get_tower_context().merge_call(_raise_exception_fn)
+
+
+# Must be the argument to a get_tower_context().merge_call() call, calls
+# dist.call_for_each_tower() with a function that raises an exception.
+def _call_raises_fn(dist):
+ dist.call_for_each_tower(_raise_exception_fn)
+
+
+# Must be the argument to a distribution.call_for_each_tower() call,
+# calls a get_tower_context().merge_call() that calls a
+# call_for_each_tower() that raises an exception.
+def _merge_call_raises_fn():
+ distribute_lib.get_tower_context().merge_call(_call_raises_fn)
+
+
+# Must be the argument to a get_tower_context().merge_call() call, calls
+# dist.call_for_each_tower() with a function that calls a
+# get_tower_context().merge_call() that raises an exception.
+def _call_merge_raises_fn(dist):
+ dist.call_for_each_tower(_merge_raises_fn)
+
+
+# Must be the argument to a distribution.call_for_each_tower() call, calls a
+# get_tower_context().merge_call() that calls a call_for_each_tower() that
+# calls a get_tower_context().merge_call() that raises an exception.
+def _merge_call_merge_raises_fn():
+ distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn)
+
+
+class DistributionTestBase(test.TestCase):
+ """Some tests that should work with any DistributionStrategy."""
+
+ def _test_minimize_loss_eager(self, d):
+ with d.scope():
+ l = core.Dense(1, use_bias=False)
+
+ def loss(x):
+ # TODO(josh11b): What if this constant was instead a captured
+ # value? Would it need to be a value that has been passed
+ # through d.broadcast()?
+ y = array_ops.reshape(l(x), []) - constant_op.constant(1.)
+ return y * y
+ # TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a
+ # common `implicit_grad` function and put it in DistributionStrategy.
+ grad_fn = backprop.implicit_grad(loss)
+ grad_fn = optimizer.get_filtered_grad_fn(grad_fn)
+
+ def update(v, g):
+ return v.assign_sub(0.2 * g)
+
+ one = d.broadcast(constant_op.constant([[1.]]))
+
+ def step():
+ """Perform one optimization step."""
+ # Run forward & backward to get gradients, variables list.
+ g_v = d.call_for_each_tower(grad_fn, one, run_concurrently=l.built)
+
+ # Update the variables using the gradients and the update() function.
+ before_list = []
+ after_list = []
+ for g, v in g_v:
+ fetched = d.fetch(v)
+ before_list.append(fetched)
+ # control_dependencies irrelevant but harmless in eager execution
+ with ops.control_dependencies([fetched]):
+ g = d.reduce("sum", g, destinations=v)
+ with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ after_list.append(d.fetch(v))
+ return before_list, after_list
+
+ for i in range(10):
+ b, a = step()
+ if i == 0:
+ before, = b # pylint: disable=unbalanced-tuple-unpacking
+ after, = a # pylint: disable=unbalanced-tuple-unpacking
+
+ error_before = abs(before.numpy() - 1)
+ error_after = abs(after.numpy() - 1)
+ # Error should go down
+ self.assertLess(error_after, error_before)
+
+ def _test_minimize_loss_graph(self, d, soft_placement=False):
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = soft_placement
+ config.gpu_options.per_process_gpu_memory_fraction = 0.3
+ with context.graph_mode(), \
+ ops.Graph().as_default(), \
+ self.test_session(config=config) as sess, \
+ d.scope():
+ l = core.Dense(1, use_bias=False)
+
+ def loss(x):
+ # TODO(josh11b): What if this constant was instead a captured
+ # value? Would it need to be a value that has been passed
+ # through d.broadcast()?
+ y = array_ops.reshape(l(x), []) - constant_op.constant(1.)
+ return y * y
+
+ grad_fn = backprop.implicit_grad(loss)
+
+ def update(v, g):
+ return v.assign_sub(0.2 * g)
+
+ one = d.broadcast(constant_op.constant([[1.]]))
+
+ def step():
+ """Perform one optimization step."""
+ # Run forward & backward to get gradients, variables list.
+ g_v = d.call_for_each_tower(grad_fn, one)
+
+ # Update the variables using the gradients and the update() function.
+ before_list = []
+ after_list = []
+ for g, v in g_v:
+ fetched = d.fetch(v)
+ before_list.append(fetched)
+ with ops.control_dependencies([fetched]):
+ g = d.reduce("sum", g, destinations=v)
+ with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ after_list.append(d.fetch(v))
+ return before_list, after_list
+
+ before_out, after_out = step()
+ variables.global_variables_initializer().run()
+ for i in range(10):
+ b, a = sess.run((before_out, after_out))
+ if i == 0:
+ before, = b
+ after, = a
+
+ error_before = abs(before - 1)
+ error_after = abs(after - 1)
+ # Error should go down
+ self.assertLess(error_after, error_before)
+
+ def _test_map_reduce(self, d, in_graph=None):
+ with d.scope():
+ map_in = [constant_op.constant(i) for i in range(10)]
+ map_out = d.map(map_in, lambda x, y: x * y, 2)
+ observed = d.fetch(d.reduce("sum", map_out))
+ expected = 90 # 2 * (0 + 1 + ... + 9)
+ self.assertEqual(expected, observed.numpy())
+
+ def _test_device_index(self, d):
+ with d.scope():
+ expected_devices = [False] * len(d.worker_devices)
+
+ def mark_devices_fn(device_id):
+ self.assertLess(device_id, len(d.worker_devices))
+ self.assertFalse(expected_devices[device_id])
+ expected_devices[device_id] = True
+
+ d.call_for_each_tower(mark_devices_fn, d.worker_device_index)
+ self.assertAllEqual(expected_devices, [True] * len(d.worker_devices))
+
+ def _test_tower_id(self, d):
+ with d.scope():
+ expected_devices = [False] * len(d.worker_devices)
+
+ def mark_devices_fn():
+ tower_id = distribute_lib.get_tower_context().tower_id
+ self.assertLess(tower_id, len(d.worker_devices))
+ self.assertFalse(expected_devices[tower_id])
+ expected_devices[tower_id] = True
+
+ d.call_for_each_tower(mark_devices_fn)
+ self.assertAllEqual(expected_devices, [True] * len(d.worker_devices))
+
+ def _test_call_and_merge_exceptions(self, dist):
+ with dist.scope():
+ with self.assertRaises(_TestException):
+ dist.call_for_each_tower(_raise_exception_fn)
+ with self.assertRaises(_TestException):
+ dist.call_for_each_tower(_merge_raises_fn)
+ with self.assertRaises(_TestException):
+ dist.call_for_each_tower(_merge_call_raises_fn)
+ with self.assertRaises(_TestException):
+ dist.call_for_each_tower(_merge_call_merge_raises_fn)
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
new file mode 100644
index 0000000000..c1ba22ed5a
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -0,0 +1,575 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Various classes representing distributed values.
+
+See go/tf-distribution-strategy.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import weakref
+
+import six
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.contrib.distribute.python import prefetching_ops_v2
+from tensorflow.contrib.eager.python import datasets
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.training import checkpointable
+from tensorflow.python.training import device_util
+from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import saver
+from tensorflow.python.util import nest
+
+
+# pylint: disable=line-too-long
+# TODO(josh11b): Should device values be strings or DeviceSpec objects
+# Not sure DeviceSpec objects are usable as a dict key.
+class DistributedValues(object):
+ """Holds a map from device to values. Either PerDevice or Mirrored."""
+
+ def __init__(self, index):
+ self._index = {device_util.canonicalize(key): value
+ for key, value in six.iteritems(index)}
+
+ def get(self, device=None):
+ """Returns the value for the current device or raises a ValueError."""
+ if device is None:
+ tower_context = distribute_lib.get_tower_context()
+ if tower_context:
+ device = tower_context.device
+ else:
+ device = distribute_lib.get_update_device()
+ if device is None:
+ device = device_util.current()
+ device = device_util.canonicalize(device)
+ try:
+ return self._index[device]
+ except KeyError:
+ raise ValueError("Device %s not found in %s (current device %s)" %
+ (device, self._index.keys(), device_util.current()))
+
+ def on_device(self, device):
+ device = device_util.canonicalize(device)
+ return device in self._index
+
+ @property
+ def devices(self):
+ return self._index.keys()
+
+ def __str__(self):
+ return "%s:%s" % (self.__class__.__name__, self._index)
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self._index)
+
+ # TODO(josh11b): Possibly make an accessor for _index for use by
+ # DistributionStrategy implementations.
+
+
+class DistributedDelegate(DistributedValues):
+ """A map from device to values; acts as the same type as the values."""
+
+ def __init__(self, index):
+ super(DistributedDelegate, self).__init__(index)
+
+ def __getattr__(self, name):
+ return getattr(self.get(), name)
+
+ # pylint: disable=multiple-statements
+ def __add__(self, o): return self.get() + o
+ def __radd__(self, o): return o + self.get()
+ def __sub__(self, o): return self.get() - o
+ def __rsub__(self, o): return o - self.get()
+ def __mul__(self, o): return self.get() * o
+ def __rmul__(self, o): return o * self.get()
+ def __truediv__(self, o): return self.get() / o
+ def __rtruediv__(self, o): return o / self.get()
+ def __floordiv__(self, o): return self.get() // o
+ def __rfloordiv__(self, o): return o // self.get()
+ def __mod__(self, o): return self.get() % o
+ def __rmod__(self, o): return o % self.get()
+ def __lt__(self, o): return self.get() < o
+ def __le__(self, o): return self.get() <= o
+ def __gt__(self, o): return self.get() > o
+ def __ge__(self, o): return self.get() >= o
+ def __and__(self, o): return self.get() & o
+ def __rand__(self, o): return o & self.get()
+ def __or__(self, o): return self.get() | o
+ def __ror__(self, o): return o | self.get()
+ def __xor__(self, o): return self.get() ^ o
+ def __rxor__(self, o): return o ^ self.get()
+ def __getitem__(self, o): return self.get()[o]
+ def __pow__(self, o, modulo=None): return pow(self.get(), o, modulo)
+ def __rpow__(self, o): return pow(o, self.get())
+ def __invert__(self): return ~self.get()
+ def __neg__(self): return -self.get()
+ def __abs__(self): return abs(self.get())
+
+ def __div__(self, o):
+ try:
+ return self.get().__div__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rdiv__(self, o):
+ try:
+ return self.get().__rdiv__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __matmul__(self, o):
+ try:
+ return self.get().__matmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rmatmul__(self, o):
+ try:
+ return self.get().__rmatmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ # TODO(josh11b): Even more operator overloads.
+
+
+class PerDevice(DistributedValues):
+ """Holds a map from device to unsynchronized values."""
+ pass
+
+
+class Mirrored(DistributedValues):
+ """Holds a map from device to values which are kept in sync."""
+ pass
+
+
+def _assign_on_device(device, variable, tensor):
+ with ops.device(device):
+ return variable.assign(array_ops.identity(tensor))
+
+
+DistributedVarOp = collections.namedtuple(
+ "DistributedVarOp", ["name", "graph", "type"])
+
+
+class DistributedVariable(DistributedDelegate):
+ """Holds a map from device to variables."""
+ # TODO(josh11b): Support changing the set of variables if e.g. if new
+ # devices are joining or a device is to leave.
+
+ def __init__(self, index):
+ # Child class must set self._primary_var before calling
+ # super(...).__init__(index).
+ self._common_name = self._primary_var.name.split(":")[0]
+ super(DistributedVariable, self).__init__(index)
+
+ @property
+ def initializer(self):
+ return control_flow_ops.group([v.initializer for v in self._index.values()])
+
+ @property
+ def graph(self):
+ return self._primary_var.graph
+
+ @property
+ def _shared_name(self):
+ return self._common_name
+
+ @property
+ def _unique_id(self):
+ return self._primary_var._unique_id # pylint: disable=protected-access
+
+ @property
+ def name(self):
+ return self._primary_var.name
+
+ @property
+ def dtype(self):
+ return self._primary_var.dtype
+
+ @property
+ def shape(self):
+ return self._primary_var.shape
+
+ def get_shape(self):
+ return self._primary_var.get_shape()
+
+ @property
+ def op(self):
+ # We want cross-tower code that does some var.op.X calls
+ # to work (even if the current device isn't in self.devices), but
+ # other uses of var.op in a cross-tower context to fail.
+ if distribute_lib.get_cross_tower_context():
+ return DistributedVarOp(self._primary_var.op.name,
+ self._primary_var.op.graph,
+ self._primary_var.op.type)
+ return self.get().op
+
+ def _should_act_as_resource_variable(self):
+ """Pass resource_variable_ops.is_resource_variable check."""
+ pass
+
+
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
+ # Try to avoid assignments to and other mutations of MirroredVariable
+ # state except through a DistributionStrategy.update() call.
+ assert not as_ref
+ return ops.internal_convert_to_tensor(
+ var.get(), dtype=dtype, name=name, as_ref=as_ref)
+
+
+ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion)
+# TODO(josh11b): ops.register_dense_tensor_like_type(DistributedVariable)?
+
+
+class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
+ """Class for defining how to restore a MirroredVariable."""
+
+ def __init__(self, mirrored_variable, primary_variable, name):
+ self._mirrored_variable = mirrored_variable
+ super(_MirroredSaveable, self).__init__(primary_variable, "", name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ """Restore the same value into all variables."""
+ tensor, = restored_tensors
+ return control_flow_ops.group([
+ _assign_on_device(d, v, tensor)
+ for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access
+
+
+def _get_update_device():
+ """Validate we are in update/update_non_slot() and return current device.
+
+ This is used in MirroredVariable.assign* members, to make sure they
+ are only called via an update method, to make sure all components of the
+ variable are being updated in a consistent way.
+
+ Returns:
+ A string device.
+
+ Raises:
+ RuntimeError: If not in distribution.update()/.update_non_slot().
+ """
+ device = distribute_lib.get_update_device()
+ if device is None:
+ raise RuntimeError(
+ "Use DistributionStrategy.update() to modify a MirroredVariable.")
+ return device
+
+
+class MirroredVariable(DistributedVariable, Mirrored,
+ checkpointable.CheckpointableBase):
+ """Holds a map from device to variables whose values are kept in sync."""
+
+ def __init__(self, index, primary_var):
+ # Use a weakref to make it easy to map from the contained values
+ # to the container without introducing a reference cycle.
+ for v in six.itervalues(index):
+ v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
+ self._primary_var = primary_var
+ super(MirroredVariable, self).__init__(index)
+
+ # We use _get_update_device() for the assign* methods to enforce
+ # that we are in an update() function. The arguments to update() are
+ # automatically unwrapped so the update() function would normally
+ # see regular variables, not MirroredVariables. However, the update
+ # function can still operate on wrapped MirroredVariables through
+ # object members, captured arguments, etc. This is more likely in an
+ # update_non_slot() function (like OptimizerV2._finish), which can
+ # update several non-slot variables in one call.
+ def assign_sub(self, *args, **kwargs):
+ return self.get(device=_get_update_device()).assign_sub(*args, **kwargs)
+
+ def assign_add(self, *args, **kwargs):
+ return self.get(device=_get_update_device()).assign_add(*args, **kwargs)
+
+ def assign(self, *args, **kwargs):
+ return self.get(device=_get_update_device()).assign(*args, **kwargs)
+
+ def _gather_saveables_for_checkpoint(self):
+ """Overrides CheckpointableBase method.
+
+ This allows both name-based and object-based save and restore of
+ MirroredVariables.
+
+ Returns:
+ A dictionary mapping attribute names to `SaveableObject` factories.
+ """
+ def _saveable_factory(name=self._common_name):
+ return _MirroredSaveable(self, self._primary_var, name)
+ return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+
+
+class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
+ """Class for defining how to restore a TowerLocalVariable."""
+
+ def __init__(self, tower_local_variable, name):
+ self._tower_local_variable = tower_local_variable
+ # We use a callable so that we don't have to evaluate this expression
+ # in the case where we are trying to restore instead of save.
+ def tensor():
+ return distribute_lib.get_distribution_strategy().fetch(
+ tower_local_variable)
+ spec = saver.BaseSaverBuilder.SaveSpec(
+ tensor=tensor,
+ slice_spec="",
+ name=name,
+ dtype=tower_local_variable.dtype)
+ super(_TowerLocalSaveable, self).__init__(tensor, [spec], name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ """Restore the same value into all variables."""
+ tensor, = restored_tensors
+ # To preserve the sum across save and restore, we have to divide the
+ # total across all devices when restoring a variable that was summed
+ # when saving.
+ if self._tower_local_variable.reduce_method == "sum":
+ tensor *= 1. / len(self._tower_local_variable.devices)
+ return control_flow_ops.group([
+ _assign_on_device(d, v, tensor)
+ for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access
+
+
+class TowerLocalVariable(DistributedVariable, PerDevice,
+ checkpointable.CheckpointableBase):
+ """Holds a map from device to variables whose values are reduced on save."""
+
+ def __init__(self, index, primary_var, reduce_method):
+ self._primary_var = primary_var
+ self._reduce_method = reduce_method
+ super(TowerLocalVariable, self).__init__(index)
+
+ def assign_sub(self, *args, **kwargs):
+ return self.get().assign_sub(*args, **kwargs)
+
+ def assign_add(self, *args, **kwargs):
+ return self.get().assign_add(*args, **kwargs)
+
+ def assign(self, *args, **kwargs):
+ return self.get().assign(*args, **kwargs)
+
+ @property
+ def reduce_method(self):
+ return self._reduce_method
+
+ def _gather_saveables_for_checkpoint(self):
+ """Overrides CheckpointableBase method.
+
+ This allows both name-based and object-based save and restore of
+ TowerLocalVariables.
+
+ Returns:
+ A dictionary mapping attribute names to `SaveableObject` factories.
+ """
+ def _saveable_factory(name=self._common_name):
+ return _TowerLocalSaveable(self, name)
+ return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+
+
+def _devices_match(d1, d2):
+ return device_util.canonicalize(d1) == device_util.canonicalize(d2)
+
+
+def regroup(per_device, wrap_class=PerDevice):
+ """Makes device->nest map into a nest of PerDevice/Mirrored values."""
+ items = list(per_device.items())
+ assert items
+ v0 = items[0][1] # First value
+
+ if isinstance(v0, list):
+ for _, v in items[1:]:
+ assert isinstance(v, list)
+ assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
+ (len(v), len(v0), v, v0))
+ return [regroup({k: v[i] for k, v in items}, wrap_class)
+ for i in range(len(v0))]
+
+ if isinstance(v0, tuple):
+ for _, v in items[1:]:
+ assert isinstance(v, tuple)
+ assert len(v) == len(v0)
+ regrouped_tuple = tuple(regroup({k: v[i] for k, v in items}, wrap_class)
+ for i in range(len(v0)))
+ if hasattr(v0, "_fields"):
+ # This tuple is in fact a namedtuple! Create a new namedtuple instance
+ # and initialize it with the regrouped values:
+ assert hasattr(type(v0), "_make")
+ return type(v0)._make(regrouped_tuple)
+ else:
+ return regrouped_tuple
+
+ if isinstance(v0, dict):
+ v0keys = set(v0.keys())
+ for _, v in items[1:]:
+ assert isinstance(v, dict)
+ assert set(v.keys()) == v0keys
+ return {key: regroup({k: v[key] for k, v in items}, wrap_class)
+ for key in v0keys}
+
+ # If exactly the same object across all devices, return it unwrapped.
+ same_id = True
+ for _, v in items[1:]:
+ if v is not v0:
+ same_id = False
+ break
+ # Consider three cases where same_id is true:
+ # * If v0 is a MirroredVariable (and same_id means it is the same
+ # across all devices), we want to return it. We check
+ # MirroredVariable specifically since it can look like it
+ # has a _mirrored_container member since its members do.
+ # * If v0 is a member of a mirrored variable, in which case
+ # hasattr(v0, "_mirrored_container") is true, we want to
+ # return the MirroredVariable that contains it using the
+ # _mirrored_container logic below. This case can trigger
+ # same_id when there is only one device.
+ # * In any other situation, same_id means we return v0.
+ if same_id and (isinstance(v0, MirroredVariable) or
+ not hasattr(v0, "_mirrored_container")):
+ return v0
+
+ # Detect the case where each device has a parallel component of the
+ # same MirroredVariable. In this case we want to return the
+ # containing MirroredVariable, after a bunch of sanity checking.
+ # In particular, each component should have the same container,
+ # and the devices of the variables should match the keys of the
+ # per-device dictionary.
+ # TODO(josh11b): Do we need similar logic for TowerLocalVariables?
+ if hasattr(v0, "_mirrored_container"):
+ # pylint: disable=protected-access
+ assert not isinstance(v0, MirroredVariable), (
+ "ids = %s, items = %s" % ([id(v[1]) for v in items], items))
+ assert _devices_match(v0.device, items[0][0]), (
+ "v0.device = %s, items = %s" % (v0.device, items))
+ mirrored_container = v0._mirrored_container()
+ assert mirrored_container is not None
+ for d, v in items[1:]:
+ assert _devices_match(v.device, d), (
+ "v.device = %s, d = %s, items = %s" % (v.device, d, items))
+ assert mirrored_container is v._mirrored_container()
+ return mirrored_container
+ # pylint: enable=protected-access
+
+ return wrap_class(per_device)
+
+
+def select_device(device, structured):
+ """Specialize a nest of regular & per-device values for one device."""
+ def _get(x):
+ return x.get(device) if isinstance(x, DistributedValues) else x
+
+ return nest.map_structure(_get, structured)
+
+
+def select_device_mirrored(device, structured):
+ """Specialize a nest of regular & mirrored values for one device."""
+ def _get_mirrored(x):
+ if isinstance(x, DistributedValues):
+ if not isinstance(x, Mirrored):
+ raise TypeError(
+ "Expected value to be mirrored across towers: %s in %s." %
+ (x, structured))
+ return x.get(device)
+ else:
+ return x
+
+ return nest.map_structure(_get_mirrored, structured)
+
+
+class PerDeviceDataIterator(object):
+ """An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`."""
+
+ def __init__(self, iterator, devices, prefetch_on_device=None):
+ self._iterator = iterator
+ self._devices = devices
+ self._prefetch_on_device = prefetch_on_device
+
+ def get_next(self, name=None):
+ """Scatter the input across devices."""
+ if self._prefetch_on_device:
+ data_list = self._iterator.get_next(name=name)
+ index = dict(zip(self._devices, data_list))
+ else:
+ batch = self._iterator.get_next(name=name)
+ index = {}
+ def get_ith(i):
+ return lambda x: x[i]
+
+ for i, d in enumerate(self._devices):
+ index[d] = nest.map_structure(get_ith(i), batch)
+ if context.executing_eagerly():
+ with ops.device(d):
+ index[d] = nest.map_structure(array_ops.identity, index[d])
+
+ return regroup(index)
+
+
+class PerDeviceDataset(object):
+ """Like `tf.data.Dataset` split devices, producing `PerDevice` data."""
+
+ def __init__(self, dataset, devices, prefetch_on_device=None):
+ self._devices = devices
+
+ # Default to using prefetching in graph mode, unless specified.
+ # TODO(priyag): Enable prefetching in eager mode.
+ self._prefetch_on_device = prefetch_on_device
+ if self._prefetch_on_device is None:
+ self._prefetch_on_device = not context.executing_eagerly()
+ assert not (self._prefetch_on_device and context.executing_eagerly()), (
+ "Prefetching is only supported in graph mode currently")
+
+ if self._prefetch_on_device:
+ self._dataset = dataset
+ else:
+ # TODO(priyag): If dropping remainder is not appropriate, find another
+ # approach to distributing the dataset when not possible to divide evenly.
+ # Possibly not an issue when we start using PartitionedDataset.
+ self._dataset = dataset.apply(
+ batching.batch_and_drop_remainder(len(devices)))
+
+ def make_one_shot_iterator(self):
+ """Get a one time use iterator for the distributed PerDeviceDataset."""
+ if self._prefetch_on_device:
+ on_device_dataset = self._dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(self._devices))
+ dataset_iterator = on_device_dataset.make_one_shot_iterator()
+ elif context.executing_eagerly():
+ dataset_iterator = datasets.Iterator(self._dataset)
+ else:
+ dataset_iterator = self._dataset.make_one_shot_iterator()
+
+ return PerDeviceDataIterator(
+ dataset_iterator, self._devices, self._prefetch_on_device)
+
+
+class MapOutput(object):
+ """Map can result in multiple outputs per device."""
+
+ def __init__(self, l):
+ self._l = l
+
+ def get(self):
+ return self._l
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
new file mode 100644
index 0000000000..5c0d4b7d6c
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -0,0 +1,807 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the distributed values library."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import values
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import device_util
+from tensorflow.python.training import saver as saver_lib
+
+
+@test_util.with_c_api
+class DistributedValuesTest(test.TestCase):
+
+ def testGetEager(self):
+ with ops.device("/device:CPU:0"):
+ one = constant_op.constant(1)
+ two = constant_op.constant(2)
+ v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two})
+ self.assertEqual(two, v.get("/device:GPU:0"))
+ self.assertEqual(one, v.get())
+ with self.assertRaises(ValueError):
+ self.assertIsNone(v.get("/device:GPU:2"))
+
+ def testGetGraph(self):
+ with context.graph_mode(), \
+ ops.Graph().as_default(), \
+ ops.device("/device:CPU:0"):
+ one = constant_op.constant(1)
+ two = constant_op.constant(2)
+ v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two})
+ self.assertEqual(two, v.get("/device:GPU:0"))
+ self.assertEqual(one, v.get())
+ with self.assertRaises(ValueError):
+ self.assertIsNone(v.get("/device:GPU:2"))
+
+ def testCanonicalization(self):
+ canonical_cpu = ["/job:localhost/replica:0/task:0/device:CPU:0"]
+ v = values.DistributedValues({"": 42})
+ self.assertEqual(canonical_cpu, list(v._index.keys()))
+ v = values.DistributedValues({"/device:CPU:0": 42})
+ self.assertEqual(canonical_cpu, list(v._index.keys()))
+ v = values.DistributedValues({"/cpu:0": 42})
+ self.assertEqual(canonical_cpu, list(v._index.keys()))
+ v = values.DistributedValues({"/CPU:0": 42})
+ self.assertEqual(canonical_cpu, list(v._index.keys()))
+ with self.assertRaises(AssertionError):
+ v = values.DistributedValues({"/device:cpu:0": 42})
+
+
+@test_util.with_c_api
+class DistributedDelegateTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testGetAttr(self):
+ with ops.device("/device:CPU:0"):
+
+ class Foo(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ v = values.DistributedDelegate(
+ {"/device:CPU:0": Foo(7), "/device:GPU:0": Foo(8)})
+ self.assertEqual(7, v.x)
+ with self.assertRaises(AttributeError):
+ _ = v.y
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testOperatorOverride(self):
+ with ops.device("/device:CPU:0"):
+ v = values.DistributedDelegate({"/device:CPU:0": 7, "/device:GPU:0": 8})
+ # v should act like int(7).
+ self.assertEqual(8, v + 1)
+ self.assertEqual(10, 3 + v)
+ self.assertEqual(14, v + v)
+ self.assertEqual(5, v - 2)
+ self.assertEqual(6, 13 - v)
+ self.assertEqual(0, v - v)
+ self.assertEqual(14, v * 2)
+ self.assertEqual(21, 3 * v)
+ self.assertEqual(49, v * v)
+ self.assertEqual(3.5, v / 2)
+ self.assertEqual(1.5, 10.5 / v)
+ self.assertEqual(3, v // 2)
+ self.assertEqual(2, 15 // v)
+ self.assertEqual(1, v % 2)
+ self.assertEqual(2, 16 % v)
+ self.assertTrue(v < 12)
+ self.assertTrue(v <= 12)
+ self.assertFalse(v > 12)
+ self.assertFalse(v >= 12)
+ self.assertFalse(12 < v)
+ self.assertFalse(12 <= v)
+ self.assertTrue(12 > v)
+ self.assertTrue(12 >= v)
+ self.assertEqual(3, v & 3)
+ self.assertEqual(3, 11 & v)
+ self.assertEqual(15, v | 8)
+ self.assertEqual(23, 16 | v)
+ self.assertEqual(4, v ^ 3)
+ self.assertEqual(12, 11 ^ v)
+ self.assertEqual(343, pow(v, 3))
+ self.assertEqual(3, pow(v, 3, 10))
+ self.assertEqual(128, pow(2, v))
+ self.assertEqual(-7, -v)
+ self.assertEqual(~7, ~v)
+ self.assertEqual(7, abs(v))
+ with self.assertRaises(TypeError):
+ _ = v[2]
+
+
+def _device_str(d):
+ return "/device:GPU:" + str(d)
+
+
+def _nested_value(d):
+ return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
+
+
+def _make_mirrored():
+ v = []
+ index = {}
+ devices = ["/device:GPU:0", "/device:CPU:0"]
+ for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
+ with ops.device(d):
+ v.append(variable_scope.get_variable(
+ name=n, initializer=init, use_resource=True))
+ index[d] = v[-1]
+ mirrored = values.MirroredVariable(index, v[0])
+ return v, devices, mirrored
+
+
+@test_util.with_c_api
+class RegroupAndSelectDeviceTest(test.TestCase):
+
+ def _is_per_device(self, result, expected, klass=values.PerDevice):
+ self.assertIsInstance(result, klass)
+ # We canonicalize the devices to match the device strings returned
+ # by PerDevice, which also does device string canonicalization.
+ devices = [device_util.canonicalize(_device_str(i))
+ for i in range(len(expected))]
+ self.assertEqual(set(devices), set(result.devices))
+ for i, d in enumerate(devices):
+ self.assertEqual(expected[i], result.get(d))
+ self.assertEqual(expected[i], result.get(_device_str(i)))
+
+ def testNested(self):
+ result = values.regroup({_device_str(0): _nested_value("1"),
+ _device_str(1): _nested_value("2")})
+ self.assertIsInstance(result, tuple)
+ self.assertEqual(3, len(result))
+ self._is_per_device(result[0], ["a1", "a2"])
+ self._is_per_device(result[2], ["h1", "h2"])
+
+ self.assertIsInstance(result[1], list)
+ self.assertEqual(3, len(result[1]))
+ self._is_per_device(result[1][0], ["b1", "b2"])
+ self._is_per_device(result[1][2], ["g1", "g2"])
+
+ self.assertIsInstance(result[1][1], dict)
+ self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
+ self._is_per_device(result[1][1]["c"], ["d1", "d2"])
+ self._is_per_device(result[1][1]["e"], ["f1", "f2"])
+
+ # Also test that we can undo the merge using select_device()
+ self.assertEqual(_nested_value("1"),
+ values.select_device(_device_str(0), result))
+ self.assertEqual(_nested_value("2"),
+ values.select_device(_device_str(1), result))
+ # select_device_mirrored() should fail due to non-mirrored values
+ with self.assertRaises(TypeError):
+ values.select_device_mirrored(_device_str(0), result)
+ with self.assertRaises(TypeError):
+ values.select_device_mirrored(_device_str(1), result)
+
+ def testWrapClass(self):
+ # Normally a mirrored value would be the same across devices, but
+ # for a test it is convenient to be able to tell the values apart.
+ result = values.regroup({_device_str(0): _nested_value("1"),
+ _device_str(1): _nested_value("2")},
+ values.Mirrored)
+ self.assertIsInstance(result, tuple)
+ self.assertEqual(3, len(result))
+ self._is_per_device(result[0], ["a1", "a2"], values.Mirrored)
+ self._is_per_device(result[2], ["h1", "h2"], values.Mirrored)
+
+ self.assertIsInstance(result[1], list)
+ self.assertEqual(3, len(result[1]))
+ self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored)
+ self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored)
+
+ self.assertIsInstance(result[1][1], dict)
+ self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
+ self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
+ self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored)
+
+ # Also test that we can undo the merge using select_device()
+ self.assertEqual(_nested_value("1"),
+ values.select_device(_device_str(0), result))
+ self.assertEqual(_nested_value("2"),
+ values.select_device(_device_str(1), result))
+ # Values are marked as mirrored, so select_device_mirrored() is allowed.
+ self.assertEqual(_nested_value("1"),
+ values.select_device_mirrored(_device_str(0), result))
+ self.assertEqual(_nested_value("2"),
+ values.select_device_mirrored(_device_str(1), result))
+
+ def testMirroredContainer(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+ v, devices, mirrored = _make_mirrored()
+ result = values.regroup(dict(zip(devices, v)))
+ self.assertIs(mirrored, result)
+
+ def testSameId(self):
+ foo = object()
+ result = values.regroup({_device_str(0): ("a", foo),
+ _device_str(1): ("b", foo)})
+ self.assertIsInstance(result, tuple)
+ self.assertEqual(2, len(result))
+ self._is_per_device(result[0], ["a", "b"])
+ self.assertIs(foo, result[1])
+
+ # Test select_device(), should undo the merge done by regroup().
+ result_0 = values.select_device(_device_str(0), result)
+ self.assertIsInstance(result_0, tuple)
+ self.assertEqual(2, len(result_0))
+ self.assertEqual("a", result_0[0])
+ self.assertIs(foo, result_0[1])
+ result_1 = values.select_device(_device_str(1), result)
+ self.assertIsInstance(result_1, tuple)
+ self.assertEqual(2, len(result_1))
+ self.assertEqual("b", result_1[0])
+ self.assertIs(foo, result_1[1])
+
+ def testOneDevice(self):
+ result = values.regroup({_device_str(0): _nested_value("1")})
+ # On one device regroup() and select_device() are basically identity.
+ self.assertEqual(_nested_value("1"), result)
+ self.assertEqual(_nested_value("1"),
+ values.select_device(_device_str(0), result))
+
+ # The one exception has to do with MirroredVariables.
+ d = "/device:CPU:0"
+ with ops.device(d):
+ v = variable_scope.get_variable(
+ name="v", initializer=1., use_resource=True)
+ index = {d: v}
+ mirrored = values.MirroredVariable(index, v)
+ result = values.regroup(index)
+ self.assertIs(mirrored, result)
+
+ def testNamedTupleEstimatorSpec(self):
+ with context.graph_mode(), ops.Graph().as_default():
+ created_estimator_specs = {}
+ to_regroup = {}
+
+ for device_id in range(3):
+ spec = model_fn_lib.EstimatorSpec(
+ mode=model_fn_lib.ModeKeys.TRAIN,
+ loss=constant_op.constant(device_id / 2),
+ train_op=array_ops.identity(constant_op.constant(device_id)))
+ created_estimator_specs[device_id] = spec
+ to_regroup[_device_str(device_id)] = spec
+
+ merged_estimator_spec = values.regroup(to_regroup)
+
+ self.assertTrue(
+ isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec))
+ self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode)
+ for device_id in range(3):
+ d = _device_str(device_id)
+ self.assertEquals(created_estimator_specs[device_id].loss,
+ merged_estimator_spec.loss.get(d))
+ self.assertEquals(created_estimator_specs[device_id].train_op,
+ merged_estimator_spec.train_op.get(d))
+ # Scaffold is populated by `EstimatorSpec.__new__`.
+ self.assertEquals(created_estimator_specs[device_id].scaffold,
+ merged_estimator_spec.scaffold.get(d))
+ # Also test that we can undo the merge using select_device()
+ self.assertEquals(created_estimator_specs[device_id],
+ values.select_device(_device_str(device_id),
+ merged_estimator_spec))
+
+
+@test_util.with_c_api
+class PerDeviceDatasetTest(test.TestCase):
+
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = True
+
+ def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
+ per_device_dataset = values.PerDeviceDataset(
+ dataset, devices, prefetch_on_device=False)
+ iterator = per_device_dataset.make_one_shot_iterator()
+
+ for expected_value in expected_values:
+ next_element = iterator.get_next()
+ actual = self.evaluate([
+ values.select_device(d, next_element) for d in devices])
+ self.assertEqual(expected_value, actual)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ next_element = iterator.get_next()
+ self.evaluate([
+ values.select_device(d, next_element) for d in devices])
+
+ def _test_iterator_with_prefetch(self, devices, dataset, expected_values):
+ if not context.executing_eagerly():
+ per_device_dataset = values.PerDeviceDataset(
+ dataset, devices, prefetch_on_device=True)
+ iterator = per_device_dataset.make_one_shot_iterator()
+
+ # With prefetching, we cannot guarantee which input ends up on which
+ # device, so we verify that the complete set seen on all devices is
+ # correct, and equal numbers are distributed to each device.
+ combined_actual = []
+ combined_expected = []
+ for expected_value in expected_values:
+ next_element = iterator.get_next()
+ combined_actual.extend(self.evaluate([
+ values.select_device(d, next_element) for d in devices]))
+ combined_expected.extend(expected_value)
+
+ self.assertEqual(set(combined_expected), set(combined_actual))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ next_element = iterator.get_next()
+ self.evaluate([
+ values.select_device(d, next_element) for d in devices])
+
+ def _test_iterator(self, devices, dataset, expected_values):
+ self._test_iterator_no_prefetch(devices, dataset, expected_values)
+ self._test_iterator_with_prefetch(devices, dataset, expected_values)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testOneDevice(self):
+ devices = ["/device:CPU:0"]
+ dataset = dataset_ops.Dataset.range(10)
+
+ expected_values = [[i] for i in range(10)]
+
+ self._test_iterator(devices, dataset, expected_values)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testMultipleDevices(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dataset = dataset_ops.Dataset.range(10)
+
+ expected_values = [[i, i+1] for i in range(0, 10, 2)]
+
+ self._test_iterator(devices, dataset, expected_values)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testTupleDataset(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dataset1 = dataset_ops.Dataset.range(10)
+ dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
+ dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
+
+ expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)]
+
+ self._test_iterator(devices, dataset, expected_values)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testUnevenDatasetBatches(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dataset = dataset_ops.Dataset.range(11)
+
+ expected_values = [[i, i+1] for i in range(0, 10, 2)]
+ self._test_iterator(devices, dataset, expected_values)
+
+
+@test_util.with_c_api
+class MirroredVariableTest(test.TestCase):
+
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = True
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testProperties(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ v, _, mirrored = _make_mirrored()
+
+ self.assertEquals(v[0].name, mirrored.name)
+ self.assertEquals(v[0].dtype, mirrored.dtype)
+ self.assertEquals(v[0].shape, mirrored.shape)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testVariableOnAnotherDevice(self):
+ v = variable_scope.get_variable(
+ name="v", initializer=[1.], use_resource=True)
+ index = {"/job:foo/device:CPU:0": v}
+ mirrored = values.MirroredVariable(index, v)
+
+ self.assertEquals(v.name, mirrored.name)
+ self.assertEquals(v.dtype, mirrored.dtype)
+ self.assertEquals(v.shape, mirrored.shape)
+
+ def _assign_mirrored(self, devices, v, new):
+ for d, var, n in zip(devices, v, new):
+ with ops.device(d):
+ self.evaluate(var.assign(n))
+
+ def _save_return_saver(self, sess, var):
+ saver = saver_lib.Saver(var_list=[var])
+ test_dir = self.get_temp_dir()
+ prefix = os.path.join(test_dir, "ckpt")
+ return saver.save(sess, prefix), saver
+
+ def _save(self, sess, var):
+ save_path, _ = self._save_return_saver(sess, var)
+ return save_path
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testSaveAndRestoreMirroredOneGraph(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ with self.test_session() as sess:
+ v, devices, mirrored = _make_mirrored()
+
+ # Overwrite the initial values.
+ self._assign_mirrored(devices, v, [3., 4.])
+
+ # Saves the current value of v[0], 3.
+ save_path, saver = self._save_return_saver(sess, mirrored)
+
+ # Change the values between save and restore.
+ self._assign_mirrored(devices, v, [5., 6.])
+
+ # Restores the saved value of 3. to both variables.
+ saver.restore(sess, save_path)
+ self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
+
+ def _save_mirrored(self):
+ """Save variables with mirroring, returns save_path."""
+ with self.test_session(graph=ops.Graph()) as sess:
+ v, devices, mirrored = _make_mirrored()
+
+ # Overwrite the initial values.
+ self._assign_mirrored(devices, v, [3., 4.])
+
+ # Saves the current value of v[0], 3.
+ save_path = self._save(sess, mirrored)
+
+ # Change the values between save and restore.
+ self._assign_mirrored(devices, v, [5., 6.])
+ return save_path
+
+ def _save_normal(self):
+ """Save variables without mirroring, returns save_path."""
+ with self.test_session(graph=ops.Graph()) as sess:
+ var = variable_scope.get_variable(
+ name="v", initializer=1., use_resource=True)
+
+ # Overwrite the initial value.
+ self.evaluate(var.assign(3.))
+
+ # Saves the current value of var, 3.
+ save_path = self._save(sess, var)
+
+ # Change the values between save and restore.
+ self.evaluate(var.assign(5.))
+ return save_path
+
+ def _restore_normal(self, save_path):
+ """Restore to variables without mirroring in a fresh graph."""
+ with self.test_session(graph=ops.Graph()) as sess:
+ var = variable_scope.get_variable(
+ name="v", initializer=7., use_resource=True)
+
+ # Overwrite the initial value.
+ self.evaluate(var.assign(8.))
+
+ # Restores the saved value of 3. to `var`.
+ saver = saver_lib.Saver(var_list=[var])
+ saver.restore(sess, save_path)
+ self.assertEqual(3., self.evaluate(var))
+
+ def _restore_mirrored(self, save_path):
+ """Restore to variables with mirroring in a fresh graph."""
+ with self.test_session(graph=ops.Graph()) as sess:
+ v, devices, mirrored = _make_mirrored()
+
+ # Overwrite the initial values.
+ self._assign_mirrored(devices, v, [7., 8.])
+
+ # Restores the saved value of 3. to both variables.
+ saver = saver_lib.Saver(var_list=[mirrored])
+ saver.restore(sess, save_path)
+ self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testSaveMirroredRestoreMirrored(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ save_path = self._save_mirrored()
+ self._restore_mirrored(save_path)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testSaveMirroredRestoreNormal(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ save_path = self._save_mirrored()
+ self._restore_normal(save_path)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testSaveNormalRestoreMirrored(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ save_path = self._save_normal()
+ self._restore_mirrored(save_path)
+
+
+_devices = ["/device:GPU:0", "/device:CPU:0"]
+
+
+def _make_tower_local(method):
+ v = []
+ index = {}
+ for d, n, init in zip(_devices, ["v", "v/replica"], [1., 2.]):
+ with ops.device(d):
+ v.append(variable_scope.get_variable(
+ name=n, initializer=init, use_resource=True))
+ index[d] = v[-1]
+ tower_local = values.TowerLocalVariable(index, v[0], method)
+ return v, tower_local
+
+
+@test_util.with_c_api
+class TowerLocalVariableTest(test.TestCase):
+
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = True
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testProperties(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ v, tower_local = _make_tower_local("sum")
+
+ self.assertEquals(v[0].name, tower_local.name)
+ self.assertEquals(v[0].dtype, tower_local.dtype)
+ self.assertEquals(v[0].shape, tower_local.shape)
+ self.assertEquals("sum", tower_local.reduce_method)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testVariableOnAnotherDevice(self):
+ v = variable_scope.get_variable(
+ name="v", initializer=[1.], use_resource=True)
+ index = {"/job:foo/device:CPU:0": v}
+ tower_local = values.TowerLocalVariable(index, v, "mean")
+
+ self.assertEquals(v.name, tower_local.name)
+ self.assertEquals(v.dtype, tower_local.dtype)
+ self.assertEquals(v.shape, tower_local.shape)
+ self.assertEquals("mean", tower_local.reduce_method)
+
+ def _assign_tower_local(self, devices, v, new):
+ for d, var, n in zip(devices, v, new):
+ with ops.device(d):
+ self.evaluate(var.assign(n))
+
+ def _save_return_saver(self, sess, var):
+ saver = saver_lib.Saver(var_list=[var])
+ test_dir = self.get_temp_dir()
+ prefix = os.path.join(test_dir, "ckpt")
+ return saver.save(sess, prefix), saver
+
+ def _save(self, sess, var):
+ save_path, _ = self._save_return_saver(sess, var)
+ return save_path
+
+ def _dist_scope(self):
+ return mirrored_strategy.MirroredStrategy(_devices).scope()
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testSaveAndRestoreTowerLocalSumOneGraph(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ with self.test_session() as sess:
+ v, tower_local = _make_tower_local("sum")
+
+ # Overwrite the initial values.
+ self._assign_tower_local(_devices, v, [3., 4.])
+
+ with self._dist_scope():
+ # Saves the current value of v[0] + v[1], 7.
+ save_path, saver = self._save_return_saver(sess, tower_local)
+
+ # Change the values between save and restore.
+ self._assign_tower_local(_devices, v, [5., 6.])
+
+ # Restores the saved value of 7. which gets divided equally
+ # between the variables.
+ saver.restore(sess, save_path)
+ self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testSaveAndRestoreTowerLocalMeanOneGraph(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ with self.test_session() as sess:
+ v, tower_local = _make_tower_local("mean")
+
+ # Overwrite the initial values.
+ self._assign_tower_local(_devices, v, [3., 4.])
+
+ with self._dist_scope():
+ # Saves the current value of (v[0] + v[1])/2, 3.5.
+ save_path, saver = self._save_return_saver(sess, tower_local)
+
+ # Change the values between save and restore.
+ self._assign_tower_local(_devices, v, [5., 6.])
+
+ # Restores the saved value of 3.5 to both variables.
+ saver.restore(sess, save_path)
+ self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
+
+ def _save_tower_local_mean(self):
+ """Save variables with mirroring, returns save_path."""
+ with self.test_session(graph=ops.Graph()) as sess:
+ v, tower_local = _make_tower_local("mean")
+
+ # Overwrite the initial values.
+ self._assign_tower_local(_devices, v, [3., 4.])
+
+ with self._dist_scope():
+ # Saves the current value of (v[0] + v[1])/2, 3.5
+ save_path = self._save(sess, tower_local)
+
+ # Change the values between save and restore.
+ self._assign_tower_local(_devices, v, [5., 6.])
+ return save_path
+
+ def _save_tower_local_sum(self):
+ """Save variables with mirroring, returns save_path."""
+ with self.test_session(graph=ops.Graph()) as sess:
+ v, tower_local = _make_tower_local("sum")
+
+ # Overwrite the initial values.
+ self._assign_tower_local(_devices, v, [1.5, 2.])
+
+ with self._dist_scope():
+ # Saves the current value of v[0] + v[1], 3.5
+ save_path = self._save(sess, tower_local)
+
+ # Change the values between save and restore.
+ self._assign_tower_local(_devices, v, [5., 6.])
+ return save_path
+
+ def _save_normal(self):
+ """Save variables without mirroring, returns save_path."""
+ with self.test_session(graph=ops.Graph()) as sess:
+ var = variable_scope.get_variable(
+ name="v", initializer=1., use_resource=True)
+
+ # Overwrite the initial value.
+ self.evaluate(var.assign(3.5))
+
+ # Saves the current value of var, 3.5.
+ save_path = self._save(sess, var)
+
+ # Change the values between save and restore.
+ self.evaluate(var.assign(5.))
+ return save_path
+
+ def _restore_normal(self, save_path):
+ """Restore to variables without mirroring in a fresh graph."""
+ with self.test_session(graph=ops.Graph()) as sess:
+ var = variable_scope.get_variable(
+ name="v", initializer=7., use_resource=True)
+
+ # Overwrite the initial value.
+ self.evaluate(var.assign(8.))
+
+ # Restores the saved value of 3.5 to `var`.
+ saver = saver_lib.Saver(var_list=[var])
+ saver.restore(sess, save_path)
+ self.assertEqual(3.5, self.evaluate(var))
+
+ def _restore_tower_local_mean(self, save_path):
+ """Restore to variables with mirroring in a fresh graph."""
+ with self.test_session(graph=ops.Graph()) as sess:
+ v, tower_local = _make_tower_local("mean")
+
+ # Overwrite the initial values.
+ self._assign_tower_local(_devices, v, [7., 8.])
+
+ with self._dist_scope():
+ # Restores the saved value of 3.5 to both variables.
+ saver = saver_lib.Saver(var_list=[tower_local])
+ saver.restore(sess, save_path)
+ self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
+
+ def _restore_tower_local_sum(self, save_path):
+ """Restore to variables with mirroring in a fresh graph."""
+ with self.test_session(graph=ops.Graph()) as sess:
+ v, tower_local = _make_tower_local("sum")
+
+ # Overwrite the initial values.
+ self._assign_tower_local(_devices, v, [7., 8.])
+
+ with self._dist_scope():
+ # Restores the saved value of 3.5 to both variables.
+ saver = saver_lib.Saver(var_list=[tower_local])
+ saver.restore(sess, save_path)
+ self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]]))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testSaveTowerLocalRestoreTowerLocalMean(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ save_path = self._save_tower_local_mean()
+ self._restore_tower_local_mean(save_path)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testSaveTowerLocalRestoreTowerLocalSum(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ save_path = self._save_tower_local_sum()
+ self._restore_tower_local_sum(save_path)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testSaveTowerLocalMeanRestoreNormal(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ save_path = self._save_tower_local_mean()
+ self._restore_normal(save_path)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testSaveTowerLocalSumRestoreNormal(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ save_path = self._save_tower_local_sum()
+ self._restore_normal(save_path)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testSaveNormalRestoreTowerLocalMean(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ save_path = self._save_normal()
+ self._restore_tower_local_mean(save_path)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testSaveNormalRestoreTowerLocalSum(self):
+ if context.num_gpus() < 1 and context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test in eager mode.")
+
+ save_path = self._save_normal()
+ self._restore_tower_local_sum(save_path)
+
+
+if __name__ == "__main__":
+ test.main()