diff options
author | 2018-03-29 15:28:24 -0700 | |
---|---|---|
committer | 2018-03-29 15:32:01 -0700 | |
commit | 6f5d7a97cd2c0741ddfa756853ce5321377b5d53 (patch) | |
tree | e79afd91cd68bc9ed75bfe278511312da3918fe6 /tensorflow/contrib/distribute/python | |
parent | 40f8291db5c0b05b31d7bbe23b847cdbb2408718 (diff) |
Add tf.contrib.distribute, which defines classes DistributionStrategy
and MirroredStrategy, and related functionality.
Also add tf.contrib.optimizer_v2, an update to the Optimizer API.
RELNOTES: Can now pass tf.contrib.distribute.MirroredStrategy() to
tf.estimator.RunConfig() to run an Estimator model on multiple GPUs
on one machine.
PiperOrigin-RevId: 190996247
Diffstat (limited to 'tensorflow/contrib/distribute/python')
26 files changed, 5673 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD new file mode 100644 index 0000000000..4dfd3f7228 --- /dev/null +++ b/tensorflow/contrib/distribute/python/BUILD @@ -0,0 +1,431 @@ +# Implementation of a prototype TF distributed computation library. + +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# TODO(priyag): Figure out testonly issues that are preventing us from +# including our tests in pip for now. + +py_library( + name = "values", + srcs = ["values.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":prefetching_ops_v2", + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/eager/python:datasets", + "//tensorflow/python:array_ops", + "//tensorflow/python:checkpointable", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +cuda_py_test( + name = "values_test", + srcs = ["values_test.py"], + additional_deps = [ + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python:errors", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_library( + name = "mirrored_strategy", + srcs = ["mirrored_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":cross_tower_ops", + ":shared_variable_creator", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:device", + "//tensorflow/python:framework_ops", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:tape", + "@six_archive//:six", + ], +) + +py_library( + name = "one_device_strategy", + srcs = ["one_device_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":values", + "//tensorflow/contrib/eager/python:datasets", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +py_library( + name = "strategy_test_lib", + testonly = 1, + srcs = ["strategy_test_lib.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:layers", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "combinations", + testonly = 1, + srcs = ["combinations.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":mirrored_strategy", + ":one_device_strategy", + "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python/eager:context", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "combinations_test", + srcs = ["combinations_test.py"], + tags = [ + "no_pip", + ], + deps = [ + ":combinations", + "//tensorflow/python/eager:test", + ], +) + +py_test( + name = "mirrored_strategy_test", + srcs = ["mirrored_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":mirrored_strategy", + ":strategy_test_lib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_test( + name = "one_device_strategy_test", + srcs = ["one_device_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":one_device_strategy", + ":strategy_test_lib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/eager:test", + ], +) + +cuda_py_test( + name = "mirrored_strategy_multigpu_test", + srcs = ["mirrored_strategy_multigpu_test.py"], + additional_deps = [ + ":mirrored_strategy", + ":values", + ":strategy_test_lib", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:constant_op", + "//tensorflow/python:layers", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "guitar", + "no_pip", + "multi_and_single_gpu", + # Do not perform the extra analysis on this test, because it is already + # performed for the `:mirrored_strategy_test` target. + "no_oss", + "noasan", + "notap", + "notsan", + ], +) + +py_library( + name = "step_fn", + srcs = ["step_fn.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:training", + "//tensorflow/python/eager:backprop", + ], +) + +cuda_py_test( + name = "minimize_loss_test", + srcs = ["minimize_loss_test.py"], + additional_deps = [ + ":combinations", + ":single_loss_example", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python/ops/losses", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +cuda_py_test( + name = "optimizer_v2_test", + srcs = ["optimizer_v2_test.py"], + additional_deps = [ + ":combinations", + ":single_loss_example", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +py_library( + name = "single_loss_example", + srcs = ["single_loss_example.py"], + deps = [ + ":step_fn", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +cuda_py_test( + name = "step_fn_test", + srcs = ["step_fn_test.py"], + additional_deps = [ + ":single_loss_example", + ":combinations", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +py_library( + name = "monitor", + srcs = ["monitor.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + ], +) + +cuda_py_test( + name = "monitor_test", + srcs = ["monitor_test.py"], + additional_deps = [ + ":combinations", + ":monitor", + ":one_device_strategy", + ":single_loss_example", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +py_library( + name = "shared_variable_creator", + srcs = ["shared_variable_creator.py"], + visibility = ["//tensorflow:internal"], +) + +py_test( + name = "shared_variable_creator_test", + srcs = ["shared_variable_creator_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":shared_variable_creator", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:test", + ], +) + +py_binary( + name = "simple_estimator_example", + srcs = ["simple_estimator_example.py"], + deps = [ + ":mirrored_strategy", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:layers", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_library( + name = "cross_tower_utils", + srcs = ["cross_tower_utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/nccl:nccl_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + ], +) + +py_library( + name = "cross_tower_ops", + srcs = ["cross_tower_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":cross_tower_utils", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +py_test( + name = "cross_tower_ops_test", + srcs = ["cross_tower_ops_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":combinations", + ":cross_tower_ops", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", + ], +) + +py_library( + name = "prefetching_ops_v2", + srcs = ["prefetching_ops_v2.py"], + deps = [ + "//tensorflow/contrib/data/python/ops:contrib_op_loader", + "//tensorflow/contrib/data/python/ops:prefetching_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +cuda_py_test( + name = "prefetching_ops_v2_test", + srcs = ["prefetching_ops_v2_test.py"], + additional_deps = [ + ":prefetching_ops_v2", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + ], +) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py new file mode 100644 index 0000000000..dd8e7c4376 --- /dev/null +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -0,0 +1,293 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Facilities for creating multiple test combinations. + +Here is an example of testing various optimizers in Eager and Graph mode: + +class AdditionExample(test.TestCase, parameterized.TestCase): + @combinations.generate( + combinations.combine(mode=["graph", "eager"], + optimizer=[AdamOptimizer(), + GradientDescentOptimizer()])) + def testOptimizer(self, optimizer): + ... f(optimizer)... + +This will run `testOptimizer` 4 times with the specified optimizers: 2 in +Eager and 2 in Graph mode. +The test will be provided with arguments that match the arguments of combine +by name. It is necessary to request all arguments, except for `mode`, which is +optional. + +`combine()` function is available for creating a cross product of various +options. `times()` function exists for creating a product of N `combine()`-ed +results. See below. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import sys +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.optimizer_v2 import adam as adam_v2 +from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.training import adam +from tensorflow.python.training import gradient_descent +from tensorflow.python.util import tf_inspect + + +GPU_TEST = "test_gpu" in sys.argv[0] + + +def generate(combinations): + """A decorator for generating test cases of a test method or a test class. + + Args: + combinations: a list of dictionaries created using combine() and times(). + + Restrictions: + -- there should always be a "mode" argument. Accepted values are "eager" + and "graph". + -- arguments of the test method must match by name to get the corresponding + value of the combination. Tests must accept all arguments (except "mode", + which is optional). + -- distribution argument is special. It is meant for passing instances of + DistributionStrategy. Each instance is to be passed as `(<int>, + <DistributionStrategy>)` tuple, where <int> is the number of required + GPUs. If the required number of GPUs for the DistributionStrategy isn't + available then the test case is going to be skipped. + + Returns: + a decorator that will cause the test method to be run under the specified + conditions. + + Raises: + ValueError - if "mode" argument wasn't either "eager" or "graph. + """ + + def decorator(test_function): + """The decorator to be returned.""" + + # Generate good test names that can be used with --test_filter. + for combination in combinations: + # We use OrderedDicts in `combine()` and `times()` to ensure stable + # order of keys in each dictionary. + assert isinstance(combination, OrderedDict) + name = "".join([ + "_{}_{}".format( + "".join(filter(str.isalnum, key)), + "".join(filter(str.isalnum, str(value)))) + for key, value in combination.items() + ]) + combination.update({"testcase_name": "_test{}".format(name)}) + + @parameterized.named_parameters(*combinations) + def decorated(self, **kwargs): + """A wrapped test method that sets up `test_function`.""" + assert "mode" in kwargs + mode = kwargs["mode"] + + if "distribution" in kwargs: + distribution = kwargs["distribution"] + kwargs["distribution"] = distribution.strategy + if not distribution.required_gpus: + if GPU_TEST: + self.skipTest("Test that doesn't require GPUs.") + elif context.num_gpus() < distribution.required_gpus: + self.skipTest( + "{} GPUs are not available for this test. {} GPUs are available". + format(distribution.required_gpus, context.num_gpus())) + + requested_arguments = tf_inspect.getfullargspec(test_function).args + missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( + set(requested_arguments + ["mode"])) + if missing_arguments: + raise ValueError("The test is missing arguments {} .".format( + missing_arguments)) + + kwargs_to_pass = {} + for arg in requested_arguments: + if arg == "self": + kwargs_to_pass[arg] = self + else: + kwargs_to_pass[arg] = kwargs[arg] + + if mode == "eager": + with context.eager_mode(), ops.Graph().as_default(): + test_function(**kwargs_to_pass) + elif mode == "graph": + with context.graph_mode(), ops.Graph().as_default(): + test_function(**kwargs_to_pass) + else: + raise ValueError( + "'mode' has to be either 'eager' or 'graph' and not {}".format( + mode)) + + return decorated + return decorator + + +def combine(**kwargs): + """Generate combinations based on its keyword arguments. + + Two sets of returned combinations can be concatenated using +. Their product + can be computed using `times()`. + + Args: + **kwargs: keyword arguments of form `option=[possibilities, ...]`. + + Returns: + a list of dictionaries for each combination. Keys in the dictionaries are + the keyword argument names. Each key has one value - one of the + corresponding keyword argument values. + """ + if not kwargs: + return [OrderedDict()] + + sort_by_key = lambda k: k[0][0] + kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) + first = list(kwargs.items())[0] + + rest = dict(list(kwargs.items())[1:]) + rest_combined = combine(**rest) + + key = first[0] + values = first[1] + + return [ + OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) + for v in values + for combined in rest_combined + ] + + +def times(*combined): + """Generate a product of N sets of combinations. + + times(combine(a=[1,2]), combine(b=[3,4])) == combine(a=[1,2], b=[3,4]) + + Args: + *combined: N lists of dictionaries that specify combinations. + + Returns: + a list of dictionaries for each combination. + + Raises: + ValueError: if some of the inputs have overlapping keys. + """ + assert combined + + if len(combined) == 1: + return combined[0] + + first = combined[0] + rest_combined = times(*combined[1:]) + + combined_results = [] + for a in first: + for b in rest_combined: + if set(a.keys()).intersection(set(b.keys())): + raise ValueError("Keys need to not overlap: {} vs {}".format( + a.keys(), b.keys())) + + combined_results.append(OrderedDict(list(a.items()) + list(b.items()))) + return combined_results + + +class NamedObject(object): + """A class that translates an object into a good test name.""" + + def __init__(self, name, obj): + self._name = name + self._obj = obj + + def __getattr__(self, name): + return getattr(self._obj, name) + + def __call__(self, *args, **kwargs): + return self._obj(*args, **kwargs) + + def __repr__(self): + return self._name + + +class NamedDistribution(object): + """Translates DistributionStrategy and its data into a good name.""" + + def __init__(self, name, distribution, required_gpus): + self._distribution = distribution + self._name = name + self._required_gpus = required_gpus + + def __repr__(self): + return self._name + + @property + def strategy(self): + return self._distribution + + @property + def required_gpus(self): + return self._required_gpus + + +one_device_strategy = NamedDistribution( + "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"), + None) +mirrored_strategy_with_gpu_and_cpu = NamedDistribution( + "MirroredCPUAndGPU", + mirrored_strategy.MirroredStrategy(["/gpu:0", "/cpu:0"]), 1) +mirrored_strategy_with_two_gpus = NamedDistribution( + "Mirrored2GPUs", + mirrored_strategy.MirroredStrategy(["/gpu:0", "/gpu:1"]), 2) + +adam_optimizer_v1_fn = NamedObject( + "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) +gradient_descent_optimizer_v1_fn = NamedObject( + "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2)) + +adam_optimizer_v2_fn = NamedObject( + "AdamV2", lambda: adam_v2.AdamOptimizer(0.2, epsilon=1)) +gradient_descent_optimizer_v2_fn = NamedObject( + "GradientDescentV2", + lambda: gradient_descent_v2.GradientDescentOptimizer(0.2)) + +graph_and_eager_modes = ["graph", "eager"] + + +def distributions_and_v1_optimizers(): + """A common set of combination with DistributionStrategies and Optimizers.""" + return combine( + distribution=[ + one_device_strategy, mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus + ], + optimizer_fn=[adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn]) + + +def distributions_and_v2_optimizers(): + """DistributionStrategies and V2 Optimizers.""" + return combine( + distribution=[ + one_device_strategy, mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus + ], + optimizer_fn=[adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn]) diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py new file mode 100644 index 0000000000..219b24160f --- /dev/null +++ b/tensorflow/contrib/distribute/python/combinations_test.py @@ -0,0 +1,115 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for some testing utils from strategy_test_lib.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.eager import test + + +class TestingCombinationsTest(test.TestCase): + + def test_combine(self): + self.assertEqual([{ + "a": 1, + "b": 2 + }, { + "a": 1, + "b": 3 + }, { + "a": 2, + "b": 2 + }, { + "a": 2, + "b": 3 + }], combinations.combine(a=[1, 2], b=[2, 3])) + + def test_add(self): + self.assertEqual( + [{ + "a": 1 + }, { + "a": 2 + }, { + "b": 2 + }, { + "b": 3 + }], + combinations.combine(a=[1, 2]) + + combinations.combine(b=[2, 3])) + + def test_times(self): + c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"]) + c2 = combinations.combine(mode=["eager"], loss=["callable"]) + c3 = combinations.combine(distribution=["d1", "d2"]) + c4 = combinations.times(c3, c1 + c2) + self.assertEqual([ + OrderedDict([("distribution", "d1"), ("loss", "callable"), + ("mode", "graph")]), + OrderedDict([("distribution", "d1"), ("loss", "tensor"), + ("mode", "graph")]), + OrderedDict([("distribution", "d1"), ("loss", "callable"), + ("mode", "eager")]), + OrderedDict([("distribution", "d2"), ("loss", "callable"), + ("mode", "graph")]), + OrderedDict([("distribution", "d2"), ("loss", "tensor"), + ("mode", "graph")]), + OrderedDict([("distribution", "d2"), ("loss", "callable"), + ("mode", "eager")]) + ], c4) + + def test_times_variable_arguments(self): + c1 = combinations.combine(mode=["graph", "eager"]) + c2 = combinations.combine(optimizer=["adam", "gd"]) + c3 = combinations.combine(distribution=["d1", "d2"]) + c4 = combinations.times(c3, c1, c2) + self.assertEqual([ + OrderedDict([("distribution", "d1"), ("mode", "graph"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d1"), ("mode", "graph"), + ("optimizer", "gd")]), + OrderedDict([("distribution", "d1"), ("mode", "eager"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d1"), ("mode", "eager"), + ("optimizer", "gd")]), + OrderedDict([("distribution", "d2"), ("mode", "graph"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d2"), ("mode", "graph"), + ("optimizer", "gd")]), + OrderedDict([("distribution", "d2"), ("mode", "eager"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d2"), ("mode", "eager"), + ("optimizer", "gd")]) + ], c4) + self.assertEqual( + combinations.combine( + mode=["graph", "eager"], + optimizer=["adam", "gd"], + distribution=["d1", "d2"]), c4) + + def test_overlapping_keys(self): + c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"]) + c2 = combinations.combine(mode=["eager"], loss=["callable"]) + with self.assertRaisesRegexp(ValueError, ".*Keys.+overlap.+"): + _ = combinations.times(c1, c2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py new file mode 100644 index 0000000000..cb98351735 --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -0,0 +1,410 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes for different algortihms of reduction and broadcasting.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import device_util + + +def _validate_destinations(destinations): + if not isinstance(destinations, + (value_lib.DistributedValues, six.string_types, list)): + raise ValueError("destinations must be one of a `DistributedValues` object," + " a device string, a list of device strings or None") + + if not destinations: + raise ValueError("destinations can not be empty") + + +def _validate_value_destination_pairs(value_destination_pairs): + # pylint: disable=g-missing-docstring + if not value_destination_pairs: return False + if not isinstance(value_destination_pairs, (list, tuple)): return False + if not all([isinstance(pair, tuple) for pair in value_destination_pairs]): + return False + if not all([isinstance(v[0], value_lib.PerDevice) + for v in value_destination_pairs]): + return False + return True + + +def _get_devices_from(destinations): + if isinstance(destinations, value_lib.DistributedValues): + return list(destinations.devices) + elif isinstance(destinations, six.string_types): + return [device_util.canonicalize(destinations)] + else: + return [ + device_util.canonicalize(destination) for destination in destinations + ] + + +def _devices_match(left, right): + return set(_get_devices_from(left)) == set(_get_devices_from(right)) + + +def _all_devices_match(value_destination_pairs): + if not all([d is None or _devices_match(v, d) + for v, d in value_destination_pairs]): + return False + if not all([_devices_match(v, value_destination_pairs[0][0]) + for v, _ in value_destination_pairs[1:]]): + return False + return True + + +def _simple_broadcast(tensor, destinations): + index = {} + devices = _get_devices_from(destinations) + for d in devices: + with ops.device(d): + index[d] = array_ops.identity(tensor) + return value_lib.Mirrored(index) + + +def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, + method_string): + # pylint: disable=g-missing-docstring + all_values = [] + count = 0 + for v in per_device_value._index.values(): # pylint: disable=protected-access + if isinstance(v, value_lib.MapOutput): + v_list = v.get() + if not v_list: + continue + count += len(v_list) + # Sum within each device before aggregating across devices. + v = math_ops.add_n(v_list) + else: + count += 1 + all_values.append(v) + if not all_values: + raise ValueError("`per_device_value` must be non-empty") + + with ops.device(reduce_to_device): + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + if method_string == "sum": + reduced = accumulation_fn(all_values) + elif method_string == "mean": + reduced = accumulation_fn(all_values) / count + else: + raise ValueError("`method_string` must be 'sum' or 'mean'") + return reduced + + +class CrossTowerOps(object): + """Base class for cross-tower reduction and broadcasting algorithms.""" + + def __init__(self): + pass + + def reduce(self, method_string, per_device_value, destinations=None): + """Reduce `per_device_value` to `destinations`. + + It runs the reduction operation defined by `method_string` and put the + result on `destinations`. + + Args: + method_string: either 'sum' or 'mean' specifying the reduction method. + per_device_value: a PerDevice object. + destinations: the reduction destinations. + + Returns: + a Mirrored object. + + Raises: + ValueError: if per_device_value is not a PerDevice object. + """ + if not isinstance(per_device_value, value_lib.PerDevice): + raise ValueError("`per_device_value` must be a `PerDevice` object.") + if destinations is not None: + _validate_destinations(destinations) + return self._reduce(method_string, per_device_value, destinations) + + def batch_reduce(self, method_string, value_destination_pairs): + """Reduce PerDevice objects in a batch. + + Reduce each first element in `value_destination_pairs` to each second + element which indicates the destinations. + + Args: + method_string: either 'sum' or 'mean' specifying the reduction method. + value_destination_pairs: a list or a tuple of tuples of PerDevice objects + and destinations. If a destionation is None, then the destinations + are set to match the devices of the input PerDevice object. + + Returns: + a list of Mirrored objects. + + Raises: + ValueError: if `value_destination_pairs` is not a list or a tuple of + tuples of PerDevice objects and destinations + """ + if not _validate_value_destination_pairs(value_destination_pairs): + raise ValueError("`value_destination_pairs` must be a list or a tuple of " + "tuples of PerDevice objects and destinations") + for _, d in value_destination_pairs: + if d is not None: + _validate_destinations(d) + + return self._batch_reduce(method_string, value_destination_pairs) + + def broadcast(self, tensor, destinations): + """Broadcast the `tensor` to destinations. + + Args: + tensor: the tensor to broadcast. + destinations: the broadcast destinations. + + Returns: + a Mirrored object. + """ + _validate_destinations(destinations) + return self._broadcast(tensor, destinations) + + def _reduce(self, method_string, per_device_value, destinations): + raise NotImplementedError( + "_reduce method must be implemented in descendants.") + + def _batch_reduce(self, method_string, value_destination_pairs): + raise NotImplementedError( + "_batch_reduce method must be implemented in descendants.") + + def _broadcast(self, tensor, destinations): + return _simple_broadcast(tensor, destinations) + + +class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): + """Always do reduction to one device first and then do broadcasting. + + Batch reduction is done by reduction on each element one by one. + """ + + def __init__(self, reduce_to_device=None, accumulation_fn=math_ops.add_n): + """Constructor. + + Args: + reduce_to_device: the intermediate device to reduce to. If None, reduce + to the first device in `destinations` of the reduce() method. + accumulation_fn: a function that does accumulation. + """ + self.reduce_to_device = reduce_to_device + self.accumulation_fn = accumulation_fn + super(ReductionToOneDeviceCrossTowerOps, self).__init__() + + def _reduce(self, method_string, per_device_value, destinations): + devices = _get_devices_from(destinations or per_device_value) + reduce_to_device = self.reduce_to_device or devices[0] + reduced = _simple_reduce(per_device_value, reduce_to_device, + self.accumulation_fn, method_string) + return self.broadcast(reduced, devices) + + def _batch_reduce(self, method_string, value_destination_pairs): + return [self._reduce(method_string, t, destinations=v) + for t, v in value_destination_pairs] + + +def _group_value_by_device(per_device_values): + """Group values into sublists by their devices. + + This grouping is needed to call the allreduce library. + + Args: + per_device_values: a list of PerDevice obejcts. + + Returns: + a list of lists, each sublist has components for its corresponding device of + PerDevice objects, paired with a None. + """ + destinations = per_device_values[0].devices + grouped = [[] for _ in range(len(destinations))] + for per_device_value in per_device_values: + # pylint: disable=protected-access + for i, v in enumerate(per_device_value._index.values()): + assert per_device_value.devices == destinations + grouped[i].append((v, None)) + return grouped + + +def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string): + """Ungroup results from allreduce and make Mirrored objects. + + Each allreduce result would be divided by the number of destinations before + Mirrored objects are created if method_string is "mean". + """ + index = [{} for _ in range(len(grouped_reduced[0]))] + for d, per_device_reduced in enumerate(grouped_reduced): + for i, (v, _) in enumerate(per_device_reduced): + if method_string == "mean": + index[i][destinations[d]] = v / len(destinations) + else: + index[i][destinations[d]] = v + return [value_lib.Mirrored(v) for v in index] + + +class AllReduceCrossTowerOps(CrossTowerOps): + """Reduction using all reduce.""" + + def __init__(self, all_reduce_alg="nccl", gradient_repacking=1): + """Initialize this subclass of CrossTowerOps with allreduce. + + Gradients would be repacked for more efficient cross-device transportation. + + Args: + all_reduce_alg: the allreduce algorithm to use, currently only "nccl" or + "hierarchical_copy" are supported. + gradient_repacking: If zero, no gradient repacking would be done. If + non-zero value it specifies the number of split packs that will be + formed. + """ + self.all_reduce_alg = all_reduce_alg + self.gradient_repacking = gradient_repacking + super(AllReduceCrossTowerOps, self).__init__() + + def _reduce(self, method_string, per_device_value, destinations): + if ((destinations is None or _devices_match(per_device_value, destinations)) + and not context.executing_eagerly()): + return self._batch_all_reduce(method_string, [per_device_value])[0] + else: + devices = _get_devices_from(destinations or per_device_value) + reduce_to_device = devices[0] + reduced = _simple_reduce(per_device_value, reduce_to_device, + math_ops.add_n, method_string) + return self.broadcast(reduced, devices) + + def _batch_reduce(self, method_string, value_destination_pairs): + if (_all_devices_match(value_destination_pairs) and + not context.executing_eagerly()): + return self._batch_all_reduce(method_string, + [v[0] for v in value_destination_pairs]) + else: + if not context.executing_eagerly(): + logging.warning("Efficient batch_reduce is not supported if " + "destinations are different.") + return [ + self._reduce(method_string, t, destinations=v) + for t, v in value_destination_pairs + ] + + def _batch_all_reduce(self, method_string, per_device_values): + """All reduce algorithm in a batch.""" + logging.info("batch_all_reduce invoked for batches size = %d with algorithm" + " = %s and gradient repacking = %d", len(per_device_values), + self.all_reduce_alg, self.gradient_repacking) + destinations = per_device_values[0].devices + grouped = _group_value_by_device(per_device_values) + if self.gradient_repacking == 0: + if self.all_reduce_alg == "nccl": + reduced = cross_tower_utils.aggregate_gradients_using_nccl(grouped) + else: + # TODO(yuefengz): check that gpu ids in `destinations` are in ascending + # order. + reduced = ( + cross_tower_utils.aggregate_gradients_using_hierarchical_copy( + destinations, grouped)) + else: + device_grad_packs = [] + all_tower_shapes = [] + all_tower_sizes = [] + for tower_grads_and_vars in grouped: + with ops.colocate_with(tower_grads_and_vars[0][0]): + # Flatten all the grads. + flat_grads = [ + array_ops.reshape(g, [-1]) for g, _ in tower_grads_and_vars + ] + # Remember the original shape of all the grads. + tower_shapes = [array_ops.shape(g) for g, _ in tower_grads_and_vars] + # Remember the original sizes of all the grads. + tower_sizes = [array_ops.size(g) for g, _ in tower_grads_and_vars] + # Concat all the flat grads into a big flat tensor. + concat_grads = array_ops.concat(flat_grads, 0) + + # Split the big tensor into num_splits packs. In cases where the + # total size is not divisible num_splits, the last pack gets + # more elements. + # TODO(zhengxq): it is possible to optimize away the additional + # data movement by copying along the original variable boundary. + # TODO(zhengxq): it is also possible to optimize away all the concat + # as well. + num_splits = self.gradient_repacking + total_grad_size = array_ops.size(concat_grads) + split_size = total_grad_size // num_splits + split_size_last = total_grad_size - split_size * (num_splits - 1) + split_sizes = [split_size] * (num_splits - 1) + [split_size_last] + grad_packs = array_ops.split(concat_grads, split_sizes) + + # Ready to aggregate the repacked gradients, with fake variables. + # TODO(zhengxq): It is hacky to have to use fake variables. + # We should remove the need for variables in + # aggregate_gradients_using*. + device_grad_packs.append(zip(grad_packs, [None] * num_splits)) + all_tower_shapes.append(tower_shapes) + all_tower_sizes.append(tower_sizes) + + # The actual aggregation of the repacked gradients. Note that they are + # sharded among different aggregation trees. So it is important to + # strike the balance on num_splits. + if self.all_reduce_alg == "nccl": + summed_device_grad_packs = ( + cross_tower_utils.aggregate_gradients_using_nccl(device_grad_packs)) + else: + summed_device_grad_packs = ( + cross_tower_utils.aggregate_gradients_using_hierarchical_copy( + destinations, device_grad_packs)) + + aggregated_device_grads = [] + for (summed_tower_grad_packs, tower_grads_and_vars, tower_shapes, + tower_sizes) in zip(summed_device_grad_packs, grouped, + all_tower_shapes, all_tower_sizes): + # pylint: enable=line-too-long + # Reverse the packing operations in the previous steps. Form the + # summed gradients back into their original shapes. + with ops.colocate_with(summed_tower_grad_packs[0][0]): + # Form a list of the summed grad packs. + device_grad_packs = [g for g, _ in summed_tower_grad_packs] + + # Concat them back into a big flat tensor. + device_grads_concat = array_ops.concat(device_grad_packs, 0) + + # Split the tensors back into their original sizes. + grads_with_sizes = array_ops.split(device_grads_concat, tower_sizes) + + # Reshape the tensors back into their original shapes. + grads_with_shapes = [ + array_ops.reshape(grad, shape) + for shape, grad in zip(tower_shapes, grads_with_sizes) + ] + + # Form the list with the original list of variables. + summed_tower_grads = [ + (g, v) + for g, (_, v) in zip(grads_with_shapes, tower_grads_and_vars) + ] + aggregated_device_grads.append(summed_tower_grads) + reduced = aggregated_device_grads + return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, + method_string) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py new file mode 100644 index 0000000000..bb43147f5e --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -0,0 +1,185 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CrossTowerOps.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def _make_per_device(values, devices): + devices = cross_tower_ops_lib._get_devices_from(devices) + assert len(values) == len(devices) + index = {} + for d, v in zip(devices, values): + with ops.device(d): + placed_v = array_ops.identity(v) + index[d] = placed_v + return value_lib.PerDevice(index) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def _fake_mirrored(value, devices): + """Create a faked Mirrored object for testing. + + All components of the returned Mirrored have the same objects, which is not + true in reality. + """ + devices = cross_tower_ops_lib._get_devices_from(devices) + return value_lib.Mirrored( + {d: v for d, v in zip(devices, [value] * len(devices))}) + + +_cpu_device = "/device:CPU:0" + + +class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): + + def _assert_value_equal(self, left, right): + if isinstance(left, list): + for l, r in zip(left, right): + self._assert_value_equal(l, r) + else: + self.assertEqual(type(left), type(right)) + self.assertEqual(left.devices, right.devices) + if context.executing_eagerly(): + self.assertEqual([v.numpy() for v in left._index.values()], + list(right._index.values())) + else: + with self.test_session() as sess: + self.assertEqual( + sess.run(list(left._index.values())), list(right._index.values())) + + # TODO(yuefengz): decouple the num_gpus check from distribution in + # combinations module so that we can pass in devices instead of a distribution + # strategy. + reduction_to_one_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "DefaultReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "ReductionToCPUDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + reduce_to_device=_cpu_device)), + combinations.NamedObject( + "AccumulateNCrossTowerOp", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + accumulation_fn=math_ops.accumulate_n)), + ], + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ], + mode=["graph", "eager"]) + allreduce_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject("AllReduce", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "nccl", 1)), + combinations.NamedObject("HierarchicalCopy", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 8)), + combinations.NamedObject("AllReduceNoGradientRepacking", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "nccl", 0)), + combinations.NamedObject("HierarchicalCopyNoGradientRepacking", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 0)) + ], + distribution=[ + combinations.mirrored_strategy_with_two_gpus + ], + mode=["graph", "eager"]) + + @combinations.generate(reduction_to_one_combinations + allreduce_combinations) + def testReductionAndBroadcast(self, cross_tower_ops, distribution): + devices = distribution.worker_devices + + values = [constant_op.constant(float(d)) for d in range(len(devices))] + per_device = _make_per_device(values, devices) + mean = (len(devices) - 1.) / 2. + + values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))] + per_device_2 = _make_per_device(values_2, devices) + mean_2 = mean + 1. + + destination_mirrored = _fake_mirrored(1., devices) + destination_different = _fake_mirrored(1., _cpu_device) + destination_str = _cpu_device + destination_list = devices + + all_destinations = [ + None, destination_mirrored, destination_different, destination_str, + destination_list + ] + + # test reduce() + for destinations in all_destinations: + self._assert_value_equal( + cross_tower_ops.reduce("mean", per_device, destinations=destinations), + _fake_mirrored(mean, destinations or per_device)) + self._assert_value_equal( + cross_tower_ops.reduce( + "mean", per_device_2, destinations=destinations), + _fake_mirrored(mean_2, destinations or per_device)) + self._assert_value_equal( + cross_tower_ops.reduce("sum", per_device, destinations=destinations), + _fake_mirrored(mean * len(devices), destinations or per_device)) + self._assert_value_equal( + cross_tower_ops.reduce( + "sum", per_device_2, destinations=destinations), + _fake_mirrored(mean_2 * len(devices), destinations or per_device)) + + # test batch_reduce() + for d1, d2 in itertools.product(all_destinations, all_destinations): + self._assert_value_equal( + cross_tower_ops.batch_reduce( + "mean", [(per_device, d1), (per_device_2, d2)]), + [_fake_mirrored(mean, d1 or per_device), + _fake_mirrored(mean_2, d2 or per_device_2)]) + self._assert_value_equal( + cross_tower_ops.batch_reduce( + "sum", [(per_device, d1), (per_device_2, d2)]), + [_fake_mirrored(mean * len(devices), d1 or per_device), + _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)]) + + # test broadcast() + for destinations in all_destinations: + if destinations is None: + continue + else: + self._assert_value_equal( + cross_tower_ops.broadcast(constant_op.constant(1.), destinations), + _fake_mirrored(1., destinations)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py new file mode 100644 index 0000000000..93acd835d7 --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -0,0 +1,153 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for cross_tower_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import nccl +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def aggregate_gradients_using_nccl(tower_grads): + """Aggregate gradients using nccl allreduce.""" + agg_all_g_and_v = [] + for single_g_and_v in zip(*tower_grads): + single_grads = [g for g, _ in single_g_and_v] + agg_grads = nccl.all_sum(single_grads) + agg_all_g_and_v.append( + [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)]) + + agg_all_g_and_v = list(zip(*agg_all_g_and_v)) + + return agg_all_g_and_v + + +def aggregate_gradients_using_hierarchical_copy(avail_devices, tower_grads): + """Aggregate gradients using hierarchical copies. + + Args: + avail_devices: available GPU devices. + tower_grads: List of lists of (gradient, variable) tuples. The outer list + is over towers. The inner list is over individual gradients. + + Returns: + The list of (aggregated_gradient, variable), where the gradient has been + summed across all towers and the variable is chosen from the first tower. + """ + # This only works for DGX-1 type of machine topology + # Device peer to peer matrix + # DMA: 0 1 2 3 4 5 6 7 + # 0: Y Y Y Y Y N N N + # 1: Y Y Y Y N Y N N + # 2: Y Y Y Y N N Y N + # 3: Y Y Y Y N N N Y + # 4: Y N N N Y Y Y Y + # 5: N Y N N Y Y Y Y + # 6: N N Y N Y Y Y Y + # 7: N N N Y Y Y Y Y + agg_grads = [] + num_devices = len(avail_devices) + # In the special case of DGX-1 machine topology, the two groups have equal + # size. + group_size = num_devices // 2 + for i, single_grads in enumerate(zip(*tower_grads)): + group_0_main_device = i % num_devices + group_1_main_device = (group_0_main_device + group_size) % num_devices + if group_0_main_device < group_size: + group_0_begin = 0 + group_1_begin = group_size + else: + group_0_begin = group_size + group_1_begin = 0 + + # Aggregate the first group. + group_0_device_grads = single_grads[group_0_begin: + group_0_begin + group_size] + with ops.device(avail_devices[group_0_main_device]): + group_0_agg_grads, _ = aggregate_single_gradient_using_copy( + group_0_device_grads, False, False) + + # Aggregate the second group. + group_1_device_grads = single_grads[group_1_begin: + group_1_begin + group_size] + with ops.device(avail_devices[group_1_main_device]): + group_1_agg_grads, _ = aggregate_single_gradient_using_copy( + group_1_device_grads, False, False) + + # Aggregate between the groups. + with ops.device(avail_devices[group_0_main_device]): + (agg_total_grads, _), _ = aggregate_single_gradient_using_copy( + [group_0_agg_grads, group_1_agg_grads], False, False) + + # Broadcast the result back into the root of each group. + with ops.device(avail_devices[group_0_main_device]): + group_0_agg_grads_bcast = array_ops.identity(agg_total_grads) + with ops.device(avail_devices[group_1_main_device]): + group_1_agg_grads_bcast = array_ops.identity(agg_total_grads) + + agg_grads_bcast = [] + for j in range(len(single_grads)): + with ops.device(avail_devices[j]): + # Broadcast the result back to each member in the group from the root. + if (group_0_main_device < group_size) == (j < group_size): + src_device_grad = group_0_agg_grads_bcast + else: + src_device_grad = group_1_agg_grads_bcast + agg_grads_bcast.append(array_ops.identity(src_device_grad)) + + agg_grads.append( + [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)]) + + agg_grads = list(zip(*agg_grads)) + + return agg_grads + + +def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, + check_inf_nan): + """Calculate the average gradient for a shared variable across all towers. + + Note that this function provides a synchronization point across all towers. + + Args: + grad_and_vars: A list or tuple of (gradient, variable) tuples. Each + (gradient, variable) pair within the outer list represents the gradient + of the variable calculated for a single tower, and the number of pairs + equals the number of towers. + use_mean: if True, mean is taken, else sum of gradients is taken. + check_inf_nan: check grads for nans and infs. + + Returns: + The tuple ([(average_gradient, variable),], has_nan_or_inf) where the + gradient has been averaged across all towers. The variable is chosen from + the first tower. The has_nan_or_inf indicates the grads has nan or inf. + """ + grads = [g for g, _ in grad_and_vars] + grad = math_ops.add_n(grads) + + if use_mean and len(grads) > 1: + grad = array_ops.multiply(grad, 1.0 / len(grads)) + + v = grad_and_vars[0][1] + if check_inf_nan: + has_nan_or_inf = array_ops.logical_not( + array_ops.reduce_all(array_ops.is_finite(grads))) + return (grad, v), has_nan_or_inf + else: + return (grad, v), None diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py new file mode 100644 index 0000000000..0fa90df79b --- /dev/null +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -0,0 +1,279 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for running legacy optimizer code with DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example +from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.ops.losses import losses_impl + + +class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + def testTrainNetwork(self, distribution, optimizer_fn, + use_callable_loss=True): + with distribution.scope(): + model_fn, dataset, layer = minimize_loss_example( + optimizer_fn, + use_bias=True, + use_callable_loss=use_callable_loss) + + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return distribution.group( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), run_concurrently=layer.built)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(10): + run_step() + + weights.append(self.evaluate(distribution.fetch(layer.kernel))) + biases.append(self.evaluate(distribution.fetch(layer.bias))) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers() + + combinations.distributions_and_v2_optimizers(), + combinations.combine(mode=["graph", "eager"]))) + def testOptimizerInsideModelFn(self, distribution, optimizer_fn): + created_variables = [] + trainable_variables = [] + + def appending_creator(next_creator, *args, **kwargs): + v = next_creator(*args, **kwargs) + created_variables.append(v.name) + if "trainable" in kwargs and kwargs["trainable"]: + trainable_variables.append(v.name) + return v + + # Creator scope needs to be set before it's used inside + # `distribution.scope`. + with variable_scope.variable_creator_scope( + appending_creator), distribution.scope(): + model_fn, dataset, layer = minimize_loss_example( + optimizer_fn, + use_bias=True, + use_callable_loss=True, + create_optimizer_inside_model_fn=True) + + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return distribution.group( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), run_concurrently=layer.built)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + run_step() + + def get_expected_variables(optimizer_fn, num_parameter_devices): + variables_map = { + "GradientDescent": ["dense/kernel", "dense/bias"], + "Adam": [ + "dense/kernel", "dense/bias", "beta1_power", "beta2_power", + "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam", + "dense/bias/Adam_1" + ] + } + variables = variables_map[optimizer_fn().get_name()] + variables.extend([ + v + "/replica_{}".format(replica) + for v in variables + for replica in range(1, num_parameter_devices) + ]) + return set([v + ":0" for v in variables]) + + self.assertEqual( + get_expected_variables(optimizer_fn, + len(distribution.parameter_devices)), + set(created_variables)) + + @combinations.generate( + combinations.times(combinations.distributions_and_v1_optimizers(), + combinations.combine( + mode=["graph", "eager"], + momentum=[0.8, 0.9, 0.99], + renorm=[False, True]))) + def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, + renorm): + """Verifies that moving mean updates are reduced across towers.""" + with distribution.scope(): + num_towers = len(distribution.worker_devices) + model_fn, dataset, batchnorm = batchnorm_example( + optimizer_fn, + batch_per_epoch=num_towers, + momentum=momentum, + renorm=renorm) + + # Disable prefetching since that makes the specific input on each device + # to be non deterministic, and this test relies on specific input being + # on each device. + if isinstance(distribution, mirrored_strategy.MirroredStrategy): + distribution._prefetch_on_device = False + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return control_flow_ops.group( + distribution.unwrap( + distribution.call_for_each_tower( + model_fn, + iterator.get_next(), + run_concurrently=batchnorm.built)) + + ops.get_collection(ops.GraphKeys.UPDATE_OPS)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + expected_moving_means = [0.] * 8 + + def averaged_batch_mean(i): + # Each batch has shape [16, 8] where the ith element in jth list is + # (8 * j + i + tower_id * 100). So the batch mean in each tower is + # (60 + i + tower_id * 100). So here comes its batch mean over all + # towers: + return 60. + i + (num_towers - 1.) / 2. * 100. + + for _ in range(10): + run_step() + moving_means = self.evaluate(distribution.fetch(batchnorm.moving_mean)) + + # We make sure that the moving_mean is updated as if the sample mean is + # calculated over all towers. + for i, expected_moving_mean in enumerate(expected_moving_means): + expected_moving_means[i] -= (( + expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) + self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) + + @combinations.generate( + combinations.times( + combinations.combine( + distribution=[combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + optimizer_fn=[combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn], + loss_reduction=[losses_impl.Reduction.SUM, + losses_impl.Reduction.MEAN, + losses_impl.Reduction.SUM_OVER_BATCH_SIZE, + losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS]), + combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, + use_callable_loss): + with distribution.scope(): + all_vars = [] + + def model_fn(x, y): + + def loss_fn(): + # Use fixed initialization to make the steps deterministic. + w = variable_scope.get_variable("w", initializer=[[2.]]) + all_vars.append(w) + predict = math_ops.matmul(x, w) + return losses_impl.mean_squared_error( + y, predict, reduction=loss_reduction) + + optimizer = optimizer_fn() # GradientDescent with 0.2 learning rate + + if use_callable_loss: + return optimizer.minimize(loss_fn) + else: + return optimizer.minimize(loss_fn()) + + features = dataset_ops.Dataset.from_tensors([[2.], [7.]]) + labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) + dataset = dataset_ops.Dataset.zip((features, labels)).repeat() + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return distribution.group( + distribution.call_for_each_tower( + model_fn, *iterator.get_next(), run_concurrently=False)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + run_step() + + self.assertEqual(distribution.num_towers, len(all_vars)) + v = all_vars[0] + self.assertTrue(all([v is vi for vi in all_vars[1:]])) + weight = numpy.squeeze(self.evaluate(distribution.fetch(v))) + # Our model is: + # predict = x * w + # loss = (predict - y)^2 + # dloss/dpredict = 2*(predict - y) + # dloss/dw = 2 * x^T @ (predict - y) + # For our batch size of 2, assuming sum loss reduction: + # x = [2, 7] + # y = [6, 21] + # w_initial = 2 + # predict = [4, 14] + # predict - y = [-2, -7] + # dloss/dw = 2 <[2, 7], [-2, -7]> = - 2(4 + 49) = -106 + # So unreplicated the update to w with lr=0.2 is -0.2 * -106 = 21.2 + # with sum loss reduction, or 10.6 with mean. + if loss_reduction == losses_impl.Reduction.SUM: + # Note that the "distribution.num_towers" factor will go away once + # we split the input across towers, instead of pulling a complete + # batch of input per tower. + self.assertNear(weight, 2 + 21.2 * distribution.num_towers, 0.0001) + else: + # One of the mean loss reductions. + self.assertNear(weight, 2 + 10.6, 0.0001) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py new file mode 100644 index 0000000000..8cf83c52d8 --- /dev/null +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -0,0 +1,486 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class MirroredStrategy implementing DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading +import six + +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import shared_variable_creator +from tensorflow.contrib.distribute.python import values +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.eager import context +from tensorflow.python.eager import tape +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import coordinator +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib + + +# TODO(josh11b): Replace asserts in this file with if ...: raise ... + + +def _cpu_device(device): + cpu_device = tf_device.DeviceSpec.from_string(device) + cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) + return cpu_device.to_string() + + +class _RequestedStop(Exception): + pass + + +class MirroredStrategy(distribute_lib.DistributionStrategy): + """Mirrors vars to distribute across multiple devices on a single machine. + + This strategy uses one tower per device and sync replication. + """ + + def __init__(self, + devices=None, + num_gpus=None, + cross_tower_ops=None, + prefetch_on_device=None): + super(MirroredStrategy, self).__init__() + # Convert `num_gpus` into `devices`, shouldn't specify both. + if devices is None: + if num_gpus is None: + num_gpus = context.num_gpus() + devices = ["/device:GPU:%d" % d for d in range(num_gpus)] + elif num_gpus is not None: + raise ValueError("Must only specify one of `devices` and `num_gpus`.") + + assert devices, "Must specify at least one device." + assert len(set(devices)) == len(devices), ( + "No duplicates allowed in `devices` argument.") + # TODO(josh11b): Require at least 2 devices? + self._devices = devices + self._canonical_device_set = set( + [device_util.canonicalize(d) for d in devices]) + self._device_index = values.PerDevice( + dict((d, i) for i, d in enumerate(devices))) + self.cross_tower_ops = ( + cross_tower_ops or + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()) + self._prefetch_on_device = prefetch_on_device + + def _create_variable(self, next_creator, *args, **kwargs): + """Create a mirrored variable. See `DistributionStrategy.scope`.""" + # Figure out what collections this variable should be added to. + # We'll add the MirroredVariable to those collections instead. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + colocate_with = kwargs.pop("colocate_with", None) + devices = self._get_devices_from(colocate_with) + + tower_local = kwargs.pop("tower_local_reduce_method", None) + if tower_local is not None: + kwargs["trainable"] = False + + # TODO(josh11b,apassos): It would be better if variable initialization + # was never recorded on the tape instead of having to do this manually + # here. + with tape.stop_recording(): + index = {} + for i, d in enumerate(devices): + with ops.device(d): + if i > 0: + # Give replicas meaningful distinct names: + var0name = index[devices[0]].name.split(":")[0] + kwargs["name"] = "%s/replica_%d" % (var0name, i) + # Initialize replicas with the same value: + if context.executing_eagerly(): + initial_value = index[devices[0]].value() + else: + initial_value = index[devices[0]].initial_value + kwargs["initial_value"] = array_ops.identity(initial_value) + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + v = next_creator(*args, **kwargs) + assert not isinstance(v, values.DistributedVariable) + index[d] = v + + if tower_local is None: + result = values.MirroredVariable(index, index[devices[0]]) + else: + result = values.TowerLocalVariable( + index, index[devices[0]], tower_local) + + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the member variables + # to the TRAINABLE_VARIABLES collection, so we manually remove + # them and replace with the MirroredVariable. We can't set + # "trainable" to False for next_creator() since that causes functions + # like implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + for v in index.values(): + l.remove(v) + g.add_to_collections(collections, result) + return result + + def distribute_dataset(self, dataset): + per_device_dataset = values.PerDeviceDataset( + dataset, self._devices, self._prefetch_on_device) + return per_device_dataset.make_one_shot_iterator() + + def _broadcast(self, tensor, destinations): + # TODO(josh11b): In eager mode, use one thread per device, or async mode. + return self.cross_tower_ops.broadcast(tensor, destinations or self._devices) + + def _call_for_each_tower(self, fn, *args, **kwargs): + """Run `fn` in separate threads, once per tower/worker device. + + Args: + fn: function to run (will be run once per device, each in its own thread). + *args: positional arguments for `fn` + **kwargs: keyword arguments for `fn`. + `"run_concurrently"`: Boolean indicating whether executions of `fn` + can be run concurrently (under eager execution only), defaults to + `True`. + + Returns: + Merged return value of `fn` across all towers. + + Raises: + RuntimeError: If fn() calls get_tower_context().merge_call() a different + number of times for when called for different devices. + """ + run_concurrently = kwargs.pop("run_concurrently", True) + if not context.executing_eagerly(): + # Lots of TF library code isn't thread-safe in graph mode, and + # there is little to be gained by turning on multithreading when + # constructing a graph. + run_concurrently = False + # Needed for per-thread device, etc. contexts in graph mode. + ops.get_default_graph().switch_to_thread_local() + elif run_concurrently is None: + run_concurrently = True + + coord = coordinator.Coordinator( + clean_stop_exception_types=(_RequestedStop,)) + + shared_variable_store = {} + + # TODO(isaprykin): Create these threads once instead of during every run() + # call. + threads = [] + for index, d in enumerate(self._devices): + variable_creator_fn = shared_variable_creator.make_fn( + shared_variable_store, index) + t = MirroredStrategy._MirroredTowerThread( + self, coord, d, variable_creator_fn, fn, + *values.select_device(d, args), **values.select_device(d, kwargs)) + threads.append(t) + + for t in threads: + t.start() + + # When `fn` starts `should_run` event is set on _MirroredTowerThread + # (`MTT`) threads. The execution waits until + # `MTT.has_paused` is set, which indicates that either `fn` is + # complete or a `get_tower_context().merge_call()` is called. If `fn` is + # complete, then `MTT.done` is set to True. Otherwise, arguments + # of `get_tower_context().merge_call` from all paused threads are grouped + # and the `merge_fn` is performed. Results of the + # `get_tower_context().merge_call` are then set to `MTT.merge_result`. + # Each such `get_tower_context().merge_call` call returns the + # `MTT.merge_result` for that thread when `MTT.should_run` event + # is reset again. Execution of `fn` resumes. + + try: + with coord.stop_on_exception(): + all_done = False + while not all_done and not coord.should_stop(): + done = [] + if run_concurrently: + for t in threads: + t.should_run.set() + for t in threads: + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + else: + for t in threads: + t.should_run.set() + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + if coord.should_stop(): + return None + all_done = all(done) + if not all_done: + if any(done): + raise RuntimeError("Some towers made a different number of " + "tower_context().merge_call() calls.") + # get_tower_context().merge_call() case + merge_args = values.regroup( + {t.device: t.merge_args for t in threads}) + merge_kwargs = values.regroup( + {t.device: t.merge_kwargs for t in threads}) + merge_result = threads[0].merge_fn( + self, *merge_args, **merge_kwargs) + for t in threads: + t.merge_result = values.select_device(t.device, merge_result) + finally: + for t in threads: + t.should_run.set() + coord.join(threads) + + return values.regroup({t.device: t.main_result for t in threads}) + + def map(self, map_over, fn, *args, **kwargs): + # TODO(josh11b): In eager mode, use one thread per device. + index = {} + i = 0 + for m in map_over: + d = self._devices[i % len(self._devices)] + with ops.device(d): + l = index.get(d, []) + l.append(fn(m, + *values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs))) + index[d] = l + # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput + # in addition to PerDevice data. + return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()}) + + def _reduce(self, method_string, value, destinations): + if len(self._devices) == 1 and not isinstance(value, values.PerDevice): + value = values.PerDevice({self._devices[0]: value}) + assert isinstance(value, values.PerDevice) + return self.cross_tower_ops.reduce( + method_string, value, destinations=destinations) + + def _batch_reduce(self, method_string, value_destination_pairs): + return self.cross_tower_ops.batch_reduce(method_string, + value_destination_pairs) + + def _update(self, var, fn, *args, **kwargs): + # TODO(josh11b): Also support TowerLocalVariables here? If so, args and + # kwargs don't need to be mirrored. + assert isinstance(var, values.MirroredVariable) + # TODO(josh11b): In eager mode, use one thread per device. + updates = {} + for d, v in var._index.items(): # pylint: disable=protected-access + name = "update_%d" % self._device_index.get(d) + with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + updates[d] = fn(v, + *values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs)) + return values.regroup(updates, values.Mirrored) + + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + assert isinstance(colocate_with, list) + # TODO(josh11b): In eager mode, use one thread per device. + updates = {} + for d in colocate_with: + name = "update_%d" % self._device_index.get(d) + with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + updates[d] = fn(*values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs)) + return values.regroup(updates, values.Mirrored) + + def _fetch(self, val, destination, fn): + """Return a copy of `val` or `fn(val)` on `destination`.""" + assert isinstance(destination, six.string_types) + if isinstance(val, values.TowerLocalVariable): + val = self.reduce(val.reduce_method, val, destinations=destination) + with ops.device(destination): + return fn(self.unwrap(val)[0]) + + assert isinstance(val, values.Mirrored), ( + "val = %s (type %s)" % (val, val.__class__.__name__)) + if val.on_device(destination): + with ops.device(destination): + # Use an identity here to make sure we are returning a tensor + # instead of e.g. a variable object. + return array_ops.identity(fn(val.get(destination))) + device = None + for d in self._devices: + if val.on_device(d): + device = d + break + assert device is not None, ( + "Could not find destination %s in list of devices %s." % + (destination, val.devices)) + with ops.device(device): + v = fn(val.get(device)) + with ops.device(destination): + return array_ops.identity(v) + + def _unwrap(self, val): + if isinstance(val, values.DistributedValues): + # Return in a deterministic order. + if set(val.devices) == self._canonical_device_set: + return [val.get(device=d) for d in self._devices] + return [val.get(device=d) for d in sorted(val.devices)] + return [val] + + @property + def is_single_tower(self): + return len(self._devices) == 1 + + @property + def num_towers(self): + return len(self._devices) + + def _worker_device_index(self): + return self._device_index + + @property + def worker_devices(self): + # Make a copy to prevent users from accidentally mutating our copy. + return list(self._devices) + + @property + def parameter_devices(self): + return list(self._devices) + + def non_slot_devices(self, var_list): + del var_list + return list(self._devices) + + def _get_devices_from(self, colocate_with=None): + if colocate_with is None: + return self._devices + elif isinstance(colocate_with, values.DistributedValues): + # pylint: disable=protected-access + return list(colocate_with._index.keys()) + elif isinstance(colocate_with, six.string_types): + return [colocate_with] + else: + return colocate_with + + class _MirroredTowerThread(threading.Thread): + """A thread that runs() a function on a device.""" + + def __init__(self, dist, coord, device, variable_creator_fn, fn, *args, + **kwargs): + super(MirroredStrategy._MirroredTowerThread, self).__init__() # pylint: disable=protected-access + self.coord = coord + self.distribution = dist + self.device = device + self.tower_id = dist.worker_devices.index(device) + self.variable_creator_fn = variable_creator_fn + # State needed to run and return the results of `fn`. + self.main_fn = fn + self.main_args = args + self.main_kwargs = kwargs + self.main_result = None + self.done = False + # State needed to run the next merge_call() (if any) requested via + # TowerContext. + self.merge_fn = None + self.merge_args = None + self.merge_kwargs = None + self.merge_result = None + # We use a thread.Event for the main thread to signal when this + # thread should start running (`should_run`), and another for + # this thread to transfer control back to the main thread + # (`has_paused`, either when it gets to a + # `get_tower_context().merge_call` or when `fn` returns). In + # either case the event starts cleared, is signaled by calling + # set(). The receiving thread waits for the signal by calling + # wait() and then immediately clearing the event using clear(). + self.should_run = threading.Event() + self.has_paused = threading.Event() + # These fields have to do with inheriting various contexts from the + # parent thread: + # pylint: disable=protected-access + self.context_mode = context.context()._eager_context.mode + if not context.context()._context_handle: + context.context()._initialize_handle_and_devices() + self.context_device_policy = ( + pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( + context.context()._context_handle)) + self.graph = ops.get_default_graph() + self._variable_creator_stack = self.graph._variable_creator_stack[:] + self._captured_var_scope = variable_scope.get_variable_scope() + # Adding a "/" at end lets us re-enter this scope later. + self._captured_name_scope = self.graph.get_name_scope() + if self._captured_name_scope: + self._captured_name_scope += "/" + if self.tower_id > 0: + if not self._captured_name_scope: + self._captured_name_scope = "" + self._captured_name_scope += "tower_%d/" % self.tower_id + + def run(self): + # pylint: disable=protected-access + self.graph._variable_creator_stack = self._variable_creator_stack + self.should_run.wait() + self.should_run.clear() + try: + if self.coord.should_stop(): + return + with self.coord.stop_on_exception(), \ + context.context()._mode(self.context_mode), \ + context.context().device_policy(self.context_device_policy), \ + self.graph.as_default(), \ + MirroredTowerContext(self.distribution, self.tower_id), \ + ops.device(self.device), \ + ops.name_scope(self._captured_name_scope), \ + variable_scope.variable_scope( + self._captured_var_scope, reuse=self.tower_id > 0), \ + variable_scope.variable_creator_scope(self.variable_creator_fn): + self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) + self.done = True + finally: + self.has_paused.set() + + +class MirroredTowerContext(distribute_lib.TowerContext): + """TowerContext used in MirroredStrategy.call_for_each_tower(). + + Opened in `_MirroredTowerThread`, to allow the user to invoke + `MirroredStrategy`'s specific implementation of `merge_call()`, + which works by delegating the function and its arguments to + the main thread (the one that invoked + `MirroredStrategy.call_for_each_tower()`). + """ + + def _merge_call(self, fn, *args, **kwargs): + """Delegate to the main thread to actually perform merge_call().""" + t = threading.current_thread() # a _MirroredTowerThread + t.merge_fn = fn + t.merge_args = args + t.merge_kwargs = kwargs + t.has_paused.set() + t.should_run.wait() + t.should_run.clear() + if t.coord.should_stop(): + raise _RequestedStop() + return t.merge_result + + @property + def device(self): + distribute_lib.require_tower_context(self) + return self._distribution_strategy.worker_devices[self._tower_id] diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py new file mode 100644 index 0000000000..9e9f06da8e --- /dev/null +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -0,0 +1,435 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Multi-GPU tests for MirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.layers import core +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import distribute as distribute_lib + +GPU_TEST = "test_gpu" in sys.argv[0] + + +class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + devices = ["/device:CPU:0", "/device:GPU:0"] + if GPU_TEST: + self.assertGreater(context.num_gpus(), 0) + if context.num_gpus() > 1: + devices = ["/device:GPU:0", "/device:GPU:1"] + print(self.id().split(".")[-1], "devices:", ", ".join(devices)) + return mirrored_strategy.MirroredStrategy(devices) + + def testMinimizeLossEager(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_minimize_loss_eager(self._get_distribution_strategy()) + + def testMinimizeLossGraph(self): + soft_placement = not GPU_TEST + print("testMinimizeLossGraph soft_placement:", soft_placement) + self._test_minimize_loss_graph( + self._get_distribution_strategy(), soft_placement=soft_placement) + + def testMapReduce(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_map_reduce(self._get_distribution_strategy()) + + def testDeviceIndex(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_device_index(self._get_distribution_strategy()) + + def testTowerId(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_tower_id(self._get_distribution_strategy()) + + def testNumTowers(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self.assertEqual(2, self._get_distribution_strategy().num_towers) + + @test_util.run_in_graph_and_eager_modes() + def testCallAndMergeExceptions(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + + @test_util.run_in_graph_and_eager_modes() + def testRunRegroupError(self): + + def run_fn(device_id): + # Generates a list with different lengths on different devices. + # Will fail in _regroup() (if more than one device). + return list(range(device_id)) + + dist = self._get_distribution_strategy() + with dist.scope(), self.assertRaises(AssertionError): + dist.call_for_each_tower(run_fn, dist.worker_device_index) + + @test_util.run_in_graph_and_eager_modes() + def testReduceToCpu(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + + def run_fn(device_id): + return device_id + + dist = self._get_distribution_strategy() + with dist.scope(): + result = dist.call_for_each_tower(run_fn, dist.worker_device_index) + reduced = dist.reduce("sum", result, destinations="/device:CPU:0") + unwrapped = dist.unwrap(reduced) + self.assertEqual(1, len(unwrapped)) + expected = sum(range(len(dist.worker_devices))) + self.assertEqual(expected, self.evaluate(unwrapped[0])) + + +@test_util.with_c_api +class MirroredStrategyVariableCreationTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Enough GPUs not available for this test in eager mode.") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSingleVariable(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + # This variable should be created only once across the threads because of + # special variable_creator functions used by `dist.call_for_each_tower`. + v = variable_scope.variable(1.0, name="foo") + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + self.assertEquals("foo:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testUnnamedVariable(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + v = variable_scope.variable(1.0) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + # Default name of "Variable" will be used. + self.assertEquals("Variable:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleVariables(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + vs = [] + for i in range(5): + vs.append(variable_scope.variable(1.0, name="foo" + str(i))) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return vs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + for i, v in enumerate(result): + self.assertIsInstance(v, values.MirroredVariable) + self.assertEquals("foo" + str(i) + ":0", v.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleVariablesWithSameCanonicalName(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + vs = [] + vs.append(variable_scope.variable(1.0, name="foo/bar")) + vs.append(variable_scope.variable(1.0, name="foo_1/bar")) + vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) + vs.append(variable_scope.variable(1.0, name="foo/bar_1")) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return vs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + for v in result: + self.assertIsInstance(v, values.MirroredVariable) + self.assertEquals(4, len(result)) + self.assertEquals("foo/bar:0", result[0].name) + self.assertEquals("foo_1/bar:0", result[1].name) + self.assertEquals("foo_1/bar_1:0", result[2].name) + self.assertEquals("foo/bar_1:0", result[3].name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testVariableWithSameCanonicalNameAcrossThreads(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(device_id): + v = variable_scope.variable(1.0, name="foo_" + str(device_id)) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower( + model_fn, dist.worker_device_index, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + # The resulting mirrored variable will use the name from the first device. + self.assertEquals("foo_0:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testWithLayers(self): + self._skip_eager_if_gpus_less_than(1) + def model_fn(features): + with variable_scope.variable_scope("common"): + layer1 = core.Dense(1) + layer1(features) + layer2 = core.Dense(1) + layer2(features) + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + layer3 = core.Dense(1) + layer3(features) + return [(layer1.kernel, layer1.bias), + (layer2.kernel, layer2.bias), + (layer3.kernel, layer3.bias)] + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) + features = dist.distribute_dataset(features).get_next() + + with dist.scope(): + result = dist.call_for_each_tower( + model_fn, features, run_concurrently=False) + suffixes = ["", "_1", "_2"] + for (kernel, bias), suffix in zip(result, suffixes): + self.assertIsInstance(kernel, values.MirroredVariable) + self.assertEquals("common/dense" + suffix + "/kernel:0", kernel.name) + self.assertIsInstance(bias, values.MirroredVariable) + self.assertEquals("common/dense" + suffix + "/bias:0", bias.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testWithGetVariableAndVariableScope(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + v0 = variable_scope.get_variable("var-thread0", [1]) + with variable_scope.variable_scope("common"): + v1 = variable_scope.get_variable("var-thread1", [1]) + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + v2 = variable_scope.get_variable("var-thread2", [1]) + + return v0, v1, v2 + + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with variable_scope.variable_scope("main"): + v = variable_scope.get_variable("var-main0", [1]) + self.assertEquals("main/var-main0:0", v.name) + + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(3, len(result)) + v0, v1, v2 = result + self.assertIsInstance(v0, values.MirroredVariable) + self.assertEquals("main/var-thread0:0", v0.name) + self.assertIsInstance(v1, values.MirroredVariable) + self.assertEquals("main/common/var-thread1:0", v1.name) + self.assertIsInstance(v2, values.MirroredVariable) + self.assertEquals("main/common/var-thread2:0", v2.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testThreeDevices(self): + self._skip_eager_if_gpus_less_than(2) + + def model_fn(): + v = variable_scope.variable(1.0, name="foo") + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + self.assertEquals("foo:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testNonMatchingVariableCreation(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(name): + v = variable_scope.variable(1.0, name=name) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + names = values.DistributedValues({ + "/device:CPU:0": "foo", + "/device:GPU:0": "bar" + }) + with self.assertRaises(RuntimeError): + _ = dist.call_for_each_tower(model_fn, names, run_concurrently=False) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testTowerLocalVariable(self): + self._skip_eager_if_gpus_less_than(1) + + all_v_sum = {} + all_v_mean = {} + + def model_fn(device_id): + tower_context = distribute_lib.get_tower_context() + with tower_context.tower_local_var_scope("sum"): + v_sum = variable_scope.variable(1.0) + with tower_context.tower_local_var_scope("mean"): + v_mean = variable_scope.variable(4.0) + self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) + self.assertTrue(isinstance(v_mean, values.TowerLocalVariable)) + updates = [v_sum.assign_add(2.0 + device_id), + v_mean.assign(6.0 * device_id)] + all_v_sum[device_id] = v_sum + all_v_mean[device_id] = v_mean + return updates, v_sum, v_mean + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + # Create "sum" and "mean" versions of TowerLocalVariables. + ret_ops, ret_v_sum, ret_v_mean = dist.call_for_each_tower( + model_fn, dist.worker_device_index, run_concurrently=False) + # Should see the same wrapping instance in all towers. + self.assertIs(all_v_sum[0], ret_v_sum) + self.assertIs(all_v_mean[0], ret_v_mean) + for i in range(1, dist.num_towers): + self.assertIs(all_v_sum[0], all_v_sum[1]) + self.assertIs(all_v_mean[0], all_v_mean[1]) + + # Apply updates + self.evaluate(variables.global_variables_initializer()) + self.evaluate([y for x in ret_ops for y in dist.unwrap(x)]) + expected_sum = 0.0 + expected_mean = 0.0 + for i, d in enumerate(dist.worker_devices): + # Test access within a device scope, should see different values. + with ops.device(d): + v_sum_value = self.evaluate(ret_v_sum.read_value()) + v_mean_value = self.evaluate(ret_v_mean.read_value()) + expected = i + 3.0 + self.assertEqual(expected, v_sum_value) + expected_sum += expected + expected = i * 6.0 + self.assertEqual(expected, v_mean_value) + expected_mean += expected + + # fetch() should return the value you get by applying the + # reduction across all towers. + self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) + expected_mean /= len(dist.worker_devices) + self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean))) + + # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not + # testing this in eager mode. + + def testNameScope(self): + def model_fn(): + with ops.name_scope("foo"): + a = constant_op.constant(1.0, name="a") + distribute_lib.get_tower_context().merge_call(lambda _: _) + b = constant_op.constant(1.0, name="b") + return a, b + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(2, len(result)) + for v, name in zip(result, ["a", "b"]): + self.assertIsInstance(v, values.DistributedValues) + v0, v1 = dist.unwrap(v) + self.assertEquals("main/foo/" + name + ":0", v0.name) + self.assertEquals("main/tower_1/foo/" + name + ":0", v1.name) + + def testWithDefaultName(self): + def model_fn(): + with ops.name_scope(None, "foo"): + a = constant_op.constant(1.0, name="a") + distribute_lib.get_tower_context().merge_call(lambda _: _) + b = constant_op.constant(2.0, name="b") + return a, b + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(2, len(result)) + for v, name in zip(result, ["a", "b"]): + self.assertIsInstance(v, values.DistributedValues) + v0, v1 = dist.unwrap(v) + self.assertEquals("foo/" + name + ":0", v0.name) + self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py new file mode 100644 index 0000000000..a1ef0ecc77 --- /dev/null +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -0,0 +1,91 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class MirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import distribute as distribute_lib + + +@test_util.with_c_api +class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return mirrored_strategy.MirroredStrategy(["/device:CPU:0"]) + + def testMinimizeLossEager(self): + self._test_minimize_loss_eager(self._get_distribution_strategy()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + def testMapReduce(self): + self._test_map_reduce(self._get_distribution_strategy()) + + def testDeviceIndex(self): + self._test_device_index(self._get_distribution_strategy()) + + def testTowerId(self): + self._test_tower_id(self._get_distribution_strategy()) + + @test_util.run_in_graph_and_eager_modes() + def testCallAndMergeExceptions(self): + self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + + +@test_util.with_c_api +class VariableCreatorStackTest(test.TestCase): + + def testCreatorStacksAreThreadLocal(self): + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + + def model_fn(device_id): + assert isinstance(device_id, int) + def thread_creator_fn(next_creator, *args, **kwargs): + return next_creator(*args, **kwargs) + ":thread_" + str(device_id) + + with variable_scope.variable_creator_scope(thread_creator_fn): + # Create a variable in this scope. + v = variable_scope.variable(1.0) + + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + def main_thread_creator(next_creator, *args, **kwargs): + # We are not using the underlying next_creator for test purposes. + del next_creator, args, kwargs + return "main_thread" + + with context.graph_mode(), \ + dist.scope(), \ + variable_scope.variable_creator_scope(main_thread_creator): + result = dist.call_for_each_tower(model_fn, dist.worker_device_index) + result = dist.unwrap(result) + expected = ["main_thread:thread_0", "main_thread:thread_1"] + self.assertEquals(expected, result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py new file mode 100644 index 0000000000..fe80bb4df5 --- /dev/null +++ b/tensorflow/contrib/distribute/python/monitor.py @@ -0,0 +1,61 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Monitor is responsible for training, checkpointing and recovery.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.ops import variables + + +class Monitor(object): + """Executes training steps, recovers and checkpoints. + + Note that this class is particularly preliminary, experimental, and + expected to change. + """ + # TODO(isaprykin): Support step functions that need multiple session calls. + # TODO(isaprykin): Support extra arguments to the step function. + # TODO(isaprykin): Support recovery, checkpointing and summaries. + + def __init__(self, step_callable, session=None): + """Initialize the Monitor with components for executing training steps. + + Args: + step_callable: a training `Step` that's capable of signaling when done. + session: a `Session` instance that's needed for graph mode. + + Raises: + ValueError: if `session` was provided for eager mode or not provided for + graph mode. + """ + if context.executing_eagerly(): + if session is not None: + raise ValueError("Should not provide a `session` in Eager mode.") + self._run_step = step_callable + else: + if session is None: + raise ValueError("Should provide a `session` in Graph mode.") + self._run_step = session.make_callable(step_callable()) + session.run(variables.global_variables_initializer()) + + def run_steps(self, num_steps=None): + step = 0 + done = False + while done is not None and (num_steps is None or step < num_steps): + done = self._run_step() + step += 1 diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py new file mode 100644 index 0000000000..8277e1e791 --- /dev/null +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -0,0 +1,84 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class Monitor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import monitor as monitor_lib +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.training import gradient_descent + + +class MonitorTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=combinations.graph_and_eager_modes))) + def testTrainNetwork(self, distribution, optimizer_fn): + with distribution.scope(): + single_loss_step, layer = single_loss_example(optimizer_fn, distribution) + + if context.executing_eagerly(): + monitor = monitor_lib.Monitor(single_loss_step, None) + else: + with self.test_session() as sess: + monitor = monitor_lib.Monitor(single_loss_step, sess) + + monitor.run_steps(1) + + self.assertEqual(1, len(layer.trainable_variables)) + mirrored_weight_variable = layer.trainable_variables[0] + start_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + start_error = abs(numpy.array(start_error) - 1) + + monitor.run_steps(9) + end_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + end_error = abs(numpy.array(end_error) - 1) + self.assertGreaterEqual(start_error, end_error) + + def testPassingASessionInEager(self): + distribution = one_device_strategy.OneDeviceStrategy( + "/device:CPU:0") + step_function, _ = single_loss_example( + lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution) + + with self.test_session() as sess: + with self.assertRaisesRegexp(ValueError, "Should not provide"): + _ = monitor_lib.Monitor(step_function, sess) + + def testNotPassingASessionInGraph(self): + distribution = one_device_strategy.OneDeviceStrategy( + "/device:CPU:0") + step_function, _ = single_loss_example( + lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution) + + with context.graph_mode(), ops.Graph().as_default(): + with self.assertRaisesRegexp(ValueError, "Should provide"): + _ = monitor_lib.Monitor(step_function, session=None) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py new file mode 100644 index 0000000000..39c49442b9 --- /dev/null +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -0,0 +1,148 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class OneDeviceStrategy implementing DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.eager.python import datasets +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import distribute as distribute_lib + + +# TODO(josh11b): Replace asserts in this file with if ...: raise ... + + +class OneDeviceStrategy(distribute_lib.DistributionStrategy): + """A distribution strategy for running on a single device.""" + # TODO(josh11b): Do we wrap values in types to generate errors if you are + # doing something that won't work with other DistributionStrategy + # implementations? + + def __init__(self, device): + super(OneDeviceStrategy, self).__init__() + self._device = device + + def _create_variable(self, next_creator, *args, **kwargs): + # No need to distinguish tower-local variables when not mirroring, + # we just enforce that they are not trainable. + if kwargs.pop("tower_local_reduce_method", None) is not None: + kwargs["trainable"] = False + + colocate_with = kwargs.pop("colocate_with", None) + if colocate_with is None: + with ops.device(self._device): + return next_creator(*args, **kwargs) + if isinstance(colocate_with, six.string_types): + with ops.device(colocate_with): + return next_creator(*args, **kwargs) + if (isinstance(colocate_with, list) and len(colocate_with) == 1 and + isinstance(colocate_with[0], six.string_types)): + with ops.device(colocate_with[0]): + return next_creator(*args, **kwargs) + with ops.colocate_with(colocate_with): + return next_creator(*args, **kwargs) + + def distribute_dataset(self, dataset): + if context.executing_eagerly(): + return datasets.Iterator(dataset) + else: + return dataset.make_one_shot_iterator() + + def _broadcast(self, tensor, destinations): + return tensor + + def _call_for_each_tower(self, fn, *args, **kwargs): + # We don't run `fn` in multiple threads in OneDeviceStrategy. + kwargs.pop("run_concurrently", None) + with ops.device(self._device), _OneDeviceTowerContext(self): + return fn(*args, **kwargs) + + def map(self, map_over, fn, *args, **kwargs): + with ops.device(self._device): + return values.MapOutput([fn(m, *args, **kwargs) for m in map_over]) + + def _reduce(self, method_string, value, destinations): + if not isinstance(value, values.MapOutput): + return value + l = value.get() + assert l + with ops.device(self._device): + if method_string == "sum": + return math_ops.add_n(l) + elif method_string == "mean": + return math_ops.add_n(l) / len(l) + else: + assert False + + def _update(self, var, fn, *args, **kwargs): + with ops.device(self._device), distribute_lib.UpdateContext(self._device): + return fn(var, *args, **kwargs) + + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + del colocate_with + with ops.device(self._device), distribute_lib.UpdateContext(self._device): + return fn(*args, **kwargs) + + def _fetch(self, val, destination, fn): + """Return a copy of `val` or `fn(val)` on `destination`.""" + with ops.device(self._device): + v = fn(val) + with ops.device(destination): + return array_ops.identity(v) + + def _unwrap(self, value): + return [value] + + @property + def is_single_tower(self): + return True + + @property + def num_towers(self): + return 1 + + @property + def worker_devices(self): + return [self._device] + + @property + def parameter_devices(self): + return [self._device] + + def non_slot_devices(self, var_list): + del var_list + return [self._device] + + def _worker_device_index(self): + return 0 + + +class _OneDeviceTowerContext(distribute_lib.TowerContext): + + def __init__(self, distribution_strategy): + distribute_lib.TowerContext.__init__( + self, distribution_strategy, tower_id=0) + + @property + def device(self): + return self._distribution_strategy.worker_devices[0] diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py new file mode 100644 index 0000000000..7101ed0756 --- /dev/null +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class OneDeviceStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util + + +@test_util.with_c_api +class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return one_device_strategy.OneDeviceStrategy("/device:CPU:0") + + def testMinimizeLossEager(self): + self._test_minimize_loss_eager(self._get_distribution_strategy()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + def testMapReduce(self): + self._test_map_reduce(self._get_distribution_strategy()) + + def testDeviceIndex(self): + self._test_device_index(self._get_distribution_strategy()) + + def testTowerId(self): + self._test_tower_id(self._get_distribution_strategy()) + + @test_util.run_in_graph_and_eager_modes() + def testCallAndMergeExceptions(self): + self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py new file mode 100644 index 0000000000..a0912b625f --- /dev/null +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -0,0 +1,70 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for running legacy optimizer code with DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables + + +class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v2_optimizers(), + combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + def testTrainNetwork(self, distribution, optimizer_fn, + use_callable_loss=True): + with distribution.scope(): + model_fn, dataset, layer = minimize_loss_example( + optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) + + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return control_flow_ops.group(distribution.unwrap( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), run_concurrently=layer.built))) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(10): + run_step() + + weights.append(self.evaluate(distribution.fetch(layer.kernel))) + biases.append(self.evaluate(distribution.fetch(layer.bias))) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py new file mode 100644 index 0000000000..b9ffd2f266 --- /dev/null +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py @@ -0,0 +1,167 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Extension of prefetching_ops to support more than one device.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest as data_nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.util import nest + + +# pylint: disable=protected-access +class _PrefetchToDeviceIterator(object): + """A replacement for @{tf.data.Iterator} that prefetches to another device.""" + + def __init__(self, input_dataset, devices, buffer_size): + self._input_dataset = input_dataset + self._get_next_call_count = 0 + self._devices = devices + input_iterator = input_dataset.make_one_shot_iterator() + input_iterator_handle = input_iterator.string_handle() + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, input_iterator.output_types, input_iterator.output_shapes, + input_iterator.output_classes) + return remote_iterator.get_next() + + target_device = gen_dataset_ops.iterator_get_device( + input_iterator._iterator_resource) + self._buffering_resources = [] + for device in nest.flatten(self._devices): + with ops.device(device): + buffer_resource_handle = prefetching_ops.function_buffering_resource( + f=_prefetch_fn, + target_device=target_device, + string_arg=input_iterator_handle, + buffer_size=buffer_size, + thread_pool_size=0) + self._buffering_resources.append(buffer_resource_handle) + + def get_next(self, name=None): + """See @{tf.data.Iterator.get_next}.""" + self._get_next_call_count += 1 + if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: + warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) + + flat_result = [] + # TODO(priyag): This will fail if the input size (typically number of + # batches) is not divisible by number of devices. + # How do we handle that more gracefully / let the user know? + for buffer_resource in self._buffering_resources: + flat_ret = gen_dataset_ops.function_buffering_resource_get_next( + buffer_resource, + output_types=data_nest.flatten(sparse.as_dense_types( + self.output_types, self.output_classes)), name=name) + + ret = sparse.deserialize_sparse_tensors( + data_nest.pack_sequence_as(self.output_types, flat_ret), + self.output_types, self.output_shapes, self.output_classes) + + for tensor, shape in zip( + data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): + if isinstance(tensor, ops.Tensor): + tensor.set_shape(shape) + flat_result.append(ret) + + return nest.pack_sequence_as(self._devices, flat_result) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types +# pylint: enable=protected-access + + +class _PrefetchToDeviceDataset(dataset_ops.Dataset): + """A `Dataset` whose iterator prefetches elements to other device(s).""" + + def __init__(self, input_dataset, devices, buffer_size): + self._input_dataset = input_dataset + self._devices = devices + self._buffer_size = buffer_size if buffer_size is not None else 1 + + def make_one_shot_iterator(self): + return _PrefetchToDeviceIterator(self._input_dataset, self._devices, + self._buffer_size) + + def make_initializable_iterator(self, shared_name=None): + raise NotImplementedError("`prefetch_to_devices()` is not currently " + "compatible with initializable iterators. Use " + "`make_one_shot_iterator()` instead.") + + def _as_variant_tensor(self): + # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset + # transformation methods is called. + # TODO(mrry): Investigate support for chaining further transformations after + # the prefetch, including GPU support. + raise NotImplementedError("`prefetch_to_devices()` must be the last " + "transformation in a dataset pipeline.") + + # TODO(priyag): Fix the output types, shapes and classes to match the result + # of get_next (which has the additional nesting layer of devices now). + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +def prefetch_to_devices(devices, buffer_size=None): + """A transformation that prefetches dataset values to the given `devices`. + + NOTE: Although the transformation creates a @{tf.data.Dataset}, the + transformation must be the final `Dataset` in the input pipeline. + + Args: + devices: A nested structure of devices on which to prefetch the data. It can + be a single device name, or a tuple or list of device names. + buffer_size: (Optional.) The number of elements to buffer on each device. + Defaults to an automatically chosen value. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _PrefetchToDeviceDataset(dataset, devices, buffer_size) + + return _apply_fn diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py new file mode 100644 index 0000000000..8ed16f4607 --- /dev/null +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py @@ -0,0 +1,68 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for prefetching_ops_v2.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import prefetching_ops_v2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class PrefetchingOpsV2Test(test.TestCase): + + def testPrefetchToOneDevice(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices("/gpu:0")) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToTwoDevicesInAList(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + output = [] + with self.test_session() as sess: + for _ in range(5): + result = sess.run(next_element) + self.assertEqual(2, len(result)) + output.extend(result) + self.assertEquals(set(range(10)), set(output)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator.py b/tensorflow/contrib/distribute/python/shared_variable_creator.py new file mode 100644 index 0000000000..aca9c7af05 --- /dev/null +++ b/tensorflow/contrib/distribute/python/shared_variable_creator.py @@ -0,0 +1,97 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility to re-use variables created on first device on subsequent devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +_VARIABLE_UNIQUIFYING_REGEX = re.compile(r"_\d/") +_VARIABLE_UNIQUIFYING_REGEX_AT_END = re.compile(r"_\d$") + + +def _canonicalize_variable_name(name): + # If no name is specified, uses default name "Variable". + if name is None: + return "Variable" + # Replace all instances of "_<num>/" with "/" + name = _VARIABLE_UNIQUIFYING_REGEX.sub("/", name) + # Replace any instances of "_<num>" at the end of the string with "" + name = _VARIABLE_UNIQUIFYING_REGEX_AT_END.sub("", name) + return name + + +def make_fn(shared_variable_store, device_id): + """Construct the variable creator function for device `device_id`. + + Constructs custom variable creator functions for the given device. + On first device (device_id == 0), it creates the variable using the + `next_creator`, and stores it in the provided `shared_variable_store`. + On all other devices (device_id > 0), it tries to re-use the variable + already created with the same name. If no such variable exists, it throws an + error. + Additionally, we de-uniquify variable names before checking for matches. This + helps re-use variables which are intended to be the same but have different + names due to variable uniquificaton happening upstream. Since this might + mean we may have multiple variables with the same canonical name, we store + them in a list per canonical name and return them in the same order as well. + + Args: + shared_variable_store: A dictionary that we will use to store variables + created on the first device, and re-used by creators for other devices. + device_id: Integer index of the device whose creator should be + constructed. + + Returns: + An appropriate creator function based on device_id. + + """ + variable_scope_access_index = {} + assert isinstance(device_id, int) + + def create_new_variable(next_creator, *args, **kwargs): + """Create the variable using `next_creator` and store it.""" + canonical_name = _canonicalize_variable_name(kwargs.get("name")) + v = next_creator(*args, **kwargs) + + if canonical_name not in shared_variable_store: + shared_variable_store[canonical_name] = [] + shared_variable_store[canonical_name].append(v) + return v + + def reuse_variable(next_creator, *args, **kwargs): + """Re-use existing variable from store with same name (in order).""" + del next_creator, args + name = kwargs.get("name") + canonical_name = _canonicalize_variable_name(name) + + try: + variable_index = variable_scope_access_index.get(canonical_name, 0) + v = shared_variable_store[canonical_name][variable_index] + # TODO(priyag): Make this variable re-use more robust by adding checks + # that the requested shape and dtype match the existing variable. + variable_scope_access_index[canonical_name] = variable_index + 1 + return v + except (KeyError, IndexError): + raise RuntimeError( + "Tried to create variable {} with mismatching name on device {}". + format(name, device_id)) + + if device_id == 0: + return create_new_variable + else: + return reuse_variable diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py new file mode 100644 index 0000000000..713494d603 --- /dev/null +++ b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py @@ -0,0 +1,75 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SharedVariableCreator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import shared_variable_creator +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variable_scope + + +class CanonicalizeVariableNameTest(test.TestCase): + + def _canonicalize(self, name): + return shared_variable_creator._canonicalize_variable_name(name) + + def testNoName(self): + self.assertEquals("Variable", self._canonicalize(None)) + + def testPatternInMiddle(self): + self.assertEquals("foo/bar/baz", self._canonicalize("foo_1/bar_1/baz")) + + def testPatternAtEnd(self): + self.assertEquals("foo", self._canonicalize("foo_1")) + + def testWrongPatterns(self): + self.assertEquals("foo_1:0", self._canonicalize("foo_1:0")) + self.assertEquals("foo1", self._canonicalize("foo1")) + self.assertEquals("foo_a", self._canonicalize("foo_a")) + + +@test_util.with_c_api +class SharedVariableCreatorTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testSharedVariable(self): + + shared_variable_store = {} + num_devices = 3 + creator_fns = [] + for i in range(num_devices): + creator_fn = shared_variable_creator.make_fn(shared_variable_store, i) + creator_fns.append(creator_fn) + + with variable_scope.variable_creator_scope(creator_fns[0]): + v0 = variable_scope.variable(1.0, name="foo") + + with variable_scope.variable_creator_scope(creator_fns[1]): + v1 = variable_scope.variable(1.0, name="foo") + + with variable_scope.variable_creator_scope(creator_fns[2]): + v2 = variable_scope.variable(1.0, name="foo") + + # v1 and v2 should be same as v0 + self.assertIs(v1, v0) + self.assertIs(v2, v0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/simple_estimator_example.py b/tensorflow/contrib/distribute/python/simple_estimator_example.py new file mode 100644 index 0000000000..7095d801ad --- /dev/null +++ b/tensorflow/contrib/distribute/python/simple_estimator_example.py @@ -0,0 +1,97 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A simple example to test the a DistributionStrategy with Estimators. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import app +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import training_util + + +def build_model_fn_optimizer(): + """Simple model_fn with optimizer.""" + # TODO(anjalisridhar): Move this inside the model_fn once OptimizerV2 is + # done? + optimizer = gradient_descent.GradientDescentOptimizer(0.2) + + def model_fn(features, labels, mode): # pylint: disable=unused-argument + """model_fn which uses a single unit Dense layer.""" + # You can also use the Flatten layer if you want to test a model without any + # weights. + layer = core.Dense(1, use_bias=True) + logits = layer(features) + + if mode == model_fn_lib.ModeKeys.PREDICT: + predictions = {"logits": logits} + return model_fn_lib.EstimatorSpec(mode, predictions=predictions) + + def loss_fn(): + y = array_ops.reshape(logits, []) - constant_op.constant(1.) + return y * y + + if mode == model_fn_lib.ModeKeys.EVAL: + return model_fn_lib.EstimatorSpec(mode, loss=loss_fn()) + + assert mode == model_fn_lib.ModeKeys.TRAIN + + global_step = training_util.get_global_step() + train_op = optimizer.minimize(loss_fn(), global_step=global_step) + return model_fn_lib.EstimatorSpec(mode, loss=loss_fn(), train_op=train_op) + + return model_fn + + +def main(_): + distribution = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]) + config = run_config.RunConfig(distribute=distribution) + + def input_fn(): + features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) + labels = dataset_ops.Dataset.from_tensors([1.]).repeat(10) + return dataset_ops.Dataset.zip((features, labels)) + + estimator = estimator_lib.Estimator( + model_fn=build_model_fn_optimizer(), config=config) + estimator.train(input_fn=input_fn, steps=10) + + eval_result = estimator.evaluate(input_fn=input_fn) + print("Eval result: {}".format(eval_result)) + + def predict_input_fn(): + predict_features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) + return predict_features + + predictions = estimator.predict(input_fn=predict_input_fn) + # TODO(anjalsridhar): This returns a generator object, figure out how to get + # meaningful results here. + print("Prediction results: {}".format(predictions)) + + +if __name__ == "__main__": + app.run(main) diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py new file mode 100644 index 0000000000..cef5fd2f89 --- /dev/null +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -0,0 +1,102 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A simple network to use in tests and examples.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import step_fn +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.layers import core +from tensorflow.python.layers import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def single_loss_example(optimizer_fn, distribution, use_bias=False): + """Build a very simple network to use in tests and examples.""" + dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + optimizer = optimizer_fn() + layer = core.Dense(1, use_bias=use_bias) + + def loss_fn(x): + y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) + return y * y + + single_loss_step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer, + distribution) + + # Layer is returned for inspecting the kernels in tests. + return single_loss_step, layer + + +def minimize_loss_example(optimizer_fn, + use_bias=False, + use_callable_loss=True, + create_optimizer_inside_model_fn=False): + """Example of non-distribution-aware legacy code.""" + dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + # An Optimizer instance is created either outside or inside model_fn. + outer_optimizer = None + if not create_optimizer_inside_model_fn: + outer_optimizer = optimizer_fn() + + layer = core.Dense(1, use_bias=use_bias) + + def model_fn(x): + """A very simple model written by the user.""" + + def loss_fn(): + y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) + return y * y + + optimizer = outer_optimizer or optimizer_fn() + + if use_callable_loss: + return optimizer.minimize(loss_fn) + else: + return optimizer.minimize(loss_fn()) + + return model_fn, dataset, layer + + +def batchnorm_example(optimizer_fn, + batch_per_epoch=1, + momentum=0.9, + renorm=False): + """Example of non-distribution-aware legacy code with batch normalization.""" + # input shape is [16, 8], input values are increasing in both dimensions. + dataset = dataset_ops.Dataset.from_tensor_slices( + [[[float(x * 8 + y + z * 100) + for y in range(8)] + for x in range(16)] + for z in range(batch_per_epoch)]).repeat() + optimizer = optimizer_fn() + batchnorm = normalization.BatchNormalization( + renorm=renorm, momentum=momentum, fused=False) + + def model_fn(x): + + def loss_fn(): + y = math_ops.reduce_sum(batchnorm(x, training=True), axis=1) + loss = math_ops.reduce_mean(y - constant_op.constant(1.)) + return loss + + # Callable loss. + return optimizer.minimize(loss_fn) + + return model_fn, dataset, batchnorm diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py new file mode 100644 index 0000000000..82514c64be --- /dev/null +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -0,0 +1,103 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The step function abstraction represents a single training step.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import backprop +from tensorflow.python.training import optimizer as optimizer_lib + + +class Step(object): + """Interface for performing each step of a training algorithm.""" + + def __init__(self, distribution): + self._distribution = distribution + + @property + def distribution(self): + return self._distribution + + def __call__(self): + """Perform one step of this training algorithm.""" + return self.step(self.inputs()) + + def inputs(self): + """For the generating the input to be passed to `step()`.""" + raise NotImplementedError("must be implemented in descendants") + + def step(self, inputs): + """Perform the main computation of this training algorithm.""" + raise NotImplementedError("must be implemented in descendants") + + +class StandardInputStep(Step): + """Step with a standard implementation of input handling. + + Args: + input_dataset: a tf.data Dataset that provides input. + """ + + def __init__(self, input_dataset, distribution): + Step.__init__(self, distribution) + self._distributed_input = distribution.distribute_dataset(input_dataset) + + def inputs(self): + return self._distributed_input.get_next() + + +class StandardSingleLossStep(StandardInputStep): + """A step function that implements a training step for a feed forward network. + + An instance of this class is intended to be used as a callable: + + ```python + ... + step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer) + step.initialize(distribution) + + # Run a single training step on a given DistributionStrategy: + step(distribution) + ... + ``` + + Args: + input_dataset: a tf.data Dataset that provides input. + loss_fn: a function that returns loss. + optimizer: an optimizer that implements an update rule. + distribution: a `DistributionStrategy` object. + """ + + def __init__(self, input_dataset, loss_fn, optimizer, distribution): + StandardInputStep.__init__(self, input_dataset, distribution) + self._loss_fn = loss_fn + self._optimizer = optimizer + self._is_run_concurrently = False + + def step(self, inputs): + with self._distribution.scope(): + gradients_fn = backprop.implicit_grad(self._loss_fn) + gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) + + grads_and_vars = self.distribution.call_for_each_tower( + gradients_fn, inputs, run_concurrently=self._is_run_concurrently) + # If threads use layers, then we need to run the first step sequentially, + # so that layers.build() is not executed in parallel. Otherwise, multiple + # sets of mirrored variables are going to be created. + self._is_run_concurrently = True + return self._optimizer._distributed_apply( # pylint: disable=protected-access + self.distribution, grads_and_vars) diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py new file mode 100644 index 0000000000..75c5ec9659 --- /dev/null +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class Step.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.ops import variables + + +class SingleLossStepTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=combinations.graph_and_eager_modes))) + def testTrainNetwork(self, distribution, optimizer_fn): + with distribution.scope(): + single_loss_step, layer = single_loss_example( + optimizer_fn, distribution, use_bias=True) + + if context.executing_eagerly(): + run_step = single_loss_step + else: + with self.test_session() as sess: + run_step = sess.make_callable(single_loss_step()) + self.evaluate(variables.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(10): + run_step() + + weights.append(self.evaluate(distribution.fetch(layer.kernel))) + biases.append(self.evaluate(distribution.fetch(layer.bias))) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py new file mode 100644 index 0000000000..2b4ad9f146 --- /dev/null +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -0,0 +1,225 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library for testing DistributionStrategy descendants.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import optimizer + + +class _TestException(Exception): + pass + + +# May be the argument to either distribution.call_for_each_tower() or +# get_tower_context().merge_call() +def _raise_exception_fn(_=None): + raise _TestException() + + +# Must be the argument to a distribution.call_for_each_tower() call, calls a +# get_tower_context().merge_call() that raises an exception. +def _merge_raises_fn(): + distribute_lib.get_tower_context().merge_call(_raise_exception_fn) + + +# Must be the argument to a get_tower_context().merge_call() call, calls +# dist.call_for_each_tower() with a function that raises an exception. +def _call_raises_fn(dist): + dist.call_for_each_tower(_raise_exception_fn) + + +# Must be the argument to a distribution.call_for_each_tower() call, +# calls a get_tower_context().merge_call() that calls a +# call_for_each_tower() that raises an exception. +def _merge_call_raises_fn(): + distribute_lib.get_tower_context().merge_call(_call_raises_fn) + + +# Must be the argument to a get_tower_context().merge_call() call, calls +# dist.call_for_each_tower() with a function that calls a +# get_tower_context().merge_call() that raises an exception. +def _call_merge_raises_fn(dist): + dist.call_for_each_tower(_merge_raises_fn) + + +# Must be the argument to a distribution.call_for_each_tower() call, calls a +# get_tower_context().merge_call() that calls a call_for_each_tower() that +# calls a get_tower_context().merge_call() that raises an exception. +def _merge_call_merge_raises_fn(): + distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn) + + +class DistributionTestBase(test.TestCase): + """Some tests that should work with any DistributionStrategy.""" + + def _test_minimize_loss_eager(self, d): + with d.scope(): + l = core.Dense(1, use_bias=False) + + def loss(x): + # TODO(josh11b): What if this constant was instead a captured + # value? Would it need to be a value that has been passed + # through d.broadcast()? + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + # TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a + # common `implicit_grad` function and put it in DistributionStrategy. + grad_fn = backprop.implicit_grad(loss) + grad_fn = optimizer.get_filtered_grad_fn(grad_fn) + + def update(v, g): + return v.assign_sub(0.2 * g) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one, run_concurrently=l.built) + + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.fetch(v) + before_list.append(fetched) + # control_dependencies irrelevant but harmless in eager execution + with ops.control_dependencies([fetched]): + g = d.reduce("sum", g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.fetch(v)) + return before_list, after_list + + for i in range(10): + b, a = step() + if i == 0: + before, = b # pylint: disable=unbalanced-tuple-unpacking + after, = a # pylint: disable=unbalanced-tuple-unpacking + + error_before = abs(before.numpy() - 1) + error_after = abs(after.numpy() - 1) + # Error should go down + self.assertLess(error_after, error_before) + + def _test_minimize_loss_graph(self, d, soft_placement=False): + config = config_pb2.ConfigProto() + config.allow_soft_placement = soft_placement + config.gpu_options.per_process_gpu_memory_fraction = 0.3 + with context.graph_mode(), \ + ops.Graph().as_default(), \ + self.test_session(config=config) as sess, \ + d.scope(): + l = core.Dense(1, use_bias=False) + + def loss(x): + # TODO(josh11b): What if this constant was instead a captured + # value? Would it need to be a value that has been passed + # through d.broadcast()? + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + + grad_fn = backprop.implicit_grad(loss) + + def update(v, g): + return v.assign_sub(0.2 * g) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one) + + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.fetch(v) + before_list.append(fetched) + with ops.control_dependencies([fetched]): + g = d.reduce("sum", g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.fetch(v)) + return before_list, after_list + + before_out, after_out = step() + variables.global_variables_initializer().run() + for i in range(10): + b, a = sess.run((before_out, after_out)) + if i == 0: + before, = b + after, = a + + error_before = abs(before - 1) + error_after = abs(after - 1) + # Error should go down + self.assertLess(error_after, error_before) + + def _test_map_reduce(self, d, in_graph=None): + with d.scope(): + map_in = [constant_op.constant(i) for i in range(10)] + map_out = d.map(map_in, lambda x, y: x * y, 2) + observed = d.fetch(d.reduce("sum", map_out)) + expected = 90 # 2 * (0 + 1 + ... + 9) + self.assertEqual(expected, observed.numpy()) + + def _test_device_index(self, d): + with d.scope(): + expected_devices = [False] * len(d.worker_devices) + + def mark_devices_fn(device_id): + self.assertLess(device_id, len(d.worker_devices)) + self.assertFalse(expected_devices[device_id]) + expected_devices[device_id] = True + + d.call_for_each_tower(mark_devices_fn, d.worker_device_index) + self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) + + def _test_tower_id(self, d): + with d.scope(): + expected_devices = [False] * len(d.worker_devices) + + def mark_devices_fn(): + tower_id = distribute_lib.get_tower_context().tower_id + self.assertLess(tower_id, len(d.worker_devices)) + self.assertFalse(expected_devices[tower_id]) + expected_devices[tower_id] = True + + d.call_for_each_tower(mark_devices_fn) + self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) + + def _test_call_and_merge_exceptions(self, dist): + with dist.scope(): + with self.assertRaises(_TestException): + dist.call_for_each_tower(_raise_exception_fn) + with self.assertRaises(_TestException): + dist.call_for_each_tower(_merge_raises_fn) + with self.assertRaises(_TestException): + dist.call_for_each_tower(_merge_call_raises_fn) + with self.assertRaises(_TestException): + dist.call_for_each_tower(_merge_call_merge_raises_fn) diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py new file mode 100644 index 0000000000..c1ba22ed5a --- /dev/null +++ b/tensorflow/contrib/distribute/python/values.py @@ -0,0 +1,575 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Various classes representing distributed values. + +See go/tf-distribution-strategy. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import weakref + +import six + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import prefetching_ops_v2 +from tensorflow.contrib.eager.python import datasets +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.training import checkpointable +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import saver +from tensorflow.python.util import nest + + +# pylint: disable=line-too-long +# TODO(josh11b): Should device values be strings or DeviceSpec objects +# Not sure DeviceSpec objects are usable as a dict key. +class DistributedValues(object): + """Holds a map from device to values. Either PerDevice or Mirrored.""" + + def __init__(self, index): + self._index = {device_util.canonicalize(key): value + for key, value in six.iteritems(index)} + + def get(self, device=None): + """Returns the value for the current device or raises a ValueError.""" + if device is None: + tower_context = distribute_lib.get_tower_context() + if tower_context: + device = tower_context.device + else: + device = distribute_lib.get_update_device() + if device is None: + device = device_util.current() + device = device_util.canonicalize(device) + try: + return self._index[device] + except KeyError: + raise ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())) + + def on_device(self, device): + device = device_util.canonicalize(device) + return device in self._index + + @property + def devices(self): + return self._index.keys() + + def __str__(self): + return "%s:%s" % (self.__class__.__name__, self._index) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._index) + + # TODO(josh11b): Possibly make an accessor for _index for use by + # DistributionStrategy implementations. + + +class DistributedDelegate(DistributedValues): + """A map from device to values; acts as the same type as the values.""" + + def __init__(self, index): + super(DistributedDelegate, self).__init__(index) + + def __getattr__(self, name): + return getattr(self.get(), name) + + # pylint: disable=multiple-statements + def __add__(self, o): return self.get() + o + def __radd__(self, o): return o + self.get() + def __sub__(self, o): return self.get() - o + def __rsub__(self, o): return o - self.get() + def __mul__(self, o): return self.get() * o + def __rmul__(self, o): return o * self.get() + def __truediv__(self, o): return self.get() / o + def __rtruediv__(self, o): return o / self.get() + def __floordiv__(self, o): return self.get() // o + def __rfloordiv__(self, o): return o // self.get() + def __mod__(self, o): return self.get() % o + def __rmod__(self, o): return o % self.get() + def __lt__(self, o): return self.get() < o + def __le__(self, o): return self.get() <= o + def __gt__(self, o): return self.get() > o + def __ge__(self, o): return self.get() >= o + def __and__(self, o): return self.get() & o + def __rand__(self, o): return o & self.get() + def __or__(self, o): return self.get() | o + def __ror__(self, o): return o | self.get() + def __xor__(self, o): return self.get() ^ o + def __rxor__(self, o): return o ^ self.get() + def __getitem__(self, o): return self.get()[o] + def __pow__(self, o, modulo=None): return pow(self.get(), o, modulo) + def __rpow__(self, o): return pow(o, self.get()) + def __invert__(self): return ~self.get() + def __neg__(self): return -self.get() + def __abs__(self): return abs(self.get()) + + def __div__(self, o): + try: + return self.get().__div__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rdiv__(self, o): + try: + return self.get().__rdiv__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __matmul__(self, o): + try: + return self.get().__matmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rmatmul__(self, o): + try: + return self.get().__rmatmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + # TODO(josh11b): Even more operator overloads. + + +class PerDevice(DistributedValues): + """Holds a map from device to unsynchronized values.""" + pass + + +class Mirrored(DistributedValues): + """Holds a map from device to values which are kept in sync.""" + pass + + +def _assign_on_device(device, variable, tensor): + with ops.device(device): + return variable.assign(array_ops.identity(tensor)) + + +DistributedVarOp = collections.namedtuple( + "DistributedVarOp", ["name", "graph", "type"]) + + +class DistributedVariable(DistributedDelegate): + """Holds a map from device to variables.""" + # TODO(josh11b): Support changing the set of variables if e.g. if new + # devices are joining or a device is to leave. + + def __init__(self, index): + # Child class must set self._primary_var before calling + # super(...).__init__(index). + self._common_name = self._primary_var.name.split(":")[0] + super(DistributedVariable, self).__init__(index) + + @property + def initializer(self): + return control_flow_ops.group([v.initializer for v in self._index.values()]) + + @property + def graph(self): + return self._primary_var.graph + + @property + def _shared_name(self): + return self._common_name + + @property + def _unique_id(self): + return self._primary_var._unique_id # pylint: disable=protected-access + + @property + def name(self): + return self._primary_var.name + + @property + def dtype(self): + return self._primary_var.dtype + + @property + def shape(self): + return self._primary_var.shape + + def get_shape(self): + return self._primary_var.get_shape() + + @property + def op(self): + # We want cross-tower code that does some var.op.X calls + # to work (even if the current device isn't in self.devices), but + # other uses of var.op in a cross-tower context to fail. + if distribute_lib.get_cross_tower_context(): + return DistributedVarOp(self._primary_var.op.name, + self._primary_var.op.graph, + self._primary_var.op.type) + return self.get().op + + def _should_act_as_resource_variable(self): + """Pass resource_variable_ops.is_resource_variable check.""" + pass + + +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion(var, dtype=None, name=None, as_ref=False): + # Try to avoid assignments to and other mutations of MirroredVariable + # state except through a DistributionStrategy.update() call. + assert not as_ref + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion) +# TODO(josh11b): ops.register_dense_tensor_like_type(DistributedVariable)? + + +class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): + """Class for defining how to restore a MirroredVariable.""" + + def __init__(self, mirrored_variable, primary_variable, name): + self._mirrored_variable = mirrored_variable + super(_MirroredSaveable, self).__init__(primary_variable, "", name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + tensor, = restored_tensors + return control_flow_ops.group([ + _assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access + + +def _get_update_device(): + """Validate we are in update/update_non_slot() and return current device. + + This is used in MirroredVariable.assign* members, to make sure they + are only called via an update method, to make sure all components of the + variable are being updated in a consistent way. + + Returns: + A string device. + + Raises: + RuntimeError: If not in distribution.update()/.update_non_slot(). + """ + device = distribute_lib.get_update_device() + if device is None: + raise RuntimeError( + "Use DistributionStrategy.update() to modify a MirroredVariable.") + return device + + +class MirroredVariable(DistributedVariable, Mirrored, + checkpointable.CheckpointableBase): + """Holds a map from device to variables whose values are kept in sync.""" + + def __init__(self, index, primary_var): + # Use a weakref to make it easy to map from the contained values + # to the container without introducing a reference cycle. + for v in six.itervalues(index): + v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access + self._primary_var = primary_var + super(MirroredVariable, self).__init__(index) + + # We use _get_update_device() for the assign* methods to enforce + # that we are in an update() function. The arguments to update() are + # automatically unwrapped so the update() function would normally + # see regular variables, not MirroredVariables. However, the update + # function can still operate on wrapped MirroredVariables through + # object members, captured arguments, etc. This is more likely in an + # update_non_slot() function (like OptimizerV2._finish), which can + # update several non-slot variables in one call. + def assign_sub(self, *args, **kwargs): + return self.get(device=_get_update_device()).assign_sub(*args, **kwargs) + + def assign_add(self, *args, **kwargs): + return self.get(device=_get_update_device()).assign_add(*args, **kwargs) + + def assign(self, *args, **kwargs): + return self.get(device=_get_update_device()).assign(*args, **kwargs) + + def _gather_saveables_for_checkpoint(self): + """Overrides CheckpointableBase method. + + This allows both name-based and object-based save and restore of + MirroredVariables. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + def _saveable_factory(name=self._common_name): + return _MirroredSaveable(self, self._primary_var, name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + +class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): + """Class for defining how to restore a TowerLocalVariable.""" + + def __init__(self, tower_local_variable, name): + self._tower_local_variable = tower_local_variable + # We use a callable so that we don't have to evaluate this expression + # in the case where we are trying to restore instead of save. + def tensor(): + return distribute_lib.get_distribution_strategy().fetch( + tower_local_variable) + spec = saver.BaseSaverBuilder.SaveSpec( + tensor=tensor, + slice_spec="", + name=name, + dtype=tower_local_variable.dtype) + super(_TowerLocalSaveable, self).__init__(tensor, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + tensor, = restored_tensors + # To preserve the sum across save and restore, we have to divide the + # total across all devices when restoring a variable that was summed + # when saving. + if self._tower_local_variable.reduce_method == "sum": + tensor *= 1. / len(self._tower_local_variable.devices) + return control_flow_ops.group([ + _assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access + + +class TowerLocalVariable(DistributedVariable, PerDevice, + checkpointable.CheckpointableBase): + """Holds a map from device to variables whose values are reduced on save.""" + + def __init__(self, index, primary_var, reduce_method): + self._primary_var = primary_var + self._reduce_method = reduce_method + super(TowerLocalVariable, self).__init__(index) + + def assign_sub(self, *args, **kwargs): + return self.get().assign_sub(*args, **kwargs) + + def assign_add(self, *args, **kwargs): + return self.get().assign_add(*args, **kwargs) + + def assign(self, *args, **kwargs): + return self.get().assign(*args, **kwargs) + + @property + def reduce_method(self): + return self._reduce_method + + def _gather_saveables_for_checkpoint(self): + """Overrides CheckpointableBase method. + + This allows both name-based and object-based save and restore of + TowerLocalVariables. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + def _saveable_factory(name=self._common_name): + return _TowerLocalSaveable(self, name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + +def _devices_match(d1, d2): + return device_util.canonicalize(d1) == device_util.canonicalize(d2) + + +def regroup(per_device, wrap_class=PerDevice): + """Makes device->nest map into a nest of PerDevice/Mirrored values.""" + items = list(per_device.items()) + assert items + v0 = items[0][1] # First value + + if isinstance(v0, list): + for _, v in items[1:]: + assert isinstance(v, list) + assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" % + (len(v), len(v0), v, v0)) + return [regroup({k: v[i] for k, v in items}, wrap_class) + for i in range(len(v0))] + + if isinstance(v0, tuple): + for _, v in items[1:]: + assert isinstance(v, tuple) + assert len(v) == len(v0) + regrouped_tuple = tuple(regroup({k: v[i] for k, v in items}, wrap_class) + for i in range(len(v0))) + if hasattr(v0, "_fields"): + # This tuple is in fact a namedtuple! Create a new namedtuple instance + # and initialize it with the regrouped values: + assert hasattr(type(v0), "_make") + return type(v0)._make(regrouped_tuple) + else: + return regrouped_tuple + + if isinstance(v0, dict): + v0keys = set(v0.keys()) + for _, v in items[1:]: + assert isinstance(v, dict) + assert set(v.keys()) == v0keys + return {key: regroup({k: v[key] for k, v in items}, wrap_class) + for key in v0keys} + + # If exactly the same object across all devices, return it unwrapped. + same_id = True + for _, v in items[1:]: + if v is not v0: + same_id = False + break + # Consider three cases where same_id is true: + # * If v0 is a MirroredVariable (and same_id means it is the same + # across all devices), we want to return it. We check + # MirroredVariable specifically since it can look like it + # has a _mirrored_container member since its members do. + # * If v0 is a member of a mirrored variable, in which case + # hasattr(v0, "_mirrored_container") is true, we want to + # return the MirroredVariable that contains it using the + # _mirrored_container logic below. This case can trigger + # same_id when there is only one device. + # * In any other situation, same_id means we return v0. + if same_id and (isinstance(v0, MirroredVariable) or + not hasattr(v0, "_mirrored_container")): + return v0 + + # Detect the case where each device has a parallel component of the + # same MirroredVariable. In this case we want to return the + # containing MirroredVariable, after a bunch of sanity checking. + # In particular, each component should have the same container, + # and the devices of the variables should match the keys of the + # per-device dictionary. + # TODO(josh11b): Do we need similar logic for TowerLocalVariables? + if hasattr(v0, "_mirrored_container"): + # pylint: disable=protected-access + assert not isinstance(v0, MirroredVariable), ( + "ids = %s, items = %s" % ([id(v[1]) for v in items], items)) + assert _devices_match(v0.device, items[0][0]), ( + "v0.device = %s, items = %s" % (v0.device, items)) + mirrored_container = v0._mirrored_container() + assert mirrored_container is not None + for d, v in items[1:]: + assert _devices_match(v.device, d), ( + "v.device = %s, d = %s, items = %s" % (v.device, d, items)) + assert mirrored_container is v._mirrored_container() + return mirrored_container + # pylint: enable=protected-access + + return wrap_class(per_device) + + +def select_device(device, structured): + """Specialize a nest of regular & per-device values for one device.""" + def _get(x): + return x.get(device) if isinstance(x, DistributedValues) else x + + return nest.map_structure(_get, structured) + + +def select_device_mirrored(device, structured): + """Specialize a nest of regular & mirrored values for one device.""" + def _get_mirrored(x): + if isinstance(x, DistributedValues): + if not isinstance(x, Mirrored): + raise TypeError( + "Expected value to be mirrored across towers: %s in %s." % + (x, structured)) + return x.get(device) + else: + return x + + return nest.map_structure(_get_mirrored, structured) + + +class PerDeviceDataIterator(object): + """An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`.""" + + def __init__(self, iterator, devices, prefetch_on_device=None): + self._iterator = iterator + self._devices = devices + self._prefetch_on_device = prefetch_on_device + + def get_next(self, name=None): + """Scatter the input across devices.""" + if self._prefetch_on_device: + data_list = self._iterator.get_next(name=name) + index = dict(zip(self._devices, data_list)) + else: + batch = self._iterator.get_next(name=name) + index = {} + def get_ith(i): + return lambda x: x[i] + + for i, d in enumerate(self._devices): + index[d] = nest.map_structure(get_ith(i), batch) + if context.executing_eagerly(): + with ops.device(d): + index[d] = nest.map_structure(array_ops.identity, index[d]) + + return regroup(index) + + +class PerDeviceDataset(object): + """Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" + + def __init__(self, dataset, devices, prefetch_on_device=None): + self._devices = devices + + # Default to using prefetching in graph mode, unless specified. + # TODO(priyag): Enable prefetching in eager mode. + self._prefetch_on_device = prefetch_on_device + if self._prefetch_on_device is None: + self._prefetch_on_device = not context.executing_eagerly() + assert not (self._prefetch_on_device and context.executing_eagerly()), ( + "Prefetching is only supported in graph mode currently") + + if self._prefetch_on_device: + self._dataset = dataset + else: + # TODO(priyag): If dropping remainder is not appropriate, find another + # approach to distributing the dataset when not possible to divide evenly. + # Possibly not an issue when we start using PartitionedDataset. + self._dataset = dataset.apply( + batching.batch_and_drop_remainder(len(devices))) + + def make_one_shot_iterator(self): + """Get a one time use iterator for the distributed PerDeviceDataset.""" + if self._prefetch_on_device: + on_device_dataset = self._dataset.apply( + prefetching_ops_v2.prefetch_to_devices(self._devices)) + dataset_iterator = on_device_dataset.make_one_shot_iterator() + elif context.executing_eagerly(): + dataset_iterator = datasets.Iterator(self._dataset) + else: + dataset_iterator = self._dataset.make_one_shot_iterator() + + return PerDeviceDataIterator( + dataset_iterator, self._devices, self._prefetch_on_device) + + +class MapOutput(object): + """Map can result in multiple outputs per device.""" + + def __init__(self, l): + self._l = l + + def get(self): + return self._l diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py new file mode 100644 index 0000000000..5c0d4b7d6c --- /dev/null +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -0,0 +1,807 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the distributed values library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import device_util +from tensorflow.python.training import saver as saver_lib + + +@test_util.with_c_api +class DistributedValuesTest(test.TestCase): + + def testGetEager(self): + with ops.device("/device:CPU:0"): + one = constant_op.constant(1) + two = constant_op.constant(2) + v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + self.assertEqual(two, v.get("/device:GPU:0")) + self.assertEqual(one, v.get()) + with self.assertRaises(ValueError): + self.assertIsNone(v.get("/device:GPU:2")) + + def testGetGraph(self): + with context.graph_mode(), \ + ops.Graph().as_default(), \ + ops.device("/device:CPU:0"): + one = constant_op.constant(1) + two = constant_op.constant(2) + v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + self.assertEqual(two, v.get("/device:GPU:0")) + self.assertEqual(one, v.get()) + with self.assertRaises(ValueError): + self.assertIsNone(v.get("/device:GPU:2")) + + def testCanonicalization(self): + canonical_cpu = ["/job:localhost/replica:0/task:0/device:CPU:0"] + v = values.DistributedValues({"": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + v = values.DistributedValues({"/device:CPU:0": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + v = values.DistributedValues({"/cpu:0": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + v = values.DistributedValues({"/CPU:0": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + with self.assertRaises(AssertionError): + v = values.DistributedValues({"/device:cpu:0": 42}) + + +@test_util.with_c_api +class DistributedDelegateTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testGetAttr(self): + with ops.device("/device:CPU:0"): + + class Foo(object): + + def __init__(self, x): + self.x = x + + v = values.DistributedDelegate( + {"/device:CPU:0": Foo(7), "/device:GPU:0": Foo(8)}) + self.assertEqual(7, v.x) + with self.assertRaises(AttributeError): + _ = v.y + + @test_util.run_in_graph_and_eager_modes() + def testOperatorOverride(self): + with ops.device("/device:CPU:0"): + v = values.DistributedDelegate({"/device:CPU:0": 7, "/device:GPU:0": 8}) + # v should act like int(7). + self.assertEqual(8, v + 1) + self.assertEqual(10, 3 + v) + self.assertEqual(14, v + v) + self.assertEqual(5, v - 2) + self.assertEqual(6, 13 - v) + self.assertEqual(0, v - v) + self.assertEqual(14, v * 2) + self.assertEqual(21, 3 * v) + self.assertEqual(49, v * v) + self.assertEqual(3.5, v / 2) + self.assertEqual(1.5, 10.5 / v) + self.assertEqual(3, v // 2) + self.assertEqual(2, 15 // v) + self.assertEqual(1, v % 2) + self.assertEqual(2, 16 % v) + self.assertTrue(v < 12) + self.assertTrue(v <= 12) + self.assertFalse(v > 12) + self.assertFalse(v >= 12) + self.assertFalse(12 < v) + self.assertFalse(12 <= v) + self.assertTrue(12 > v) + self.assertTrue(12 >= v) + self.assertEqual(3, v & 3) + self.assertEqual(3, 11 & v) + self.assertEqual(15, v | 8) + self.assertEqual(23, 16 | v) + self.assertEqual(4, v ^ 3) + self.assertEqual(12, 11 ^ v) + self.assertEqual(343, pow(v, 3)) + self.assertEqual(3, pow(v, 3, 10)) + self.assertEqual(128, pow(2, v)) + self.assertEqual(-7, -v) + self.assertEqual(~7, ~v) + self.assertEqual(7, abs(v)) + with self.assertRaises(TypeError): + _ = v[2] + + +def _device_str(d): + return "/device:GPU:" + str(d) + + +def _nested_value(d): + return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d) + + +def _make_mirrored(): + v = [] + index = {} + devices = ["/device:GPU:0", "/device:CPU:0"] + for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): + with ops.device(d): + v.append(variable_scope.get_variable( + name=n, initializer=init, use_resource=True)) + index[d] = v[-1] + mirrored = values.MirroredVariable(index, v[0]) + return v, devices, mirrored + + +@test_util.with_c_api +class RegroupAndSelectDeviceTest(test.TestCase): + + def _is_per_device(self, result, expected, klass=values.PerDevice): + self.assertIsInstance(result, klass) + # We canonicalize the devices to match the device strings returned + # by PerDevice, which also does device string canonicalization. + devices = [device_util.canonicalize(_device_str(i)) + for i in range(len(expected))] + self.assertEqual(set(devices), set(result.devices)) + for i, d in enumerate(devices): + self.assertEqual(expected[i], result.get(d)) + self.assertEqual(expected[i], result.get(_device_str(i))) + + def testNested(self): + result = values.regroup({_device_str(0): _nested_value("1"), + _device_str(1): _nested_value("2")}) + self.assertIsInstance(result, tuple) + self.assertEqual(3, len(result)) + self._is_per_device(result[0], ["a1", "a2"]) + self._is_per_device(result[2], ["h1", "h2"]) + + self.assertIsInstance(result[1], list) + self.assertEqual(3, len(result[1])) + self._is_per_device(result[1][0], ["b1", "b2"]) + self._is_per_device(result[1][2], ["g1", "g2"]) + + self.assertIsInstance(result[1][1], dict) + self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) + self._is_per_device(result[1][1]["c"], ["d1", "d2"]) + self._is_per_device(result[1][1]["e"], ["f1", "f2"]) + + # Also test that we can undo the merge using select_device() + self.assertEqual(_nested_value("1"), + values.select_device(_device_str(0), result)) + self.assertEqual(_nested_value("2"), + values.select_device(_device_str(1), result)) + # select_device_mirrored() should fail due to non-mirrored values + with self.assertRaises(TypeError): + values.select_device_mirrored(_device_str(0), result) + with self.assertRaises(TypeError): + values.select_device_mirrored(_device_str(1), result) + + def testWrapClass(self): + # Normally a mirrored value would be the same across devices, but + # for a test it is convenient to be able to tell the values apart. + result = values.regroup({_device_str(0): _nested_value("1"), + _device_str(1): _nested_value("2")}, + values.Mirrored) + self.assertIsInstance(result, tuple) + self.assertEqual(3, len(result)) + self._is_per_device(result[0], ["a1", "a2"], values.Mirrored) + self._is_per_device(result[2], ["h1", "h2"], values.Mirrored) + + self.assertIsInstance(result[1], list) + self.assertEqual(3, len(result[1])) + self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored) + self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored) + + self.assertIsInstance(result[1][1], dict) + self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) + self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored) + self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored) + + # Also test that we can undo the merge using select_device() + self.assertEqual(_nested_value("1"), + values.select_device(_device_str(0), result)) + self.assertEqual(_nested_value("2"), + values.select_device(_device_str(1), result)) + # Values are marked as mirrored, so select_device_mirrored() is allowed. + self.assertEqual(_nested_value("1"), + values.select_device_mirrored(_device_str(0), result)) + self.assertEqual(_nested_value("2"), + values.select_device_mirrored(_device_str(1), result)) + + def testMirroredContainer(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + v, devices, mirrored = _make_mirrored() + result = values.regroup(dict(zip(devices, v))) + self.assertIs(mirrored, result) + + def testSameId(self): + foo = object() + result = values.regroup({_device_str(0): ("a", foo), + _device_str(1): ("b", foo)}) + self.assertIsInstance(result, tuple) + self.assertEqual(2, len(result)) + self._is_per_device(result[0], ["a", "b"]) + self.assertIs(foo, result[1]) + + # Test select_device(), should undo the merge done by regroup(). + result_0 = values.select_device(_device_str(0), result) + self.assertIsInstance(result_0, tuple) + self.assertEqual(2, len(result_0)) + self.assertEqual("a", result_0[0]) + self.assertIs(foo, result_0[1]) + result_1 = values.select_device(_device_str(1), result) + self.assertIsInstance(result_1, tuple) + self.assertEqual(2, len(result_1)) + self.assertEqual("b", result_1[0]) + self.assertIs(foo, result_1[1]) + + def testOneDevice(self): + result = values.regroup({_device_str(0): _nested_value("1")}) + # On one device regroup() and select_device() are basically identity. + self.assertEqual(_nested_value("1"), result) + self.assertEqual(_nested_value("1"), + values.select_device(_device_str(0), result)) + + # The one exception has to do with MirroredVariables. + d = "/device:CPU:0" + with ops.device(d): + v = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + index = {d: v} + mirrored = values.MirroredVariable(index, v) + result = values.regroup(index) + self.assertIs(mirrored, result) + + def testNamedTupleEstimatorSpec(self): + with context.graph_mode(), ops.Graph().as_default(): + created_estimator_specs = {} + to_regroup = {} + + for device_id in range(3): + spec = model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.TRAIN, + loss=constant_op.constant(device_id / 2), + train_op=array_ops.identity(constant_op.constant(device_id))) + created_estimator_specs[device_id] = spec + to_regroup[_device_str(device_id)] = spec + + merged_estimator_spec = values.regroup(to_regroup) + + self.assertTrue( + isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) + self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) + for device_id in range(3): + d = _device_str(device_id) + self.assertEquals(created_estimator_specs[device_id].loss, + merged_estimator_spec.loss.get(d)) + self.assertEquals(created_estimator_specs[device_id].train_op, + merged_estimator_spec.train_op.get(d)) + # Scaffold is populated by `EstimatorSpec.__new__`. + self.assertEquals(created_estimator_specs[device_id].scaffold, + merged_estimator_spec.scaffold.get(d)) + # Also test that we can undo the merge using select_device() + self.assertEquals(created_estimator_specs[device_id], + values.select_device(_device_str(device_id), + merged_estimator_spec)) + + +@test_util.with_c_api +class PerDeviceDatasetTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _test_iterator_no_prefetch(self, devices, dataset, expected_values): + per_device_dataset = values.PerDeviceDataset( + dataset, devices, prefetch_on_device=False) + iterator = per_device_dataset.make_one_shot_iterator() + + for expected_value in expected_values: + next_element = iterator.get_next() + actual = self.evaluate([ + values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, actual) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + self.evaluate([ + values.select_device(d, next_element) for d in devices]) + + def _test_iterator_with_prefetch(self, devices, dataset, expected_values): + if not context.executing_eagerly(): + per_device_dataset = values.PerDeviceDataset( + dataset, devices, prefetch_on_device=True) + iterator = per_device_dataset.make_one_shot_iterator() + + # With prefetching, we cannot guarantee which input ends up on which + # device, so we verify that the complete set seen on all devices is + # correct, and equal numbers are distributed to each device. + combined_actual = [] + combined_expected = [] + for expected_value in expected_values: + next_element = iterator.get_next() + combined_actual.extend(self.evaluate([ + values.select_device(d, next_element) for d in devices])) + combined_expected.extend(expected_value) + + self.assertEqual(set(combined_expected), set(combined_actual)) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + self.evaluate([ + values.select_device(d, next_element) for d in devices]) + + def _test_iterator(self, devices, dataset, expected_values): + self._test_iterator_no_prefetch(devices, dataset, expected_values) + self._test_iterator_with_prefetch(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes() + def testOneDevice(self): + devices = ["/device:CPU:0"] + dataset = dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleDevices(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset = dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testTupleDataset(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testUnevenDatasetBatches(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset = dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(devices, dataset, expected_values) + + +@test_util.with_c_api +class MirroredVariableTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + @test_util.run_in_graph_and_eager_modes(config=config) + def testProperties(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + v, _, mirrored = _make_mirrored() + + self.assertEquals(v[0].name, mirrored.name) + self.assertEquals(v[0].dtype, mirrored.dtype) + self.assertEquals(v[0].shape, mirrored.shape) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testVariableOnAnotherDevice(self): + v = variable_scope.get_variable( + name="v", initializer=[1.], use_resource=True) + index = {"/job:foo/device:CPU:0": v} + mirrored = values.MirroredVariable(index, v) + + self.assertEquals(v.name, mirrored.name) + self.assertEquals(v.dtype, mirrored.dtype) + self.assertEquals(v.shape, mirrored.shape) + + def _assign_mirrored(self, devices, v, new): + for d, var, n in zip(devices, v, new): + with ops.device(d): + self.evaluate(var.assign(n)) + + def _save_return_saver(self, sess, var): + saver = saver_lib.Saver(var_list=[var]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + return saver.save(sess, prefix), saver + + def _save(self, sess, var): + save_path, _ = self._save_return_saver(sess, var) + return save_path + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveAndRestoreMirroredOneGraph(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + with self.test_session() as sess: + v, devices, mirrored = _make_mirrored() + + # Overwrite the initial values. + self._assign_mirrored(devices, v, [3., 4.]) + + # Saves the current value of v[0], 3. + save_path, saver = self._save_return_saver(sess, mirrored) + + # Change the values between save and restore. + self._assign_mirrored(devices, v, [5., 6.]) + + # Restores the saved value of 3. to both variables. + saver.restore(sess, save_path) + self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) + + def _save_mirrored(self): + """Save variables with mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + v, devices, mirrored = _make_mirrored() + + # Overwrite the initial values. + self._assign_mirrored(devices, v, [3., 4.]) + + # Saves the current value of v[0], 3. + save_path = self._save(sess, mirrored) + + # Change the values between save and restore. + self._assign_mirrored(devices, v, [5., 6.]) + return save_path + + def _save_normal(self): + """Save variables without mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(3.)) + + # Saves the current value of var, 3. + save_path = self._save(sess, var) + + # Change the values between save and restore. + self.evaluate(var.assign(5.)) + return save_path + + def _restore_normal(self, save_path): + """Restore to variables without mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=7., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(8.)) + + # Restores the saved value of 3. to `var`. + saver = saver_lib.Saver(var_list=[var]) + saver.restore(sess, save_path) + self.assertEqual(3., self.evaluate(var)) + + def _restore_mirrored(self, save_path): + """Restore to variables with mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + v, devices, mirrored = _make_mirrored() + + # Overwrite the initial values. + self._assign_mirrored(devices, v, [7., 8.]) + + # Restores the saved value of 3. to both variables. + saver = saver_lib.Saver(var_list=[mirrored]) + saver.restore(sess, save_path) + self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveMirroredRestoreMirrored(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_mirrored() + self._restore_mirrored(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveMirroredRestoreNormal(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_mirrored() + self._restore_normal(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveNormalRestoreMirrored(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_normal() + self._restore_mirrored(save_path) + + +_devices = ["/device:GPU:0", "/device:CPU:0"] + + +def _make_tower_local(method): + v = [] + index = {} + for d, n, init in zip(_devices, ["v", "v/replica"], [1., 2.]): + with ops.device(d): + v.append(variable_scope.get_variable( + name=n, initializer=init, use_resource=True)) + index[d] = v[-1] + tower_local = values.TowerLocalVariable(index, v[0], method) + return v, tower_local + + +@test_util.with_c_api +class TowerLocalVariableTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + @test_util.run_in_graph_and_eager_modes(config=config) + def testProperties(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + v, tower_local = _make_tower_local("sum") + + self.assertEquals(v[0].name, tower_local.name) + self.assertEquals(v[0].dtype, tower_local.dtype) + self.assertEquals(v[0].shape, tower_local.shape) + self.assertEquals("sum", tower_local.reduce_method) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testVariableOnAnotherDevice(self): + v = variable_scope.get_variable( + name="v", initializer=[1.], use_resource=True) + index = {"/job:foo/device:CPU:0": v} + tower_local = values.TowerLocalVariable(index, v, "mean") + + self.assertEquals(v.name, tower_local.name) + self.assertEquals(v.dtype, tower_local.dtype) + self.assertEquals(v.shape, tower_local.shape) + self.assertEquals("mean", tower_local.reduce_method) + + def _assign_tower_local(self, devices, v, new): + for d, var, n in zip(devices, v, new): + with ops.device(d): + self.evaluate(var.assign(n)) + + def _save_return_saver(self, sess, var): + saver = saver_lib.Saver(var_list=[var]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + return saver.save(sess, prefix), saver + + def _save(self, sess, var): + save_path, _ = self._save_return_saver(sess, var) + return save_path + + def _dist_scope(self): + return mirrored_strategy.MirroredStrategy(_devices).scope() + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveAndRestoreTowerLocalSumOneGraph(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + with self.test_session() as sess: + v, tower_local = _make_tower_local("sum") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [3., 4.]) + + with self._dist_scope(): + # Saves the current value of v[0] + v[1], 7. + save_path, saver = self._save_return_saver(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + + # Restores the saved value of 7. which gets divided equally + # between the variables. + saver.restore(sess, save_path) + self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveAndRestoreTowerLocalMeanOneGraph(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + with self.test_session() as sess: + v, tower_local = _make_tower_local("mean") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [3., 4.]) + + with self._dist_scope(): + # Saves the current value of (v[0] + v[1])/2, 3.5. + save_path, saver = self._save_return_saver(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + + # Restores the saved value of 3.5 to both variables. + saver.restore(sess, save_path) + self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) + + def _save_tower_local_mean(self): + """Save variables with mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("mean") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [3., 4.]) + + with self._dist_scope(): + # Saves the current value of (v[0] + v[1])/2, 3.5 + save_path = self._save(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + return save_path + + def _save_tower_local_sum(self): + """Save variables with mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("sum") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [1.5, 2.]) + + with self._dist_scope(): + # Saves the current value of v[0] + v[1], 3.5 + save_path = self._save(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + return save_path + + def _save_normal(self): + """Save variables without mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(3.5)) + + # Saves the current value of var, 3.5. + save_path = self._save(sess, var) + + # Change the values between save and restore. + self.evaluate(var.assign(5.)) + return save_path + + def _restore_normal(self, save_path): + """Restore to variables without mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=7., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(8.)) + + # Restores the saved value of 3.5 to `var`. + saver = saver_lib.Saver(var_list=[var]) + saver.restore(sess, save_path) + self.assertEqual(3.5, self.evaluate(var)) + + def _restore_tower_local_mean(self, save_path): + """Restore to variables with mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("mean") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [7., 8.]) + + with self._dist_scope(): + # Restores the saved value of 3.5 to both variables. + saver = saver_lib.Saver(var_list=[tower_local]) + saver.restore(sess, save_path) + self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) + + def _restore_tower_local_sum(self, save_path): + """Restore to variables with mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("sum") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [7., 8.]) + + with self._dist_scope(): + # Restores the saved value of 3.5 to both variables. + saver = saver_lib.Saver(var_list=[tower_local]) + saver.restore(sess, save_path) + self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalRestoreTowerLocalMean(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_mean() + self._restore_tower_local_mean(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalRestoreTowerLocalSum(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_sum() + self._restore_tower_local_sum(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalMeanRestoreNormal(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_mean() + self._restore_normal(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalSumRestoreNormal(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_sum() + self._restore_normal(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveNormalRestoreTowerLocalMean(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_normal() + self._restore_tower_local_mean(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveNormalRestoreTowerLocalSum(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_normal() + self._restore_tower_local_sum(save_path) + + +if __name__ == "__main__": + test.main() |