diff options
author | 2018-08-02 16:08:42 -0700 | |
---|---|---|
committer | 2018-08-02 16:12:12 -0700 | |
commit | 7b28e85b57d01c491554c6cf657a92b8344c901a (patch) | |
tree | 4c4d5825b8aa49650b0f98da7a0d6c3f359e666e | |
parent | fcf12557c9f616695a8dab9d22beb897ae3c4f02 (diff) |
Gradients of tfe.defun functions with loops in them.
PiperOrigin-RevId: 207183038
-rw-r--r-- | tensorflow/core/grappler/optimizers/loop_optimizer.cc | 20 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/loop_optimizer_test.cc | 23 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils.h | 3 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 26 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 17 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 127 |
6 files changed, 151 insertions, 65 deletions
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index 2bfbb5606e..f3a07be728 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -464,7 +464,25 @@ std::vector<int> GetStackPushNodesToConvert( const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx); VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name(); if (IsStackPushOp(fanout_node)) { - nodes_to_convert.push_back(fanout_idx); + // Check that the stack itself is not a node we want to preserve. This can + // happen when the graph we have contains only the forward pass for a loop + // (as when the forward and backward passes are split across different + // functions). + if (graph_view.has_node(fanout_node.input(0))) { + const NodeDef* stack_node = + &graph_view.node(graph_view.index(fanout_node.input(0))); + while (stack_node->op() != "Stack" && stack_node->op() != "StackV2" && + stack_node->input_size() > 0 && + graph_view.has_node(stack_node->input(0))) { + stack_node = &graph_view.node(graph_view.index(stack_node->input(0))); + } + if (nodes_to_preserve.find(stack_node->name()) == + nodes_to_preserve.end()) { + nodes_to_convert.push_back(fanout_idx); + } + } else { + nodes_to_convert.push_back(fanout_idx); + } } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) || op_types_to_traverse.find(fanout_node.op()) != op_types_to_traverse.end()) { diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc index f5fe28d4ba..81f40db8f0 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc @@ -536,6 +536,29 @@ TEST_F(LoopOptimizerTest, RemovePush_NoOp) { VerifyGraphsEqual(item.graph, output, __FUNCTION__); } +TEST_F(LoopOptimizerTest, RemovePush_NoPopButStackLives) { + GrapplerItem item; + GraphDef& graph = item.graph; + AddSimpleNode("c", "Const", {}, &graph); + // Stack with corresponding push + AddSimpleNode("stack1", "StackV2", {}, &graph); + AddSimpleNode("push1", "StackPushV2", {"stack1", "c"}, &graph); + // Stack with corresponding push behind Enter. + AddSimpleNode("stack2", "StackV2", {}, &graph); + AddEnterNode("enter2_c", "frame_name", false, 1, {"c"}, &graph); + AddEnterNode("enter2_stack2", "frame_name", false, 1, {"stack2"}, &graph); + AddSimpleNode("push2", "StackPushV2", {"enter2_stack2", "enter2_c"}, &graph); + item.keep_ops.push_back("stack1"); + item.keep_ops.push_back("stack2"); + + LoopOptimizer optimizer; + EnableOnlyStackPushRemoval(&optimizer); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + VerifyGraphsEqual(item.graph, output, __FUNCTION__); +} + TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) { GrapplerItem item; GraphDef& graph = item.graph; diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index b297caa8d4..a9c34b6d08 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -239,6 +239,9 @@ class SimpleGraphView { const GraphDef* graph() const { return graph_; } inline int num_nodes() const { return index_to_name_.size(); } + inline bool has_node(const string& node_name) const { + return name_to_index_.find(node_name) != name_to_index_.end(); + } inline const int index(const string& node_name) const { const auto& it = name_to_index_.find(node_name); DCHECK(it != name_to_index_.end()); diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index de55f999d7..8e8c028f60 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -132,11 +132,23 @@ class CapturingGraph(ops.Graph): op_def=None, compute_shapes=True, compute_device=True): - # TODO(apassos) this should do some form of alias analysis as ops which - # forward the resources such as Identity and Switch can cause serialization - # to fail. + # This capturing logic interacts poorly with control flow contexts which + # want to replace inputs of ops far too late in the process. This can lead + # the context to get confused and try to create an Enter for an Enter. We + # can detect this here and skip the additional Enter which can confuse loop + # validation logic. + if op_type == "Enter" and inputs[0].op.type == "Enter": + if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: + return inputs[0].op + # Calling AddValue on the control flow contexts to force creation of the + # backward accumulators in the original graph before we create placeholders + # to capture the inputs. + ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access for i, inp in enumerate(inputs): - inputs[i] = self.capture(inp) + if ctxt is not None and hasattr(ctxt, "AddValue"): + inp = ctxt.AddValue(inp) + inp = self.capture(inp) + inputs[i] = inp return super(CapturingGraph, self).create_op( op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device=compute_device) @@ -457,7 +469,6 @@ class GraphModeFunction(object): self._func_name = name self._function_def = defined_function self._num_outputs = len(defined_function.signature.output_arg) - self._ops = operations self._python_func_outputs = python_func_outputs self._python_returns = [python_func_outputs] if isinstance( python_func_outputs, @@ -501,8 +512,9 @@ class GraphModeFunction(object): forward_name = _forward_name(self._func_name) self._forward_fdef = _EagerDefinedFunction( - forward_name, self._graph, self._ops, self._input_placeholders, - filtered_outputs + list(extra_inputs), self._attrs) + forward_name, self._graph, self._graph.get_operations(), + self._input_placeholders, filtered_outputs + list(extra_inputs), + self._attrs) all_inputs = self._out_grad_placeholders + list(extra_placeholders) # Excluding input ops from the body as we do not intend to execute these # operations when the function is executed. diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index afd4bbf4f3..5efdecdbc6 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -227,6 +227,23 @@ class FunctionTest(test.TestCase): y = f(x) self.assertAllEqual(self.evaluate(t.gradient(y, x)), 2.0) + @test_util.run_in_graph_and_eager_modes() + def testGraphLoopGradient(self): + if context.executing_eagerly(): + self.skipTest('TODO(apassos): support loops in defuns in eager') + + @function.defun + def f(x): + return control_flow_ops.while_loop(lambda _, i: i < 2, + lambda x, i: (2*x, i + 1), + [x, 0])[0] + + with backprop.GradientTape() as t: + x = constant_op.constant(1.0) + t.watch(x) + y = f(x) + self.assertAllEqual(self.evaluate(t.gradient(y, x)), 4.0) + def testDefunCapturedInt32(self): x = constant_op.constant(1, dtype=dtypes.int32) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index aeac61c005..c7061b36dd 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -817,11 +817,12 @@ class GradLoopState(object): outer_forward_ctxt = forward_ctxt.outer_context # Add the forward loop counter. - if outer_forward_ctxt: - outer_forward_ctxt.Enter() - cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) - if outer_forward_ctxt: - outer_forward_ctxt.Exit() + with forward_ctxt._graph.as_default(): # pylint: disable=protected-access + if outer_forward_ctxt: + outer_forward_ctxt.Enter() + cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) + if outer_forward_ctxt: + outer_forward_ctxt.Exit() self._forward_context = forward_ctxt self._forward_index = forward_index @@ -984,60 +985,61 @@ class GradLoopState(object): for the stack can't be found. """ # curr_ctxt is the context that tf.gradients was called in. - curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access - with ops.control_dependencies(None): - if curr_ctxt: - curr_ctxt.Enter() - with ops.colocate_with(value): - # We only need to pass maximum_iterations to the stack if - # we're inside an XLA context. - if not util.IsInXLAContext(value.op): - max_size = constant_op.constant(-1, dtypes.int32) - else: - max_size = GetMaxSizeFromNestedMaximumIterations( - value, self.forward_context) - acc = gen_data_flow_ops.stack_v2( - max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") - if curr_ctxt: - curr_ctxt.Exit() - - # Make acc available in the forward context. - enter_acc = self.forward_context.AddValue(acc) - - # Add the stack_push op in the context of value.op. - swap_enabled = self.forward_context.swap_memory - value_ctxt = util.GetOutputContext(value.op) - if value_ctxt == self.forward_context: - # value is not nested in the forward context. - self.forward_context.Enter() - push = gen_data_flow_ops.stack_push_v2( - enter_acc, value, swap_memory=swap_enabled) - self.forward_context.Exit() - # Protect stack push and order it before forward_index. - self.forward_index.op._add_control_input(push.op) - else: - # value is in a cond context within the forward context. - if not isinstance(value_ctxt, CondContext): - raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) - if dead_branch: - # The special case for creating a zero tensor for a dead - # branch of a switch. See ControlFlowState.ZerosLike(). - value_ctxt.outer_context.Enter() + with self._forward_index.graph.as_default(): + curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + with ops.control_dependencies(None): + if curr_ctxt: + curr_ctxt.Enter() + with ops.colocate_with(value): + # We only need to pass maximum_iterations to the stack if + # we're inside an XLA context. + if not util.IsInXLAContext(value.op): + max_size = constant_op.constant(-1, dtypes.int32) + else: + max_size = GetMaxSizeFromNestedMaximumIterations( + value, self.forward_context) + acc = gen_data_flow_ops.stack_v2( + max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") + if curr_ctxt: + curr_ctxt.Exit() + + # Make acc available in the forward context. + enter_acc = self.forward_context.AddValue(acc) + + # Add the stack_push op in the context of value.op. + swap_enabled = self.forward_context.swap_memory + value_ctxt = util.GetOutputContext(value.op) + if value_ctxt == self.forward_context: + # value is not nested in the forward context. + self.forward_context.Enter() push = gen_data_flow_ops.stack_push_v2( enter_acc, value, swap_memory=swap_enabled) - value_ctxt.outer_context.Exit() - push.op._set_control_flow_context(value_ctxt) + self.forward_context.Exit() + # Protect stack push and order it before forward_index. + self.forward_index.op._add_control_input(push.op) else: - value_ctxt.Enter() - push = gen_data_flow_ops.stack_push_v2( - enter_acc, value, swap_memory=swap_enabled) - value_ctxt.Exit() - # Protect stack push and order it before forward_sync. - self.forward_sync._add_control_input(push.op) - # Order stack push after the successor of forward_index - add_op = self.forward_index.op.inputs[0].op - push.op._add_control_input(add_op) - return acc + # value is in a cond context within the forward context. + if not isinstance(value_ctxt, CondContext): + raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) + if dead_branch: + # The special case for creating a zero tensor for a dead + # branch of a switch. See ControlFlowState.ZerosLike(). + value_ctxt.outer_context.Enter() + push = gen_data_flow_ops.stack_push_v2( + enter_acc, value, swap_memory=swap_enabled) + value_ctxt.outer_context.Exit() + push.op._set_control_flow_context(value_ctxt) + else: + value_ctxt.Enter() + push = gen_data_flow_ops.stack_push_v2( + enter_acc, value, swap_memory=swap_enabled) + value_ctxt.Exit() + # Protect stack push and order it before forward_sync. + self.forward_sync._add_control_input(push.op) + # Order stack push after the successor of forward_index + add_op = self.forward_index.op.inputs[0].op + push.op._add_control_input(add_op) + return acc def AddBackpropAccumulatedValue(self, history_value, value, dead_branch=False): @@ -2215,6 +2217,7 @@ class WhileContext(ControlFlowContext): self._loop_exits = [] # The list of enter tensors for loop variables. self._loop_enters = [] + self._graph = ops.get_default_graph() def _init_from_proto(self, context_def, import_scope=None): """Creates a new `WhileContext` from protocol buffer. @@ -2268,6 +2271,7 @@ class WhileContext(ControlFlowContext): op._set_attr("frame_name", attr_value_pb2.AttrValue(s=compat.as_bytes(self.name))) # pylint: enable=protected-access + self._graph = ops.get_default_graph() @property def maximum_iterations(self): @@ -2592,7 +2596,14 @@ class WhileContext(ControlFlowContext): Returns: The loop index. """ - one = constant_op.constant(1, name="b_count") + in_separate_functions = count.graph is not ops.get_default_graph() + if in_separate_functions: + # Brings the count into this graph + count = array_ops.identity(count) + else: + # TODO(apassos) XLA expects this constant to be created outside the loop, + # so doing that for now. + one = constant_op.constant(1, name="b_count") self.Enter() self.AddName(count.name) @@ -2607,6 +2618,8 @@ class WhileContext(ControlFlowContext): merge_count = merge([enter_count, enter_count])[0] self._pivot_for_pred = merge_count + if in_separate_functions: + one = constant_op.constant(1, name="b_count") pred = math_ops.greater_equal(merge_count, one) self._pivot = loop_cond(pred, name="b_count") switch_count = switch(merge_count, self._pivot) |