aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py23
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py58
-rw-r--r--tensorflow/python/eager/backprop_test.py24
-rw-r--r--tensorflow/python/eager/function.py53
-rw-r--r--tensorflow/python/eager/tape.py31
5 files changed, 128 insertions, 61 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index a32424b316..0f82508428 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -293,7 +293,8 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for v in index.values():
- l.remove(v)
+ if v in l:
+ l.remove(v)
g.add_to_collections(collections, result)
elif ops.GraphKeys.GLOBAL_STEP in collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
@@ -461,16 +462,20 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# name as the absolute name of the variable.
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
# Initialize replicas with the same value:
- if context.executing_eagerly():
- kwargs["initial_value"] = array_ops.identity(
- index[devices[0]].value())
- else:
- def initial_value_fn(device=d):
+ def initial_value_fn(device=d):
+ if context.executing_eagerly():
+ init_value = index[devices[0]].value()
+ return array_ops.identity(init_value)
+ else:
with ops.device(device):
- return array_ops.identity(index[devices[0]].initial_value)
- kwargs["initial_value"] = initial_value_fn
+ init_value = index[devices[0]].initial_value
+ return array_ops.identity(init_value)
+ kwargs["initial_value"] = initial_value_fn
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- v = next_creator(*args, **kwargs)
+ # Don't record operations (e.g. other variable reads) during
+ # variable creation.
+ with tape.stop_recording():
+ v = next_creator(*args, **kwargs)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
return index
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index eeac528329..ed36639ce8 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import sys
+import numpy as np
+
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
@@ -34,7 +36,10 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training as keras_training
+from tensorflow.python.keras.layers import core as keras_core
from tensorflow.python.layers import core
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
@@ -43,6 +48,8 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import server_lib
@@ -1245,6 +1252,22 @@ class MockModel(object):
return x
+class MiniModel(keras_training.Model):
+ """Minimal model for mnist.
+
+ Useful for testing and debugging on slow TPU simulators.
+ """
+
+ def __init__(self):
+ super(MiniModel, self).__init__(name="")
+ self.fc = keras_core.Dense(1, name="fc", kernel_initializer="ones",
+ bias_initializer="ones")
+
+ def call(self, inputs, training=True):
+ inputs = array_ops.ones([1, 10])
+ return self.fc(inputs)
+
+
class MirroredStrategyDefunTest(test.TestCase):
def _skip_eager_if_gpus_less_than(self, num_gpus):
@@ -1365,6 +1388,41 @@ class MirroredStrategyDefunTest(test.TestCase):
"GPU:0": 3.0 * 1.25})
self._call_and_check(fn1, [factors], expected_result, [fn1])
+ @test_util.run_in_graph_and_eager_modes()
+ def testTrain(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ cpu_dev = device_util.canonicalize("CPU:0")
+ gpu_dev = device_util.canonicalize("GPU:0")
+ devices = [cpu_dev, gpu_dev]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+
+ with dist.scope():
+ mock_model = MiniModel()
+ mock_model.call = function.defun(mock_model.call)
+
+ def loss_fn(ctx):
+ del ctx
+ return mock_model(array_ops.ones([1, 10]))
+
+ gradients_fn = backprop.implicit_grad(loss_fn)
+ gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
+ grads_and_vars = dist.call_for_each_tower(
+ gradients_fn, None, run_concurrently=False)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.25)
+ update_ops = optimizer._distributed_apply(dist, grads_and_vars) # pylint: disable=protected-access
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(update_ops)
+
+ updated_var_values = self.evaluate(mock_model.variables)
+ # All variables start at 1.0 and get two updates of 0.25.
+ self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0])
+ self.assertAllEqual([0.5], updated_var_values[1])
+
+
class MultiWorkerMirroredStrategyTest(
multi_worker_test_base.MultiWorkerTestBase,
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 7e5c9f3cb6..b1b20fafd2 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -258,6 +258,30 @@ class BackpropTest(test.TestCase):
loss += v * v
self.assertAllEqual(t.gradient(loss, v), 2.0)
+ def testAutomaticWatchedVariables(self):
+ with backprop.GradientTape() as t:
+ self.assertEqual(0, len(t.watched_variables()))
+ v = resource_variable_ops.ResourceVariable(1.0)
+ loss = v * v
+ self.assertAllEqual([v], t.watched_variables())
+
+ t.reset()
+ self.assertEqual(0, len(t.watched_variables()))
+ loss += v * v
+ self.assertAllEqual([v], t.watched_variables())
+
+ def testExplicitWatchedVariables(self):
+ with backprop.GradientTape() as t:
+ self.assertEqual(0, len(t.watched_variables()))
+ v = resource_variable_ops.ResourceVariable(1.0)
+ t.watch(v)
+ self.assertAllEqual([v], t.watched_variables())
+
+ t.reset()
+ self.assertEqual(0, len(t.watched_variables()))
+ t.watch(v)
+ self.assertAllEqual([v], t.watched_variables())
+
@test_util.assert_no_new_tensors
def testGradientNone(self):
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index ff138cad1e..f1a63adce1 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -51,7 +51,6 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@@ -202,6 +201,7 @@ class FuncGraph(ops.Graph):
# from the default graph even in eager mode. Maybe it should be part of the
# eager context?
self._distribution_strategy_stack = graph._distribution_strategy_stack
+ self._variable_creator_stack = graph._variable_creator_stack
# Inherit the graph key, since this is used for matching variables in
# optimizers.
self._graph_key = graph._graph_key
@@ -563,17 +563,6 @@ class Function(object):
self._func_graph.inputs, self._func_graph.outputs, self._attrs)
self._backward_graph_function = None
- # Map holding distributed variables, keyed by resource handle tensors.
- self._distributed_variables = {}
- strategy = distribution_strategy_context.get_distribution_strategy()
- for variable in self._func_graph.variables:
- # If variable is not distributed, unwrap returns [variable].
- component_variables = strategy.unwrap(variable)
- # Only update the dictionary when the variable is actually distributed.
- if (len(component_variables) > 1 or component_variables[0] != variable):
- for component_variable in component_variables:
- self._distributed_variables[component_variable.handle] = variable
-
def __call__(self, *args):
"""Executes the wrapped function.
@@ -602,7 +591,6 @@ class Function(object):
if v.trainable:
tape.variable_accessed(v)
- captures = self._resolve_captured_inputs()
tensor_inputs = []
for i, arg in enumerate(nest.flatten(args)):
if isinstance(arg, resource_variable_ops.ResourceVariable):
@@ -615,9 +603,10 @@ class Function(object):
raise ValueError("All inputs to `Function`s must be Tensors; "
"on invocation of %s, the %d-th input (%s) was not a "
"Tensor." % (self._func_graph.name, i, str(arg)))
- args = tensor_inputs + captures
+ args = tensor_inputs + self._captured_inputs
- if tape.should_record(tensor_inputs) or tape.should_record(captures):
+ if (tape.should_record(tensor_inputs) or
+ tape.should_record(self._captured_inputs)):
return self._backprop_call(args)
# Only need to override the gradient in graph mode and when we have outputs.
@@ -804,32 +793,6 @@ class Function(object):
args, backward_function)
return self._build_call_outputs(real_outputs)
- def _resolve_captured_inputs(self):
- """Resolve captured distributed variables to their current values.
-
- Some inputs can be distributed variables. Such variables yield a different
- component (i.e. actual tf.Variable) variables depending on the context of
- execution.
-
- Returns:
- a list of resolved captured input tensors.
- """
- if self._distributed_variables:
- # Loop over each captured input and check if it corresponds to something
- # distributed. If so, get its _distributed_container and fetch the
- # component appropriate for the current execution context.
- resolved_captured_inputs = self._captured_inputs[:]
- for i, captured_input in enumerate(self._captured_inputs):
- distributed_var = self._distributed_variables.get(captured_input, None)
- if distributed_var is not None:
- # distributed variables override __getattr__ and substitute the
- # right component variable. In here, `distributed_var.handle`
- # actually does the equivalent of
- # distributed_var.get_current_component_var().handle.
- resolved_captured_inputs[i] = distributed_var.handle
- return resolved_captured_inputs
- return self._captured_inputs
-
def _build_call_outputs(self, result):
"""Maps the fdef output list to actual output structure.
@@ -1010,14 +973,6 @@ def func_graph_from_py_func(name,
for x in _flatten(func_graph.structured_outputs)
if x is not None)
- # Some captured variables might be components of DistributedValues.
- # Instead of storing non-distributed component variables, we
- # store their distributed containers so we can retrieve the correct
- # component variables at call-time.
- strategy = distribution_strategy_context.get_distribution_strategy()
- for i, variable in enumerate(variables):
- # If variable is not distributed value_container returns itself.
- variables[i] = strategy.value_container(variable)
func_graph.variables = variables
# Register any other functions defined in the graph.
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index 399d90223c..ade945f874 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -21,6 +21,15 @@ from __future__ import print_function
import contextlib
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.util.lazy_loader import LazyLoader
+
+# There is a circular dependency between this, ops.py, and
+# distribution_strategy_context.
+# TODO(b/117329403): Remove this circular dependency.
+distribution_strategy_context = LazyLoader(
+ "distribute_lib", globals(),
+ "tensorflow.python.training."
+ "distribution_strategy_context")
class Tape(object):
@@ -52,12 +61,28 @@ def watch(tape, tensor):
def watch_variable(tape, variable):
"""Marks this variable to be watched by the given tape."""
- pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ if distribution_strategy_context.get_tower_context():
+ variables = [strategy.value_container(variable)]
+ else:
+ variables = strategy.unwrap(variable)
+ for var in variables:
+ pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
def variable_accessed(variable):
- """Notifies all tapes in the stack that a variable has been accessed."""
- pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable)
+ """Notifies all tapes in the stack that a variable has been accessed.
+
+ Args:
+ variable: variable to be watched.
+ """
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ if distribution_strategy_context.get_tower_context():
+ variables = [strategy.value_container(variable)]
+ else:
+ variables = strategy.unwrap(variable)
+ for var in variables:
+ pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var)
def pop_tape(tape):