diff options
-rw-r--r-- | tensorflow/core/graph/graph.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/control_flow_ops.cc | 65 | ||||
-rw-r--r-- | tensorflow/core/ops/control_flow_ops.cc | 48 | ||||
-rw-r--r-- | tensorflow/core/ops/ops.pbtxt | 76 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 67 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_grad.py | 47 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 64 | ||||
-rw-r--r-- | tensorflow/python/ops/gradients.py | 9 |
10 files changed, 332 insertions, 54 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 4abff8922e..a29f800a26 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -83,10 +83,10 @@ void Node::Initialize(int id, int cost_id, Properties* props) { } while (0) SET_CLASS(NC_SWITCH, ts, "Switch", "RefSwitch"); - SET_CLASS(NC_MERGE, ts, "Merge", ""); + SET_CLASS(NC_MERGE, ts, "Merge", "RefMerge"); SET_CLASS(NC_ENTER, ts, "Enter", "RefEnter"); - SET_CLASS(NC_EXIT, ts, "Exit", ""); - SET_CLASS(NC_NEXT_ITERATION, ts, "NextIteration", ""); + SET_CLASS(NC_EXIT, ts, "Exit", "RefExit"); + SET_CLASS(NC_NEXT_ITERATION, ts, "NextIteration", "RefNextIteration"); SET_CLASS(NC_LOOP_COND, ts, "LoopCond", ""); SET_CLASS(NC_CONTROL_TRIGGER, ts, "ControlTrigger", ""); SET_CLASS(NC_SEND, ts, "_Send", "_HostSend"); diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 874f6214f9..c96f04f2d7 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -35,7 +35,7 @@ namespace tensorflow { namespace { inline bool IsMerge(const NodeDef& node_def) { - return node_def.op() == "Merge"; + return node_def.op() == "Merge" || node_def.op() == "RefMerge"; } } // namespace diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index 8b539a7751..f009b85b1b 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -192,7 +192,11 @@ class MergeOp : public OpKernel { } input_seen = true; - context->set_output(0, context->input(i)); + if (IsRefType(context->input_dtype(i))) { + context->forward_ref_input_to_ref_output(i, 0); + } else { + context->set_output(0, context->input(i)); + } Tensor* value_index = nullptr; OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), &value_index)); @@ -209,18 +213,26 @@ class MergeOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp); +REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp); #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("Merge") \ .Device(DEVICE_GPU) \ .TypeConstraint<type>("T") \ .HostMemory("value_index"), \ - MergeOp) + MergeOp); +#define REGISTER_GPU_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefMerge") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("value_index"), \ + MergeOp); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); REGISTER_GPU_KERNEL(bool); #undef REGISTER_GPU_KERNEL +#undef REGISTER_GPU_REF_KERNEL // Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -232,6 +244,13 @@ REGISTER_GPU_KERNEL(bool); .HostMemory("output") \ .HostMemory("value_index") \ .TypeConstraint<type>("T"), \ + MergeOp); \ + REGISTER_KERNEL_BUILDER(Name("RefMerge") \ + .Device(DEVICE_GPU) \ + .HostMemory("inputs") \ + .HostMemory("output") \ + .HostMemory("value_index") \ + .TypeConstraint<type>("T"), \ MergeOp) REGISTER_GPU_HOST_KERNEL(int32); @@ -314,7 +333,11 @@ class ExitOp : public OpKernel { explicit ExitOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - context->set_output(0, context->input(0)); + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); + } } bool IsExpensive() override { return false; } @@ -325,15 +348,20 @@ class ExitOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp); +REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp); #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER( \ Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp); +#define REGISTER_GPU_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); REGISTER_GPU_KERNEL(bool); #undef REGISTER_GPU_KERNEL +#undef REGISTER_GPU_REF_KERNEL // Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -344,7 +372,13 @@ REGISTER_GPU_KERNEL(bool); .HostMemory("data") \ .HostMemory("output") \ .TypeConstraint<type>("T"), \ - ExitOp) + ExitOp); \ + REGISTER_KERNEL_BUILDER(Name("RefExit") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + ExitOp); REGISTER_GPU_HOST_KERNEL(int32); REGISTER_GPU_HOST_KERNEL(string); @@ -358,7 +392,11 @@ class NextIterationOp : public OpKernel { explicit NextIterationOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - context->set_output(0, context->input(0)); + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); + } } bool IsExpensive() override { return false; } @@ -370,10 +408,15 @@ class NextIterationOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU), NextIterationOp); +REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU), + NextIterationOp); -#define REGISTER_GPU_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("RefNextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ NextIterationOp) TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); @@ -390,6 +433,12 @@ REGISTER_GPU_KERNEL(bool); .HostMemory("data") \ .HostMemory("output") \ .TypeConstraint<type>("T"), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ NextIterationOp) REGISTER_GPU_HOST_KERNEL(int32); diff --git a/tensorflow/core/ops/control_flow_ops.cc b/tensorflow/core/ops/control_flow_ops.cc index 06c7e91af3..7ae0012af7 100644 --- a/tensorflow/core/ops/control_flow_ops.cc +++ b/tensorflow/core/ops/control_flow_ops.cc @@ -97,6 +97,28 @@ output: Will be set to the available input tensor. value_index: The index of the chosen input tensor in `inputs`. )doc"); +REGISTER_OP("RefMerge") + .Input("inputs: Ref(N * T)") + .Output("output: Ref(T)") + .Output("value_index: int32") + .Attr("T: type") + .Attr("N: int >= 1") + .Doc(R"doc( +Forwards the value of an available tensor from `inputs` to `output`. + +`Merge` waits for at least one of the tensors in `inputs` to become available. +It is usually combined with `Switch` to implement branching. + +`Merge` forwards the first tensor for become available to `output`, and sets +`value_index` to its index in `inputs`. + +It is an error if more than one tensor in `inputs` is available. + +inputs: The input tensors, exactly one of which will become available. +output: Will be set to the available input tensor. +value_index: The index of the chosen input tensor in `inputs`. +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("Enter") .Input("data: T") @@ -158,6 +180,19 @@ data: The tensor to be made available to the parent frame. output: The same tensor as `data`. )doc"); +REGISTER_OP("RefExit") + .Input("data: Ref(T)") + .Output("output: Ref(T)") + .Attr("T: type") + .Doc(R"doc( +Exits the current frame to its parent frame. + +Exit makes its input `data` available to the parent frame. + +data: The tensor to be made available to the parent frame. +output: The same tensor as `data`. +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("NextIteration") .Input("data: T") @@ -170,6 +205,17 @@ data: The tensor to be made available to the next iteration. output: The same tensor as `data`. )doc"); +REGISTER_OP("RefNextIteration") + .Input("data: Ref(T)") + .Output("output: Ref(T)") + .Attr("T: type") + .Doc(R"doc( +Makes its input available to the next iteration. + +data: The tensor to be made available to the next iteration. +output: The same tensor as `data`. +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("LoopCond") .Input("input: bool") @@ -180,7 +226,7 @@ Forwards the input to the output. This operator represents the loop termination condition used by the "pivot" switches of a loop. -input:= A boolean scalar, representing the branch predicate of the Switch op. +input: A boolean scalar, representing the branch predicate of the Switch op. output: The same tensor as `input`. )doc"); diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index b54826269e..69247f02e1 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -3996,7 +3996,7 @@ op { name: "LoopCond" input_arg { name: "input" - description: "= A boolean scalar, representing the branch predicate of the Switch op." + description: "A boolean scalar, representing the branch predicate of the Switch op." type: DT_BOOL } output_arg { @@ -6124,6 +6124,27 @@ op { description: "The unique `frame_name` is used by the `Executor` to identify frames. If\n`is_constant` is true, `output` is a constant in the child frame; otherwise\nit may be changed in the child frame. At most `parallel_iterations` iterations\nare run in parallel in the child frame." } op { + name: "RefExit" + input_arg { + name: "data" + description: "The tensor to be made available to the parent frame." + type_attr: "T" + is_ref: true + } + output_arg { + name: "output" + description: "The same tensor as `data`." + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + } + summary: "Exits the current frame to its parent frame." + description: "Exit makes its input `data` available to the parent frame." +} +op { name: "RefIdentity" input_arg { name: "input" @@ -6142,6 +6163,59 @@ op { summary: "Return the same ref tensor as the input ref tensor." } op { + name: "RefMerge" + input_arg { + name: "inputs" + description: "The input tensors, exactly one of which will become available." + type_attr: "T" + number_attr: "N" + is_ref: true + } + output_arg { + name: "output" + description: "Will be set to the available input tensor." + type_attr: "T" + is_ref: true + } + output_arg { + name: "value_index" + description: "The index of the chosen input tensor in `inputs`." + type: DT_INT32 + } + attr { + name: "T" + type: "type" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + summary: "Forwards the value of an available tensor from `inputs` to `output`." + description: "`Merge` waits for at least one of the tensors in `inputs` to become available.\nIt is usually combined with `Switch` to implement branching.\n\n`Merge` forwards the first tensor for become available to `output`, and sets\n`value_index` to its index in `inputs`.\n\nIt is an error if more than one tensor in `inputs` is available." +} +op { + name: "RefNextIteration" + input_arg { + name: "data" + description: "The tensor to be made available to the next iteration." + type_attr: "T" + is_ref: true + } + output_arg { + name: "output" + description: "The same tensor as `data`." + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + } + summary: "Makes its input available to the next iteration." +} +op { name: "RefSelect" input_arg { name: "index" diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 71840b31c2..4457ebbb2e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -435,7 +435,9 @@ tf_gen_op_wrapper_py( hidden = [ "Switch", "Merge", + "RefMerge", "Exit", + "RefExit", ], require_shape_functions = True, deps = [ diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index baed939f73..7c405e8212 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -29,8 +29,8 @@ from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_data_flow_ops -from tensorflow.python.ops import gradients from tensorflow.python.ops import logging_ops from tensorflow.python.pywrap_tensorflow import StatusNotOK @@ -68,6 +68,36 @@ def isum(s): class ControlFlowTest(tf.test.TestCase): + def testWhileWithRefsWithGradients_1(self): + with self.test_session() as sess: + x = tf.Variable(0).ref() + i = tf.constant(0) + c = lambda i, x: tf.less(i, 10) + + self.assertEqual(x.dtype, tf.int32_ref) + + # pylint: disable=protected-access + def body(i, x): + self.assertEqual(x.dtype, tf.int32_ref) + return (i+1, gen_array_ops._ref_identity(x)) + # pylint: enable=protected-access + + r = control_flow_ops.While(c, body, [i, x], parallel_iterations=5) + + grad_ys = [tf.Variable(73).ref()] + grad = tf.gradients([r[1]], [x], grad_ys=grad_ys) + + tf.initialize_all_variables().run() + + self.assertEqual(r[0].dtype, tf.int32) + self.assertEqual(r[1].dtype, tf.int32_ref) + + value_i, value_x, value_x_grad = sess.run(r + grad) + + self.assertEqual(10, value_i) + self.assertEqual(0, value_x) + self.assertEqual(73, value_x_grad) + def testRefIdentity(self): with self.test_session(): v = tf.Variable(7) @@ -99,7 +129,7 @@ class ControlFlowTest(tf.test.TestCase): v = tf.Variable(7) p = tf.constant(True) - v1 = control_flow_ops._SwitchRefOrTensor(v, p) + v1 = control_flow_ops._SwitchRefOrTensor(v.ref(), p) v2 = tf.assign(v1[1], 9) tf.initialize_all_variables().run() self.assertEqual(9, v2.eval()) @@ -171,7 +201,7 @@ class ControlFlowTest(tf.test.TestCase): dead_branch = tf.identity(switch_op[0]) with self.assertRaisesWithPredicateMatch( - StatusNotOK, lambda e: 'The tensor returned for' in str(e)): + StatusNotOK, lambda e: "The tensor returned for" in str(e)): dead_branch.eval() def testSwitchMergeIdentity_1(self): @@ -544,6 +574,30 @@ class ControlFlowTest(tf.test.TestCase): self.assertTrue(check_op_order(n.graph)) self.assertEqual(10000, result) + def testWhileWithRefs_1(self): + with self.test_session() as sess: + x = tf.Variable(0).ref() + i = tf.constant(0) + c = lambda i, x: tf.less(i, 100) + + self.assertEqual(x.dtype, tf.int32_ref) + + def b(i, x): + self.assertEqual(x.dtype, tf.int32_ref) + return (i+1, gen_array_ops._ref_identity(x)) + + r = control_flow_ops.While(c, b, [i, x], parallel_iterations=5) + + tf.initialize_all_variables().run() + + self.assertEqual(r[0].dtype, tf.int32) + self.assertEqual(r[1].dtype, tf.int32_ref) + + value_i, value_x = sess.run(r) + + self.assertEqual(100, value_i) + self.assertEqual(0, value_x) + def testWhile_2(self): with self.test_session(): s = tf.constant(0) @@ -737,8 +791,8 @@ class ControlFlowTest(tf.test.TestCase): n = tf.convert_to_tensor(10, name="n") one = tf.convert_to_tensor(1, name="one") c = lambda x: tf.less(x, n) - b = lambda x: tf.cond(tf.constant(True), lambda: tf.add(x, one), - lambda: tf.sub(x, one)) + b = lambda x: tf.cond( + tf.constant(True), lambda: tf.add(x, one), lambda: tf.sub(x, one)) r = control_flow_ops.While(c, b, [i]) result = r.eval() @@ -880,7 +934,7 @@ class ControlFlowTest(tf.test.TestCase): tf.initialize_all_variables().run() # Change condition to check var_b - def pred(i): + def pred(_): return tf.less(var_b, 10) # Change body to increment var_b @@ -1507,5 +1561,6 @@ class TupleTest(tf.test.TestCase): self.assertEquals(1, var.eval()) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py index a9986eb6bd..15f6ac8a8f 100644 --- a/tensorflow/python/ops/control_flow_grad.py +++ b/tensorflow/python/ops/control_flow_grad.py @@ -20,13 +20,13 @@ from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops # pylint: disable=wildcard-import,undefined-variable from tensorflow.python.ops.control_flow_ops import * from tensorflow.python.ops.gen_control_flow_ops import * -@ops.RegisterGradient("Switch") def _SwitchGrad(op, *grad): """Gradients for a Switch op is calculated using a Merge op. @@ -45,14 +45,17 @@ def _SwitchGrad(op, *grad): # the non-exit branch of the Switch, so update the second input # to the Merge. # TODO: Need to perform shape inference with this new input. - merge_op._update_input(1, next_iteration(grad[1])) + # pylint: disable=protected-access + merge_op._update_input(1, control_flow_ops._NextIteration(grad[1])) + # pylint: enable=protected-access return None, None else: # This is the first time this Switch is visited. It always comes # from the Exit branch, which is grad[0]. grad[1] is empty at this point. # Use grad[0] for both inputs to merge for now, but update the second # input of merge when we see this Switch the second time. - merge_op = merge([grad[0], grad[0]], name="b_switch")[0] + merge_fn = control_flow_ops._Merge # pylint: disable=protected-access + merge_op = merge_fn([grad[0], grad[0]], name="b_switch")[0] op.grad_state.switch_map[real_op] = merge_op.op return merge_op, None elif isinstance(ctxt, CondContext): @@ -71,9 +74,8 @@ def _SwitchGrad(op, *grad): return merge([false_grad, true_grad])[0], None -@ops.RegisterGradient("RefSwitch") -def _RefSwitchGrad(op, *grad): - return _SwitchGrad(op, *grad) +ops.RegisterGradient("Switch")(_SwitchGrad) +ops.RegisterGradient("RefSwitch")(_SwitchGrad) @ops.RegisterGradient("Merge") @@ -86,7 +88,9 @@ def _MergeGrad(op, grad, _): # pylint: enable=protected-access if isinstance(ctxt, WhileContext): grad_ctxt = op.grad_state.grad_context - return switch(grad, grad_ctxt.pivot) + # pylint: disable=protected-access + return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot) + # pylint: enable=protected-access elif isinstance(ctxt, CondContext): pred = ctxt.pred if isinstance(op, ControlFlowOpWrapper): @@ -108,11 +112,21 @@ def _MergeGrad(op, grad, _): real_pred = grad_state.AddBackPropAccumulatedValue(history_pred, pred) grad_state.history_map[pred.name] = real_pred pred = real_pred - return switch(grad, pred, name="cond_grad") + # pylint: disable=protected-access + return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad") + # pylint: enable=protected-access else: num_inputs = len(real_op.inputs) cond = [math_ops.equal(real_op.outputs[1], i) for i in xrange(num_inputs)] - return [switch(grad, cond[i])[1] for i in xrange(num_inputs)] + # pylint: disable=protected-access + return [control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1] + for i in xrange(num_inputs)] + # pylint: enable=protected-access + + +@ops.RegisterGradient("RefMerge") +def _RefMergeGrad(op, grad, _): + return _MergeGrad(op, grad, _) @ops.RegisterGradient("Exit") @@ -127,9 +141,13 @@ def _ExitGrad(op, grad): return None grad_ctxt = op.grad_state.grad_context grad_ctxt.AddName(grad.name) - return enter(grad, grad_ctxt.name, is_constant=False, - parallel_iterations=grad_ctxt.parallel_iterations, - name="b_exit") + enter_fn = control_flow_ops._Enter # pylint: disable=protected-access + return enter_fn(grad, grad_ctxt.name, is_constant=False, + parallel_iterations=grad_ctxt.parallel_iterations, + name="b_exit") + + +ops.RegisterGradient("RefExit")(_ExitGrad) @ops.RegisterGradient("NextIteration") @@ -141,6 +159,11 @@ def _NextIterationGrad(_, grad): return grad +@ops.RegisterGradient("RefNextIteration") +def _RefNextIterationGrad(_, grad): + return _NextIterationGrad(_, grad) + + @ops.RegisterGradient("Enter") def _EnterGrad(op, grad): """Gradients for an Enter are calculated using an Exit op. diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 83caf6c779..abb7a48f54 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -108,6 +108,20 @@ def _Identity(data, name=None): return gen_array_ops._ref_identity(data, name=name) +def _NextIteration(data, name=None): + if not data.dtype.is_ref_dtype: + return next_iteration(data, name=name) + else: + return ref_next_iteration(data, name=name) + + +def _Merge(values, name=None): + if all([v.dtype.is_ref_dtype for v in values]): + return gen_control_flow_ops._ref_merge(values, name) + else: + return gen_control_flow_ops._merge(values, name) + + def _Enter(data, frame_name, is_constant=False, parallel_iterations=10, use_ref=True, name=None): """Creates or finds a child frame, and makes `data` available to it. @@ -148,7 +162,10 @@ def exit(data, name=None): Returns: The same tensor as `data`. """ - return gen_control_flow_ops._exit(data, name) + if data.dtype.is_ref_dtype: + return gen_control_flow_ops._ref_exit(data, name) + else: + return gen_control_flow_ops._exit(data, name) def switch(data, pred, dtype=None, name=None): @@ -215,20 +232,20 @@ def merge(inputs, name=None): dense_shape property. """ with ops.op_scope(inputs, name, "Merge") as name: - inputs = [ops.convert_to_tensor_or_indexed_slices(inp) for inp in inputs] + inputs = [ops.convert_to_tensor_or_indexed_slices(inp) + for inp in inputs] if all([isinstance(inp, ops.Tensor) for inp in inputs]): - return gen_control_flow_ops._merge(inputs, name=name) + return _Merge(inputs, name=name) else: inputs = math_ops._as_indexed_slices_list(inputs) - values, _ = gen_control_flow_ops._merge([inp.values for inp in inputs], - name=name) - indices, chosen_index = gen_control_flow_ops._merge( + values, _ = _Merge([inp.values for inp in inputs], name=name) + indices, chosen_index = _Merge( [inp.indices for inp in inputs], name="indices") if any(inp.dense_shape for inp in inputs): if not all(inp.dense_shape for inp in inputs): raise ValueError("Either all merged IndexedSlices must have a " "dense_shape, or none must have a dense_shape.") - dense_shape, _ = gen_control_flow_ops._merge( + dense_shape, _ = _Merge( [inp.dense_shape for inp in inputs], name="dense_shape") else: dense_shape = None @@ -255,7 +272,7 @@ def _SwitchRefOrTensor(data, pred, name="Switch"): Raises: TypeError: if data is not a Tensor or IndexedSlices """ - data = ops.convert_to_tensor_or_indexed_slices(data, name="data", as_ref=True) + data = ops.convert_to_tensor_or_indexed_slices(data, name="data") with ops.device(data.device): if isinstance(data, ops.Tensor): if not data.dtype.is_ref_dtype: @@ -402,6 +419,10 @@ def _IsLoopConstantEnter(op): return is_enter and op.get_attr("is_constant") +def _IsLoopExit(op): + return op.type == "Exit" or op.type == "RefExit" + + class GradLoopState(object): """The state used for constructing the gradient graph for a while loop. @@ -581,7 +602,7 @@ class GradLoopState(object): # Add the stack_push op in the context of value.op. value_ctxt = value.op._get_control_flow_context() - if value.op.type == "Exit": + if _IsLoopExit(value.op): value_ctxt = value_ctxt.outer_context if value_ctxt == self.forward_context: # value is not nested in the forward context. @@ -681,7 +702,7 @@ class GradLoopState(object): self._grad_context.Exit() outer_value = value.op.inputs[0] history_value = self._outer_grad_state.AddForwardAccumulator( - outer_value) + outer_value) self._grad_context.Enter() else: # Just use the input value of this Enter node. @@ -807,7 +828,7 @@ class ControlFlowState(object): outer_grad_ctxt = outer_grad_state.grad_context outer_grad_ctxt.Enter() real_val = outer_grad_state.AddBackPropAccumulatedValue( - history_val, val) + history_val, val) result = array_ops.zeros_like(real_val) outer_grad_ctxt.Exit() else: @@ -860,7 +881,7 @@ class ControlFlowState(object): grad_state.grad_context.Enter() # Create a zero tensor with the right shape. shape = grad_state.AddBackPropAccumulatedValue( - history_shape, zero_shape, dead_branch) + history_shape, zero_shape, dead_branch) result = array_ops.zeros(shape, val.dtype) return result @@ -883,7 +904,7 @@ def MaybeCreateControlFlowState(between_op_list, between_ops): """ loop_state = None for op in between_op_list: - if op.type == "Exit": + if _IsLoopExit(op): if loop_state is None: loop_state = ControlFlowState() loop_state.AddWhileContext(op, between_op_list, between_ops) @@ -892,7 +913,7 @@ def MaybeCreateControlFlowState(between_op_list, between_ops): def IsLoopSwitch(op): """Return true if `op` is the Switch for a While loop.""" - if op.type == "Switch": + if op.type == "Switch" or op.type == "RefSwitch": ctxt = op._get_control_flow_context() return ctxt and isinstance(ctxt, WhileContext) return False @@ -1286,7 +1307,7 @@ class WhileContext(ControlFlowContext): switch_n = switch(merge_n, self._pivot) index = math_ops.add(switch_n[1], 1) - next_n = next_iteration(index) + next_n = _NextIteration(index) merge_n.op._update_input(1, next_n) total_iterations = exit(switch_n[0], name="f_count") @@ -1326,7 +1347,7 @@ class WhileContext(ControlFlowContext): index = math_ops.sub(switch_count[1], one) self._pivot_for_body = index - next_count = next_iteration(index) + next_count = _NextIteration(index) merge_count.op._update_input(1, next_count) self.Exit() @@ -1366,7 +1387,7 @@ class WhileContext(ControlFlowContext): switch_acc = switch(merge_acc, self._pivot) add_acc = math_ops.add(switch_acc[1], value) - next_acc = next_iteration(add_acc) + next_acc = _NextIteration(add_acc) merge_acc.op._update_input(1, next_acc) acc_result = exit(switch_acc[0], name="b_acc") @@ -1385,8 +1406,7 @@ class WhileContext(ControlFlowContext): real_vars = [self._outer_context.AddValue(x) for x in loop_vars] with ops.control_dependencies(None): enter_vars = [_Enter(x, self._name, is_constant=False, - parallel_iterations=self._parallel_iterations, - use_ref=False) + parallel_iterations=self._parallel_iterations) for x in real_vars] for x in enter_vars: x.op._set_control_flow_context(self) # pylint: disable=protected-access @@ -1408,7 +1428,7 @@ class WhileContext(ControlFlowContext): if not isinstance(body_result, (list, _basetuple)): body_result = [body_result] result = ops.convert_n_to_tensor_or_indexed_slices(body_result) - next_vars = [next_iteration(x) for x in result] + next_vars = [_NextIteration(x) for x in result] # Add the back edges to complete the loop. assert len(merge_vars) == len(next_vars) @@ -1863,6 +1883,8 @@ ops.RegisterShape("Enter")(common_shapes.unchanged_shape) ops.RegisterShape("Exit")(common_shapes.unknown_shape) ops.RegisterShape("NextIteration")(common_shapes.unchanged_shape) ops.RegisterShape("RefEnter")(common_shapes.unchanged_shape) +ops.RegisterShape("RefExit")(common_shapes.unknown_shape) +ops.RegisterShape("RefNextIteration")(common_shapes.unchanged_shape) ops.RegisterShape("ControlTrigger")(common_shapes.no_outputs) ops.RegisterShape("NoOp")(common_shapes.no_outputs) @@ -1903,6 +1925,8 @@ def _MergeShape(op): else: return [tensor_shape.unknown_shape(), tensor_shape.scalar()] +ops.RegisterShape("RefMerge")(_MergeShape) + @ops.RegisterShape("RefSelect") def _RefSelectShape(op): diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py index 28158d579d..a99a8ea2f5 100644 --- a/tensorflow/python/ops/gradients.py +++ b/tensorflow/python/ops/gradients.py @@ -408,6 +408,7 @@ def gradients(ys, for op in to_ops: # 'ready' handles the case where one output gradient relies on # another output's gradient. + # pylint: disable=protected-access ready = (pending_count[op._id] == 0) if ready and op._id not in to_ops_set: to_ops_set.add(op._id) @@ -439,11 +440,13 @@ def gradients(ys, loop_state.EnterGradWhileContext(op) out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method) grad_fn = None + # pylint: disable=protected-access is_func_call = ops.get_default_graph()._is_function(op.type) # pylint: enable=protected-access if not is_func_call and any(out_grads) and op._id not in stop_ops: + # pylint: enable=protected-access # A grad_fn must be defined, either as a function or as None # for ops that do not have gradients. try: @@ -472,8 +475,8 @@ def gradients(ys, if loop_state: wrapped_op = loop_state.MakeWrapper(op) if is_func_call: - # For function call ops, we add a 'SymbolicGradient' node to the - # graph to compute gradients. + # For function call ops, we add a 'SymbolicGradient' + # node to the graph to compute gradients. f_in = [x for x in op.inputs] + out_grads f_types = [x.dtype for x in op.inputs] # pylint: disable=protected-access @@ -501,6 +504,7 @@ def gradients(ys, loop_state.ExitGradWhileContext(op) # update pending count for the inputs of op. + # pylint: disable=protected-access for x in op.inputs: pending_count[x.op._id] -= 1 ready = (pending_count[x.op._id] == 0) @@ -513,6 +517,7 @@ def gradients(ys, pending_count[x._id] -= 1 if pending_count[x._id] is 0: queue.append(x) + # pylint: enable=protected-access return [_GetGrad(grads, x) for x in xs] |