aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-10-09 15:07:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 15:12:12 -0700
commit5f69248a692f7b47ea11930621f4f19d0397fe8c (patch)
treee6ae69c17d798afc96ba83644bf2ce6656181856
parentc1093a3757224257fed0f7a1959d0fc99d5c757f (diff)
Make defun work under distributed strategies.
The core of the change is have the gradient tape capture distributed variables instead of plain ResourceVariables. In other words, we move the distribution awareness from defun down to tape and rely on distributed variable magic to provide us with the right variable at runtime. In tower context, we always watch the container (e.g. MirroredVariable). In cross tower context, we always watch all the components. PiperOrigin-RevId: 216430530
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py23
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py58
-rw-r--r--tensorflow/python/eager/backprop_test.py24
-rw-r--r--tensorflow/python/eager/function.py53
-rw-r--r--tensorflow/python/eager/tape.py31
5 files changed, 128 insertions, 61 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index a32424b316..0f82508428 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -293,7 +293,8 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for v in index.values():
- l.remove(v)
+ if v in l:
+ l.remove(v)
g.add_to_collections(collections, result)
elif ops.GraphKeys.GLOBAL_STEP in collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
@@ -461,16 +462,20 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# name as the absolute name of the variable.
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
# Initialize replicas with the same value:
- if context.executing_eagerly():
- kwargs["initial_value"] = array_ops.identity(
- index[devices[0]].value())
- else:
- def initial_value_fn(device=d):
+ def initial_value_fn(device=d):
+ if context.executing_eagerly():
+ init_value = index[devices[0]].value()
+ return array_ops.identity(init_value)
+ else:
with ops.device(device):
- return array_ops.identity(index[devices[0]].initial_value)
- kwargs["initial_value"] = initial_value_fn
+ init_value = index[devices[0]].initial_value
+ return array_ops.identity(init_value)
+ kwargs["initial_value"] = initial_value_fn
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- v = next_creator(*args, **kwargs)
+ # Don't record operations (e.g. other variable reads) during
+ # variable creation.
+ with tape.stop_recording():
+ v = next_creator(*args, **kwargs)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
return index
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index eeac528329..ed36639ce8 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import sys
+import numpy as np
+
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
@@ -34,7 +36,10 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training as keras_training
+from tensorflow.python.keras.layers import core as keras_core
from tensorflow.python.layers import core
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
@@ -43,6 +48,8 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import server_lib
@@ -1245,6 +1252,22 @@ class MockModel(object):
return x
+class MiniModel(keras_training.Model):
+ """Minimal model for mnist.
+
+ Useful for testing and debugging on slow TPU simulators.
+ """
+
+ def __init__(self):
+ super(MiniModel, self).__init__(name="")
+ self.fc = keras_core.Dense(1, name="fc", kernel_initializer="ones",
+ bias_initializer="ones")
+
+ def call(self, inputs, training=True):
+ inputs = array_ops.ones([1, 10])
+ return self.fc(inputs)
+
+
class MirroredStrategyDefunTest(test.TestCase):
def _skip_eager_if_gpus_less_than(self, num_gpus):
@@ -1365,6 +1388,41 @@ class MirroredStrategyDefunTest(test.TestCase):
"GPU:0": 3.0 * 1.25})
self._call_and_check(fn1, [factors], expected_result, [fn1])
+ @test_util.run_in_graph_and_eager_modes()
+ def testTrain(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ cpu_dev = device_util.canonicalize("CPU:0")
+ gpu_dev = device_util.canonicalize("GPU:0")
+ devices = [cpu_dev, gpu_dev]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+
+ with dist.scope():
+ mock_model = MiniModel()
+ mock_model.call = function.defun(mock_model.call)
+
+ def loss_fn(ctx):
+ del ctx
+ return mock_model(array_ops.ones([1, 10]))
+
+ gradients_fn = backprop.implicit_grad(loss_fn)
+ gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
+ grads_and_vars = dist.call_for_each_tower(
+ gradients_fn, None, run_concurrently=False)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.25)
+ update_ops = optimizer._distributed_apply(dist, grads_and_vars) # pylint: disable=protected-access
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(update_ops)
+
+ updated_var_values = self.evaluate(mock_model.variables)
+ # All variables start at 1.0 and get two updates of 0.25.
+ self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0])
+ self.assertAllEqual([0.5], updated_var_values[1])
+
+
class MultiWorkerMirroredStrategyTest(
multi_worker_test_base.MultiWorkerTestBase,
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 7e5c9f3cb6..b1b20fafd2 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -258,6 +258,30 @@ class BackpropTest(test.TestCase):
loss += v * v
self.assertAllEqual(t.gradient(loss, v), 2.0)
+ def testAutomaticWatchedVariables(self):
+ with backprop.GradientTape() as t:
+ self.assertEqual(0, len(t.watched_variables()))
+ v = resource_variable_ops.ResourceVariable(1.0)
+ loss = v * v
+ self.assertAllEqual([v], t.watched_variables())
+
+ t.reset()
+ self.assertEqual(0, len(t.watched_variables()))
+ loss += v * v
+ self.assertAllEqual([v], t.watched_variables())
+
+ def testExplicitWatchedVariables(self):
+ with backprop.GradientTape() as t:
+ self.assertEqual(0, len(t.watched_variables()))
+ v = resource_variable_ops.ResourceVariable(1.0)
+ t.watch(v)
+ self.assertAllEqual([v], t.watched_variables())
+
+ t.reset()
+ self.assertEqual(0, len(t.watched_variables()))
+ t.watch(v)
+ self.assertAllEqual([v], t.watched_variables())
+
@test_util.assert_no_new_tensors
def testGradientNone(self):
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index ff138cad1e..f1a63adce1 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -51,7 +51,6 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@@ -202,6 +201,7 @@ class FuncGraph(ops.Graph):
# from the default graph even in eager mode. Maybe it should be part of the
# eager context?
self._distribution_strategy_stack = graph._distribution_strategy_stack
+ self._variable_creator_stack = graph._variable_creator_stack
# Inherit the graph key, since this is used for matching variables in
# optimizers.
self._graph_key = graph._graph_key
@@ -563,17 +563,6 @@ class Function(object):
self._func_graph.inputs, self._func_graph.outputs, self._attrs)
self._backward_graph_function = None
- # Map holding distributed variables, keyed by resource handle tensors.
- self._distributed_variables = {}
- strategy = distribution_strategy_context.get_distribution_strategy()
- for variable in self._func_graph.variables:
- # If variable is not distributed, unwrap returns [variable].
- component_variables = strategy.unwrap(variable)
- # Only update the dictionary when the variable is actually distributed.
- if (len(component_variables) > 1 or component_variables[0] != variable):
- for component_variable in component_variables:
- self._distributed_variables[component_variable.handle] = variable
-
def __call__(self, *args):
"""Executes the wrapped function.
@@ -602,7 +591,6 @@ class Function(object):
if v.trainable:
tape.variable_accessed(v)
- captures = self._resolve_captured_inputs()
tensor_inputs = []
for i, arg in enumerate(nest.flatten(args)):
if isinstance(arg, resource_variable_ops.ResourceVariable):
@@ -615,9 +603,10 @@ class Function(object):
raise ValueError("All inputs to `Function`s must be Tensors; "
"on invocation of %s, the %d-th input (%s) was not a "
"Tensor." % (self._func_graph.name, i, str(arg)))
- args = tensor_inputs + captures
+ args = tensor_inputs + self._captured_inputs
- if tape.should_record(tensor_inputs) or tape.should_record(captures):
+ if (tape.should_record(tensor_inputs) or
+ tape.should_record(self._captured_inputs)):
return self._backprop_call(args)
# Only need to override the gradient in graph mode and when we have outputs.
@@ -804,32 +793,6 @@ class Function(object):
args, backward_function)
return self._build_call_outputs(real_outputs)
- def _resolve_captured_inputs(self):
- """Resolve captured distributed variables to their current values.
-
- Some inputs can be distributed variables. Such variables yield a different
- component (i.e. actual tf.Variable) variables depending on the context of
- execution.
-
- Returns:
- a list of resolved captured input tensors.
- """
- if self._distributed_variables:
- # Loop over each captured input and check if it corresponds to something
- # distributed. If so, get its _distributed_container and fetch the
- # component appropriate for the current execution context.
- resolved_captured_inputs = self._captured_inputs[:]
- for i, captured_input in enumerate(self._captured_inputs):
- distributed_var = self._distributed_variables.get(captured_input, None)
- if distributed_var is not None:
- # distributed variables override __getattr__ and substitute the
- # right component variable. In here, `distributed_var.handle`
- # actually does the equivalent of
- # distributed_var.get_current_component_var().handle.
- resolved_captured_inputs[i] = distributed_var.handle
- return resolved_captured_inputs
- return self._captured_inputs
-
def _build_call_outputs(self, result):
"""Maps the fdef output list to actual output structure.
@@ -1010,14 +973,6 @@ def func_graph_from_py_func(name,
for x in _flatten(func_graph.structured_outputs)
if x is not None)
- # Some captured variables might be components of DistributedValues.
- # Instead of storing non-distributed component variables, we
- # store their distributed containers so we can retrieve the correct
- # component variables at call-time.
- strategy = distribution_strategy_context.get_distribution_strategy()
- for i, variable in enumerate(variables):
- # If variable is not distributed value_container returns itself.
- variables[i] = strategy.value_container(variable)
func_graph.variables = variables
# Register any other functions defined in the graph.
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index 399d90223c..ade945f874 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -21,6 +21,15 @@ from __future__ import print_function
import contextlib
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.util.lazy_loader import LazyLoader
+
+# There is a circular dependency between this, ops.py, and
+# distribution_strategy_context.
+# TODO(b/117329403): Remove this circular dependency.
+distribution_strategy_context = LazyLoader(
+ "distribute_lib", globals(),
+ "tensorflow.python.training."
+ "distribution_strategy_context")
class Tape(object):
@@ -52,12 +61,28 @@ def watch(tape, tensor):
def watch_variable(tape, variable):
"""Marks this variable to be watched by the given tape."""
- pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ if distribution_strategy_context.get_tower_context():
+ variables = [strategy.value_container(variable)]
+ else:
+ variables = strategy.unwrap(variable)
+ for var in variables:
+ pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
def variable_accessed(variable):
- """Notifies all tapes in the stack that a variable has been accessed."""
- pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable)
+ """Notifies all tapes in the stack that a variable has been accessed.
+
+ Args:
+ variable: variable to be watched.
+ """
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ if distribution_strategy_context.get_tower_context():
+ variables = [strategy.value_container(variable)]
+ else:
+ variables = strategy.unwrap(variable)
+ for var in variables:
+ pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var)
def pop_tape(tape):