aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-08-02 16:08:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 16:12:12 -0700
commit7b28e85b57d01c491554c6cf657a92b8344c901a (patch)
tree4c4d5825b8aa49650b0f98da7a0d6c3f359e666e
parentfcf12557c9f616695a8dab9d22beb897ae3c4f02 (diff)
Gradients of tfe.defun functions with loops in them.
PiperOrigin-RevId: 207183038
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.cc20
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer_test.cc23
-rw-r--r--tensorflow/core/grappler/utils.h3
-rw-r--r--tensorflow/python/eager/function.py26
-rw-r--r--tensorflow/python/eager/function_test.py17
-rw-r--r--tensorflow/python/ops/control_flow_ops.py127
6 files changed, 151 insertions, 65 deletions
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
index 2bfbb5606e..f3a07be728 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
@@ -464,7 +464,25 @@ std::vector<int> GetStackPushNodesToConvert(
const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
if (IsStackPushOp(fanout_node)) {
- nodes_to_convert.push_back(fanout_idx);
+ // Check that the stack itself is not a node we want to preserve. This can
+ // happen when the graph we have contains only the forward pass for a loop
+ // (as when the forward and backward passes are split across different
+ // functions).
+ if (graph_view.has_node(fanout_node.input(0))) {
+ const NodeDef* stack_node =
+ &graph_view.node(graph_view.index(fanout_node.input(0)));
+ while (stack_node->op() != "Stack" && stack_node->op() != "StackV2" &&
+ stack_node->input_size() > 0 &&
+ graph_view.has_node(stack_node->input(0))) {
+ stack_node = &graph_view.node(graph_view.index(stack_node->input(0)));
+ }
+ if (nodes_to_preserve.find(stack_node->name()) ==
+ nodes_to_preserve.end()) {
+ nodes_to_convert.push_back(fanout_idx);
+ }
+ } else {
+ nodes_to_convert.push_back(fanout_idx);
+ }
} else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
op_types_to_traverse.find(fanout_node.op()) !=
op_types_to_traverse.end()) {
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
index f5fe28d4ba..81f40db8f0 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
@@ -536,6 +536,29 @@ TEST_F(LoopOptimizerTest, RemovePush_NoOp) {
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
}
+TEST_F(LoopOptimizerTest, RemovePush_NoPopButStackLives) {
+ GrapplerItem item;
+ GraphDef& graph = item.graph;
+ AddSimpleNode("c", "Const", {}, &graph);
+ // Stack with corresponding push
+ AddSimpleNode("stack1", "StackV2", {}, &graph);
+ AddSimpleNode("push1", "StackPushV2", {"stack1", "c"}, &graph);
+ // Stack with corresponding push behind Enter.
+ AddSimpleNode("stack2", "StackV2", {}, &graph);
+ AddEnterNode("enter2_c", "frame_name", false, 1, {"c"}, &graph);
+ AddEnterNode("enter2_stack2", "frame_name", false, 1, {"stack2"}, &graph);
+ AddSimpleNode("push2", "StackPushV2", {"enter2_stack2", "enter2_c"}, &graph);
+ item.keep_ops.push_back("stack1");
+ item.keep_ops.push_back("stack2");
+
+ LoopOptimizer optimizer;
+ EnableOnlyStackPushRemoval(&optimizer);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ VerifyGraphsEqual(item.graph, output, __FUNCTION__);
+}
+
TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
GrapplerItem item;
GraphDef& graph = item.graph;
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index b297caa8d4..a9c34b6d08 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -239,6 +239,9 @@ class SimpleGraphView {
const GraphDef* graph() const { return graph_; }
inline int num_nodes() const { return index_to_name_.size(); }
+ inline bool has_node(const string& node_name) const {
+ return name_to_index_.find(node_name) != name_to_index_.end();
+ }
inline const int index(const string& node_name) const {
const auto& it = name_to_index_.find(node_name);
DCHECK(it != name_to_index_.end());
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index de55f999d7..8e8c028f60 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -132,11 +132,23 @@ class CapturingGraph(ops.Graph):
op_def=None,
compute_shapes=True,
compute_device=True):
- # TODO(apassos) this should do some form of alias analysis as ops which
- # forward the resources such as Identity and Switch can cause serialization
- # to fail.
+ # This capturing logic interacts poorly with control flow contexts which
+ # want to replace inputs of ops far too late in the process. This can lead
+ # the context to get confused and try to create an Enter for an Enter. We
+ # can detect this here and skip the additional Enter which can confuse loop
+ # validation logic.
+ if op_type == "Enter" and inputs[0].op.type == "Enter":
+ if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
+ return inputs[0].op
+ # Calling AddValue on the control flow contexts to force creation of the
+ # backward accumulators in the original graph before we create placeholders
+ # to capture the inputs.
+ ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
for i, inp in enumerate(inputs):
- inputs[i] = self.capture(inp)
+ if ctxt is not None and hasattr(ctxt, "AddValue"):
+ inp = ctxt.AddValue(inp)
+ inp = self.capture(inp)
+ inputs[i] = inp
return super(CapturingGraph, self).create_op(
op_type, inputs, dtypes, input_types, name, attrs, op_def,
compute_device=compute_device)
@@ -457,7 +469,6 @@ class GraphModeFunction(object):
self._func_name = name
self._function_def = defined_function
self._num_outputs = len(defined_function.signature.output_arg)
- self._ops = operations
self._python_func_outputs = python_func_outputs
self._python_returns = [python_func_outputs] if isinstance(
python_func_outputs,
@@ -501,8 +512,9 @@ class GraphModeFunction(object):
forward_name = _forward_name(self._func_name)
self._forward_fdef = _EagerDefinedFunction(
- forward_name, self._graph, self._ops, self._input_placeholders,
- filtered_outputs + list(extra_inputs), self._attrs)
+ forward_name, self._graph, self._graph.get_operations(),
+ self._input_placeholders, filtered_outputs + list(extra_inputs),
+ self._attrs)
all_inputs = self._out_grad_placeholders + list(extra_placeholders)
# Excluding input ops from the body as we do not intend to execute these
# operations when the function is executed.
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index afd4bbf4f3..5efdecdbc6 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -227,6 +227,23 @@ class FunctionTest(test.TestCase):
y = f(x)
self.assertAllEqual(self.evaluate(t.gradient(y, x)), 2.0)
+ @test_util.run_in_graph_and_eager_modes()
+ def testGraphLoopGradient(self):
+ if context.executing_eagerly():
+ self.skipTest('TODO(apassos): support loops in defuns in eager')
+
+ @function.defun
+ def f(x):
+ return control_flow_ops.while_loop(lambda _, i: i < 2,
+ lambda x, i: (2*x, i + 1),
+ [x, 0])[0]
+
+ with backprop.GradientTape() as t:
+ x = constant_op.constant(1.0)
+ t.watch(x)
+ y = f(x)
+ self.assertAllEqual(self.evaluate(t.gradient(y, x)), 4.0)
+
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index aeac61c005..c7061b36dd 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -817,11 +817,12 @@ class GradLoopState(object):
outer_forward_ctxt = forward_ctxt.outer_context
# Add the forward loop counter.
- if outer_forward_ctxt:
- outer_forward_ctxt.Enter()
- cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
- if outer_forward_ctxt:
- outer_forward_ctxt.Exit()
+ with forward_ctxt._graph.as_default(): # pylint: disable=protected-access
+ if outer_forward_ctxt:
+ outer_forward_ctxt.Enter()
+ cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
+ if outer_forward_ctxt:
+ outer_forward_ctxt.Exit()
self._forward_context = forward_ctxt
self._forward_index = forward_index
@@ -984,60 +985,61 @@ class GradLoopState(object):
for the stack can't be found.
"""
# curr_ctxt is the context that tf.gradients was called in.
- curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
- with ops.control_dependencies(None):
- if curr_ctxt:
- curr_ctxt.Enter()
- with ops.colocate_with(value):
- # We only need to pass maximum_iterations to the stack if
- # we're inside an XLA context.
- if not util.IsInXLAContext(value.op):
- max_size = constant_op.constant(-1, dtypes.int32)
- else:
- max_size = GetMaxSizeFromNestedMaximumIterations(
- value, self.forward_context)
- acc = gen_data_flow_ops.stack_v2(
- max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
- if curr_ctxt:
- curr_ctxt.Exit()
-
- # Make acc available in the forward context.
- enter_acc = self.forward_context.AddValue(acc)
-
- # Add the stack_push op in the context of value.op.
- swap_enabled = self.forward_context.swap_memory
- value_ctxt = util.GetOutputContext(value.op)
- if value_ctxt == self.forward_context:
- # value is not nested in the forward context.
- self.forward_context.Enter()
- push = gen_data_flow_ops.stack_push_v2(
- enter_acc, value, swap_memory=swap_enabled)
- self.forward_context.Exit()
- # Protect stack push and order it before forward_index.
- self.forward_index.op._add_control_input(push.op)
- else:
- # value is in a cond context within the forward context.
- if not isinstance(value_ctxt, CondContext):
- raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
- if dead_branch:
- # The special case for creating a zero tensor for a dead
- # branch of a switch. See ControlFlowState.ZerosLike().
- value_ctxt.outer_context.Enter()
+ with self._forward_index.graph.as_default():
+ curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
+ with ops.control_dependencies(None):
+ if curr_ctxt:
+ curr_ctxt.Enter()
+ with ops.colocate_with(value):
+ # We only need to pass maximum_iterations to the stack if
+ # we're inside an XLA context.
+ if not util.IsInXLAContext(value.op):
+ max_size = constant_op.constant(-1, dtypes.int32)
+ else:
+ max_size = GetMaxSizeFromNestedMaximumIterations(
+ value, self.forward_context)
+ acc = gen_data_flow_ops.stack_v2(
+ max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
+ if curr_ctxt:
+ curr_ctxt.Exit()
+
+ # Make acc available in the forward context.
+ enter_acc = self.forward_context.AddValue(acc)
+
+ # Add the stack_push op in the context of value.op.
+ swap_enabled = self.forward_context.swap_memory
+ value_ctxt = util.GetOutputContext(value.op)
+ if value_ctxt == self.forward_context:
+ # value is not nested in the forward context.
+ self.forward_context.Enter()
push = gen_data_flow_ops.stack_push_v2(
enter_acc, value, swap_memory=swap_enabled)
- value_ctxt.outer_context.Exit()
- push.op._set_control_flow_context(value_ctxt)
+ self.forward_context.Exit()
+ # Protect stack push and order it before forward_index.
+ self.forward_index.op._add_control_input(push.op)
else:
- value_ctxt.Enter()
- push = gen_data_flow_ops.stack_push_v2(
- enter_acc, value, swap_memory=swap_enabled)
- value_ctxt.Exit()
- # Protect stack push and order it before forward_sync.
- self.forward_sync._add_control_input(push.op)
- # Order stack push after the successor of forward_index
- add_op = self.forward_index.op.inputs[0].op
- push.op._add_control_input(add_op)
- return acc
+ # value is in a cond context within the forward context.
+ if not isinstance(value_ctxt, CondContext):
+ raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
+ if dead_branch:
+ # The special case for creating a zero tensor for a dead
+ # branch of a switch. See ControlFlowState.ZerosLike().
+ value_ctxt.outer_context.Enter()
+ push = gen_data_flow_ops.stack_push_v2(
+ enter_acc, value, swap_memory=swap_enabled)
+ value_ctxt.outer_context.Exit()
+ push.op._set_control_flow_context(value_ctxt)
+ else:
+ value_ctxt.Enter()
+ push = gen_data_flow_ops.stack_push_v2(
+ enter_acc, value, swap_memory=swap_enabled)
+ value_ctxt.Exit()
+ # Protect stack push and order it before forward_sync.
+ self.forward_sync._add_control_input(push.op)
+ # Order stack push after the successor of forward_index
+ add_op = self.forward_index.op.inputs[0].op
+ push.op._add_control_input(add_op)
+ return acc
def AddBackpropAccumulatedValue(self, history_value, value,
dead_branch=False):
@@ -2215,6 +2217,7 @@ class WhileContext(ControlFlowContext):
self._loop_exits = []
# The list of enter tensors for loop variables.
self._loop_enters = []
+ self._graph = ops.get_default_graph()
def _init_from_proto(self, context_def, import_scope=None):
"""Creates a new `WhileContext` from protocol buffer.
@@ -2268,6 +2271,7 @@ class WhileContext(ControlFlowContext):
op._set_attr("frame_name",
attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
# pylint: enable=protected-access
+ self._graph = ops.get_default_graph()
@property
def maximum_iterations(self):
@@ -2592,7 +2596,14 @@ class WhileContext(ControlFlowContext):
Returns:
The loop index.
"""
- one = constant_op.constant(1, name="b_count")
+ in_separate_functions = count.graph is not ops.get_default_graph()
+ if in_separate_functions:
+ # Brings the count into this graph
+ count = array_ops.identity(count)
+ else:
+ # TODO(apassos) XLA expects this constant to be created outside the loop,
+ # so doing that for now.
+ one = constant_op.constant(1, name="b_count")
self.Enter()
self.AddName(count.name)
@@ -2607,6 +2618,8 @@ class WhileContext(ControlFlowContext):
merge_count = merge([enter_count, enter_count])[0]
self._pivot_for_pred = merge_count
+ if in_separate_functions:
+ one = constant_op.constant(1, name="b_count")
pred = math_ops.greater_equal(merge_count, one)
self._pivot = loop_cond(pred, name="b_count")
switch_count = switch(merge_count, self._pivot)