aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/graph/graph.cc6
-rw-r--r--tensorflow/core/graph/graph_constructor.cc2
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc65
-rw-r--r--tensorflow/core/ops/control_flow_ops.cc48
-rw-r--r--tensorflow/core/ops/ops.pbtxt76
-rw-r--r--tensorflow/python/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py67
-rw-r--r--tensorflow/python/ops/control_flow_grad.py47
-rw-r--r--tensorflow/python/ops/control_flow_ops.py64
-rw-r--r--tensorflow/python/ops/gradients.py9
10 files changed, 332 insertions, 54 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 4abff8922e..a29f800a26 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -83,10 +83,10 @@ void Node::Initialize(int id, int cost_id, Properties* props) {
} while (0)
SET_CLASS(NC_SWITCH, ts, "Switch", "RefSwitch");
- SET_CLASS(NC_MERGE, ts, "Merge", "");
+ SET_CLASS(NC_MERGE, ts, "Merge", "RefMerge");
SET_CLASS(NC_ENTER, ts, "Enter", "RefEnter");
- SET_CLASS(NC_EXIT, ts, "Exit", "");
- SET_CLASS(NC_NEXT_ITERATION, ts, "NextIteration", "");
+ SET_CLASS(NC_EXIT, ts, "Exit", "RefExit");
+ SET_CLASS(NC_NEXT_ITERATION, ts, "NextIteration", "RefNextIteration");
SET_CLASS(NC_LOOP_COND, ts, "LoopCond", "");
SET_CLASS(NC_CONTROL_TRIGGER, ts, "ControlTrigger", "");
SET_CLASS(NC_SEND, ts, "_Send", "_HostSend");
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 874f6214f9..c96f04f2d7 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -35,7 +35,7 @@ namespace tensorflow {
namespace {
inline bool IsMerge(const NodeDef& node_def) {
- return node_def.op() == "Merge";
+ return node_def.op() == "Merge" || node_def.op() == "RefMerge";
}
} // namespace
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index 8b539a7751..f009b85b1b 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -192,7 +192,11 @@ class MergeOp : public OpKernel {
}
input_seen = true;
- context->set_output(0, context->input(i));
+ if (IsRefType(context->input_dtype(i))) {
+ context->forward_ref_input_to_ref_output(i, 0);
+ } else {
+ context->set_output(0, context->input(i));
+ }
Tensor* value_index = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}),
&value_index));
@@ -209,18 +213,26 @@ class MergeOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp);
+REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp);
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Merge") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.HostMemory("value_index"), \
- MergeOp)
+ MergeOp);
+#define REGISTER_GPU_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefMerge") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("value_index"), \
+ MergeOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
REGISTER_GPU_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
+#undef REGISTER_GPU_REF_KERNEL
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -232,6 +244,13 @@ REGISTER_GPU_KERNEL(bool);
.HostMemory("output") \
.HostMemory("value_index") \
.TypeConstraint<type>("T"), \
+ MergeOp); \
+ REGISTER_KERNEL_BUILDER(Name("RefMerge") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("inputs") \
+ .HostMemory("output") \
+ .HostMemory("value_index") \
+ .TypeConstraint<type>("T"), \
MergeOp)
REGISTER_GPU_HOST_KERNEL(int32);
@@ -314,7 +333,11 @@ class ExitOp : public OpKernel {
explicit ExitOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
- context->set_output(0, context->input(0));
+ if (IsRefType(context->input_dtype(0))) {
+ context->forward_ref_input_to_ref_output(0, 0);
+ } else {
+ context->set_output(0, context->input(0));
+ }
}
bool IsExpensive() override { return false; }
@@ -325,15 +348,20 @@ class ExitOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp);
+REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp);
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
+#define REGISTER_GPU_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
REGISTER_GPU_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
+#undef REGISTER_GPU_REF_KERNEL
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -344,7 +372,13 @@ REGISTER_GPU_KERNEL(bool);
.HostMemory("data") \
.HostMemory("output") \
.TypeConstraint<type>("T"), \
- ExitOp)
+ ExitOp); \
+ REGISTER_KERNEL_BUILDER(Name("RefExit") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ ExitOp);
REGISTER_GPU_HOST_KERNEL(int32);
REGISTER_GPU_HOST_KERNEL(string);
@@ -358,7 +392,11 @@ class NextIterationOp : public OpKernel {
explicit NextIterationOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
- context->set_output(0, context->input(0));
+ if (IsRefType(context->input_dtype(0))) {
+ context->forward_ref_input_to_ref_output(0, 0);
+ } else {
+ context->set_output(0, context->input(0));
+ }
}
bool IsExpensive() override { return false; }
@@ -370,10 +408,15 @@ class NextIterationOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU),
NextIterationOp);
+REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU),
+ NextIterationOp);
-#define REGISTER_GPU_KERNEL(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ NextIterationOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RefNextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
NextIterationOp)
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
@@ -390,6 +433,12 @@ REGISTER_GPU_KERNEL(bool);
.HostMemory("data") \
.HostMemory("output") \
.TypeConstraint<type>("T"), \
+ NextIterationOp); \
+ REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
NextIterationOp)
REGISTER_GPU_HOST_KERNEL(int32);
diff --git a/tensorflow/core/ops/control_flow_ops.cc b/tensorflow/core/ops/control_flow_ops.cc
index 06c7e91af3..7ae0012af7 100644
--- a/tensorflow/core/ops/control_flow_ops.cc
+++ b/tensorflow/core/ops/control_flow_ops.cc
@@ -97,6 +97,28 @@ output: Will be set to the available input tensor.
value_index: The index of the chosen input tensor in `inputs`.
)doc");
+REGISTER_OP("RefMerge")
+ .Input("inputs: Ref(N * T)")
+ .Output("output: Ref(T)")
+ .Output("value_index: int32")
+ .Attr("T: type")
+ .Attr("N: int >= 1")
+ .Doc(R"doc(
+Forwards the value of an available tensor from `inputs` to `output`.
+
+`Merge` waits for at least one of the tensors in `inputs` to become available.
+It is usually combined with `Switch` to implement branching.
+
+`Merge` forwards the first tensor for become available to `output`, and sets
+`value_index` to its index in `inputs`.
+
+It is an error if more than one tensor in `inputs` is available.
+
+inputs: The input tensors, exactly one of which will become available.
+output: Will be set to the available input tensor.
+value_index: The index of the chosen input tensor in `inputs`.
+)doc");
+
// --------------------------------------------------------------------------
REGISTER_OP("Enter")
.Input("data: T")
@@ -158,6 +180,19 @@ data: The tensor to be made available to the parent frame.
output: The same tensor as `data`.
)doc");
+REGISTER_OP("RefExit")
+ .Input("data: Ref(T)")
+ .Output("output: Ref(T)")
+ .Attr("T: type")
+ .Doc(R"doc(
+Exits the current frame to its parent frame.
+
+Exit makes its input `data` available to the parent frame.
+
+data: The tensor to be made available to the parent frame.
+output: The same tensor as `data`.
+)doc");
+
// --------------------------------------------------------------------------
REGISTER_OP("NextIteration")
.Input("data: T")
@@ -170,6 +205,17 @@ data: The tensor to be made available to the next iteration.
output: The same tensor as `data`.
)doc");
+REGISTER_OP("RefNextIteration")
+ .Input("data: Ref(T)")
+ .Output("output: Ref(T)")
+ .Attr("T: type")
+ .Doc(R"doc(
+Makes its input available to the next iteration.
+
+data: The tensor to be made available to the next iteration.
+output: The same tensor as `data`.
+)doc");
+
// --------------------------------------------------------------------------
REGISTER_OP("LoopCond")
.Input("input: bool")
@@ -180,7 +226,7 @@ Forwards the input to the output.
This operator represents the loop termination condition used by the
"pivot" switches of a loop.
-input:= A boolean scalar, representing the branch predicate of the Switch op.
+input: A boolean scalar, representing the branch predicate of the Switch op.
output: The same tensor as `input`.
)doc");
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index b54826269e..69247f02e1 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -3996,7 +3996,7 @@ op {
name: "LoopCond"
input_arg {
name: "input"
- description: "= A boolean scalar, representing the branch predicate of the Switch op."
+ description: "A boolean scalar, representing the branch predicate of the Switch op."
type: DT_BOOL
}
output_arg {
@@ -6124,6 +6124,27 @@ op {
description: "The unique `frame_name` is used by the `Executor` to identify frames. If\n`is_constant` is true, `output` is a constant in the child frame; otherwise\nit may be changed in the child frame. At most `parallel_iterations` iterations\nare run in parallel in the child frame."
}
op {
+ name: "RefExit"
+ input_arg {
+ name: "data"
+ description: "The tensor to be made available to the parent frame."
+ type_attr: "T"
+ is_ref: true
+ }
+ output_arg {
+ name: "output"
+ description: "The same tensor as `data`."
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ summary: "Exits the current frame to its parent frame."
+ description: "Exit makes its input `data` available to the parent frame."
+}
+op {
name: "RefIdentity"
input_arg {
name: "input"
@@ -6142,6 +6163,59 @@ op {
summary: "Return the same ref tensor as the input ref tensor."
}
op {
+ name: "RefMerge"
+ input_arg {
+ name: "inputs"
+ description: "The input tensors, exactly one of which will become available."
+ type_attr: "T"
+ number_attr: "N"
+ is_ref: true
+ }
+ output_arg {
+ name: "output"
+ description: "Will be set to the available input tensor."
+ type_attr: "T"
+ is_ref: true
+ }
+ output_arg {
+ name: "value_index"
+ description: "The index of the chosen input tensor in `inputs`."
+ type: DT_INT32
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ summary: "Forwards the value of an available tensor from `inputs` to `output`."
+ description: "`Merge` waits for at least one of the tensors in `inputs` to become available.\nIt is usually combined with `Switch` to implement branching.\n\n`Merge` forwards the first tensor for become available to `output`, and sets\n`value_index` to its index in `inputs`.\n\nIt is an error if more than one tensor in `inputs` is available."
+}
+op {
+ name: "RefNextIteration"
+ input_arg {
+ name: "data"
+ description: "The tensor to be made available to the next iteration."
+ type_attr: "T"
+ is_ref: true
+ }
+ output_arg {
+ name: "output"
+ description: "The same tensor as `data`."
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ summary: "Makes its input available to the next iteration."
+}
+op {
name: "RefSelect"
input_arg {
name: "index"
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 71840b31c2..4457ebbb2e 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -435,7 +435,9 @@ tf_gen_op_wrapper_py(
hidden = [
"Switch",
"Merge",
+ "RefMerge",
"Exit",
+ "RefExit",
],
require_shape_functions = True,
deps = [
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index baed939f73..7c405e8212 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -29,8 +29,8 @@ from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_data_flow_ops
-from tensorflow.python.ops import gradients
from tensorflow.python.ops import logging_ops
from tensorflow.python.pywrap_tensorflow import StatusNotOK
@@ -68,6 +68,36 @@ def isum(s):
class ControlFlowTest(tf.test.TestCase):
+ def testWhileWithRefsWithGradients_1(self):
+ with self.test_session() as sess:
+ x = tf.Variable(0).ref()
+ i = tf.constant(0)
+ c = lambda i, x: tf.less(i, 10)
+
+ self.assertEqual(x.dtype, tf.int32_ref)
+
+ # pylint: disable=protected-access
+ def body(i, x):
+ self.assertEqual(x.dtype, tf.int32_ref)
+ return (i+1, gen_array_ops._ref_identity(x))
+ # pylint: enable=protected-access
+
+ r = control_flow_ops.While(c, body, [i, x], parallel_iterations=5)
+
+ grad_ys = [tf.Variable(73).ref()]
+ grad = tf.gradients([r[1]], [x], grad_ys=grad_ys)
+
+ tf.initialize_all_variables().run()
+
+ self.assertEqual(r[0].dtype, tf.int32)
+ self.assertEqual(r[1].dtype, tf.int32_ref)
+
+ value_i, value_x, value_x_grad = sess.run(r + grad)
+
+ self.assertEqual(10, value_i)
+ self.assertEqual(0, value_x)
+ self.assertEqual(73, value_x_grad)
+
def testRefIdentity(self):
with self.test_session():
v = tf.Variable(7)
@@ -99,7 +129,7 @@ class ControlFlowTest(tf.test.TestCase):
v = tf.Variable(7)
p = tf.constant(True)
- v1 = control_flow_ops._SwitchRefOrTensor(v, p)
+ v1 = control_flow_ops._SwitchRefOrTensor(v.ref(), p)
v2 = tf.assign(v1[1], 9)
tf.initialize_all_variables().run()
self.assertEqual(9, v2.eval())
@@ -171,7 +201,7 @@ class ControlFlowTest(tf.test.TestCase):
dead_branch = tf.identity(switch_op[0])
with self.assertRaisesWithPredicateMatch(
- StatusNotOK, lambda e: 'The tensor returned for' in str(e)):
+ StatusNotOK, lambda e: "The tensor returned for" in str(e)):
dead_branch.eval()
def testSwitchMergeIdentity_1(self):
@@ -544,6 +574,30 @@ class ControlFlowTest(tf.test.TestCase):
self.assertTrue(check_op_order(n.graph))
self.assertEqual(10000, result)
+ def testWhileWithRefs_1(self):
+ with self.test_session() as sess:
+ x = tf.Variable(0).ref()
+ i = tf.constant(0)
+ c = lambda i, x: tf.less(i, 100)
+
+ self.assertEqual(x.dtype, tf.int32_ref)
+
+ def b(i, x):
+ self.assertEqual(x.dtype, tf.int32_ref)
+ return (i+1, gen_array_ops._ref_identity(x))
+
+ r = control_flow_ops.While(c, b, [i, x], parallel_iterations=5)
+
+ tf.initialize_all_variables().run()
+
+ self.assertEqual(r[0].dtype, tf.int32)
+ self.assertEqual(r[1].dtype, tf.int32_ref)
+
+ value_i, value_x = sess.run(r)
+
+ self.assertEqual(100, value_i)
+ self.assertEqual(0, value_x)
+
def testWhile_2(self):
with self.test_session():
s = tf.constant(0)
@@ -737,8 +791,8 @@ class ControlFlowTest(tf.test.TestCase):
n = tf.convert_to_tensor(10, name="n")
one = tf.convert_to_tensor(1, name="one")
c = lambda x: tf.less(x, n)
- b = lambda x: tf.cond(tf.constant(True), lambda: tf.add(x, one),
- lambda: tf.sub(x, one))
+ b = lambda x: tf.cond(
+ tf.constant(True), lambda: tf.add(x, one), lambda: tf.sub(x, one))
r = control_flow_ops.While(c, b, [i])
result = r.eval()
@@ -880,7 +934,7 @@ class ControlFlowTest(tf.test.TestCase):
tf.initialize_all_variables().run()
# Change condition to check var_b
- def pred(i):
+ def pred(_):
return tf.less(var_b, 10)
# Change body to increment var_b
@@ -1507,5 +1561,6 @@ class TupleTest(tf.test.TestCase):
self.assertEquals(1, var.eval())
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py
index a9986eb6bd..15f6ac8a8f 100644
--- a/tensorflow/python/ops/control_flow_grad.py
+++ b/tensorflow/python/ops/control_flow_grad.py
@@ -20,13 +20,13 @@ from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
# pylint: disable=wildcard-import,undefined-variable
from tensorflow.python.ops.control_flow_ops import *
from tensorflow.python.ops.gen_control_flow_ops import *
-@ops.RegisterGradient("Switch")
def _SwitchGrad(op, *grad):
"""Gradients for a Switch op is calculated using a Merge op.
@@ -45,14 +45,17 @@ def _SwitchGrad(op, *grad):
# the non-exit branch of the Switch, so update the second input
# to the Merge.
# TODO: Need to perform shape inference with this new input.
- merge_op._update_input(1, next_iteration(grad[1]))
+ # pylint: disable=protected-access
+ merge_op._update_input(1, control_flow_ops._NextIteration(grad[1]))
+ # pylint: enable=protected-access
return None, None
else:
# This is the first time this Switch is visited. It always comes
# from the Exit branch, which is grad[0]. grad[1] is empty at this point.
# Use grad[0] for both inputs to merge for now, but update the second
# input of merge when we see this Switch the second time.
- merge_op = merge([grad[0], grad[0]], name="b_switch")[0]
+ merge_fn = control_flow_ops._Merge # pylint: disable=protected-access
+ merge_op = merge_fn([grad[0], grad[0]], name="b_switch")[0]
op.grad_state.switch_map[real_op] = merge_op.op
return merge_op, None
elif isinstance(ctxt, CondContext):
@@ -71,9 +74,8 @@ def _SwitchGrad(op, *grad):
return merge([false_grad, true_grad])[0], None
-@ops.RegisterGradient("RefSwitch")
-def _RefSwitchGrad(op, *grad):
- return _SwitchGrad(op, *grad)
+ops.RegisterGradient("Switch")(_SwitchGrad)
+ops.RegisterGradient("RefSwitch")(_SwitchGrad)
@ops.RegisterGradient("Merge")
@@ -86,7 +88,9 @@ def _MergeGrad(op, grad, _):
# pylint: enable=protected-access
if isinstance(ctxt, WhileContext):
grad_ctxt = op.grad_state.grad_context
- return switch(grad, grad_ctxt.pivot)
+ # pylint: disable=protected-access
+ return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot)
+ # pylint: enable=protected-access
elif isinstance(ctxt, CondContext):
pred = ctxt.pred
if isinstance(op, ControlFlowOpWrapper):
@@ -108,11 +112,21 @@ def _MergeGrad(op, grad, _):
real_pred = grad_state.AddBackPropAccumulatedValue(history_pred, pred)
grad_state.history_map[pred.name] = real_pred
pred = real_pred
- return switch(grad, pred, name="cond_grad")
+ # pylint: disable=protected-access
+ return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad")
+ # pylint: enable=protected-access
else:
num_inputs = len(real_op.inputs)
cond = [math_ops.equal(real_op.outputs[1], i) for i in xrange(num_inputs)]
- return [switch(grad, cond[i])[1] for i in xrange(num_inputs)]
+ # pylint: disable=protected-access
+ return [control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1]
+ for i in xrange(num_inputs)]
+ # pylint: enable=protected-access
+
+
+@ops.RegisterGradient("RefMerge")
+def _RefMergeGrad(op, grad, _):
+ return _MergeGrad(op, grad, _)
@ops.RegisterGradient("Exit")
@@ -127,9 +141,13 @@ def _ExitGrad(op, grad):
return None
grad_ctxt = op.grad_state.grad_context
grad_ctxt.AddName(grad.name)
- return enter(grad, grad_ctxt.name, is_constant=False,
- parallel_iterations=grad_ctxt.parallel_iterations,
- name="b_exit")
+ enter_fn = control_flow_ops._Enter # pylint: disable=protected-access
+ return enter_fn(grad, grad_ctxt.name, is_constant=False,
+ parallel_iterations=grad_ctxt.parallel_iterations,
+ name="b_exit")
+
+
+ops.RegisterGradient("RefExit")(_ExitGrad)
@ops.RegisterGradient("NextIteration")
@@ -141,6 +159,11 @@ def _NextIterationGrad(_, grad):
return grad
+@ops.RegisterGradient("RefNextIteration")
+def _RefNextIterationGrad(_, grad):
+ return _NextIterationGrad(_, grad)
+
+
@ops.RegisterGradient("Enter")
def _EnterGrad(op, grad):
"""Gradients for an Enter are calculated using an Exit op.
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 83caf6c779..abb7a48f54 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -108,6 +108,20 @@ def _Identity(data, name=None):
return gen_array_ops._ref_identity(data, name=name)
+def _NextIteration(data, name=None):
+ if not data.dtype.is_ref_dtype:
+ return next_iteration(data, name=name)
+ else:
+ return ref_next_iteration(data, name=name)
+
+
+def _Merge(values, name=None):
+ if all([v.dtype.is_ref_dtype for v in values]):
+ return gen_control_flow_ops._ref_merge(values, name)
+ else:
+ return gen_control_flow_ops._merge(values, name)
+
+
def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
use_ref=True, name=None):
"""Creates or finds a child frame, and makes `data` available to it.
@@ -148,7 +162,10 @@ def exit(data, name=None):
Returns:
The same tensor as `data`.
"""
- return gen_control_flow_ops._exit(data, name)
+ if data.dtype.is_ref_dtype:
+ return gen_control_flow_ops._ref_exit(data, name)
+ else:
+ return gen_control_flow_ops._exit(data, name)
def switch(data, pred, dtype=None, name=None):
@@ -215,20 +232,20 @@ def merge(inputs, name=None):
dense_shape property.
"""
with ops.op_scope(inputs, name, "Merge") as name:
- inputs = [ops.convert_to_tensor_or_indexed_slices(inp) for inp in inputs]
+ inputs = [ops.convert_to_tensor_or_indexed_slices(inp)
+ for inp in inputs]
if all([isinstance(inp, ops.Tensor) for inp in inputs]):
- return gen_control_flow_ops._merge(inputs, name=name)
+ return _Merge(inputs, name=name)
else:
inputs = math_ops._as_indexed_slices_list(inputs)
- values, _ = gen_control_flow_ops._merge([inp.values for inp in inputs],
- name=name)
- indices, chosen_index = gen_control_flow_ops._merge(
+ values, _ = _Merge([inp.values for inp in inputs], name=name)
+ indices, chosen_index = _Merge(
[inp.indices for inp in inputs], name="indices")
if any(inp.dense_shape for inp in inputs):
if not all(inp.dense_shape for inp in inputs):
raise ValueError("Either all merged IndexedSlices must have a "
"dense_shape, or none must have a dense_shape.")
- dense_shape, _ = gen_control_flow_ops._merge(
+ dense_shape, _ = _Merge(
[inp.dense_shape for inp in inputs], name="dense_shape")
else:
dense_shape = None
@@ -255,7 +272,7 @@ def _SwitchRefOrTensor(data, pred, name="Switch"):
Raises:
TypeError: if data is not a Tensor or IndexedSlices
"""
- data = ops.convert_to_tensor_or_indexed_slices(data, name="data", as_ref=True)
+ data = ops.convert_to_tensor_or_indexed_slices(data, name="data")
with ops.device(data.device):
if isinstance(data, ops.Tensor):
if not data.dtype.is_ref_dtype:
@@ -402,6 +419,10 @@ def _IsLoopConstantEnter(op):
return is_enter and op.get_attr("is_constant")
+def _IsLoopExit(op):
+ return op.type == "Exit" or op.type == "RefExit"
+
+
class GradLoopState(object):
"""The state used for constructing the gradient graph for a while loop.
@@ -581,7 +602,7 @@ class GradLoopState(object):
# Add the stack_push op in the context of value.op.
value_ctxt = value.op._get_control_flow_context()
- if value.op.type == "Exit":
+ if _IsLoopExit(value.op):
value_ctxt = value_ctxt.outer_context
if value_ctxt == self.forward_context:
# value is not nested in the forward context.
@@ -681,7 +702,7 @@ class GradLoopState(object):
self._grad_context.Exit()
outer_value = value.op.inputs[0]
history_value = self._outer_grad_state.AddForwardAccumulator(
- outer_value)
+ outer_value)
self._grad_context.Enter()
else:
# Just use the input value of this Enter node.
@@ -807,7 +828,7 @@ class ControlFlowState(object):
outer_grad_ctxt = outer_grad_state.grad_context
outer_grad_ctxt.Enter()
real_val = outer_grad_state.AddBackPropAccumulatedValue(
- history_val, val)
+ history_val, val)
result = array_ops.zeros_like(real_val)
outer_grad_ctxt.Exit()
else:
@@ -860,7 +881,7 @@ class ControlFlowState(object):
grad_state.grad_context.Enter()
# Create a zero tensor with the right shape.
shape = grad_state.AddBackPropAccumulatedValue(
- history_shape, zero_shape, dead_branch)
+ history_shape, zero_shape, dead_branch)
result = array_ops.zeros(shape, val.dtype)
return result
@@ -883,7 +904,7 @@ def MaybeCreateControlFlowState(between_op_list, between_ops):
"""
loop_state = None
for op in between_op_list:
- if op.type == "Exit":
+ if _IsLoopExit(op):
if loop_state is None:
loop_state = ControlFlowState()
loop_state.AddWhileContext(op, between_op_list, between_ops)
@@ -892,7 +913,7 @@ def MaybeCreateControlFlowState(between_op_list, between_ops):
def IsLoopSwitch(op):
"""Return true if `op` is the Switch for a While loop."""
- if op.type == "Switch":
+ if op.type == "Switch" or op.type == "RefSwitch":
ctxt = op._get_control_flow_context()
return ctxt and isinstance(ctxt, WhileContext)
return False
@@ -1286,7 +1307,7 @@ class WhileContext(ControlFlowContext):
switch_n = switch(merge_n, self._pivot)
index = math_ops.add(switch_n[1], 1)
- next_n = next_iteration(index)
+ next_n = _NextIteration(index)
merge_n.op._update_input(1, next_n)
total_iterations = exit(switch_n[0], name="f_count")
@@ -1326,7 +1347,7 @@ class WhileContext(ControlFlowContext):
index = math_ops.sub(switch_count[1], one)
self._pivot_for_body = index
- next_count = next_iteration(index)
+ next_count = _NextIteration(index)
merge_count.op._update_input(1, next_count)
self.Exit()
@@ -1366,7 +1387,7 @@ class WhileContext(ControlFlowContext):
switch_acc = switch(merge_acc, self._pivot)
add_acc = math_ops.add(switch_acc[1], value)
- next_acc = next_iteration(add_acc)
+ next_acc = _NextIteration(add_acc)
merge_acc.op._update_input(1, next_acc)
acc_result = exit(switch_acc[0], name="b_acc")
@@ -1385,8 +1406,7 @@ class WhileContext(ControlFlowContext):
real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
with ops.control_dependencies(None):
enter_vars = [_Enter(x, self._name, is_constant=False,
- parallel_iterations=self._parallel_iterations,
- use_ref=False)
+ parallel_iterations=self._parallel_iterations)
for x in real_vars]
for x in enter_vars:
x.op._set_control_flow_context(self) # pylint: disable=protected-access
@@ -1408,7 +1428,7 @@ class WhileContext(ControlFlowContext):
if not isinstance(body_result, (list, _basetuple)):
body_result = [body_result]
result = ops.convert_n_to_tensor_or_indexed_slices(body_result)
- next_vars = [next_iteration(x) for x in result]
+ next_vars = [_NextIteration(x) for x in result]
# Add the back edges to complete the loop.
assert len(merge_vars) == len(next_vars)
@@ -1863,6 +1883,8 @@ ops.RegisterShape("Enter")(common_shapes.unchanged_shape)
ops.RegisterShape("Exit")(common_shapes.unknown_shape)
ops.RegisterShape("NextIteration")(common_shapes.unchanged_shape)
ops.RegisterShape("RefEnter")(common_shapes.unchanged_shape)
+ops.RegisterShape("RefExit")(common_shapes.unknown_shape)
+ops.RegisterShape("RefNextIteration")(common_shapes.unchanged_shape)
ops.RegisterShape("ControlTrigger")(common_shapes.no_outputs)
ops.RegisterShape("NoOp")(common_shapes.no_outputs)
@@ -1903,6 +1925,8 @@ def _MergeShape(op):
else:
return [tensor_shape.unknown_shape(), tensor_shape.scalar()]
+ops.RegisterShape("RefMerge")(_MergeShape)
+
@ops.RegisterShape("RefSelect")
def _RefSelectShape(op):
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index 28158d579d..a99a8ea2f5 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -408,6 +408,7 @@ def gradients(ys,
for op in to_ops:
# 'ready' handles the case where one output gradient relies on
# another output's gradient.
+ # pylint: disable=protected-access
ready = (pending_count[op._id] == 0)
if ready and op._id not in to_ops_set:
to_ops_set.add(op._id)
@@ -439,11 +440,13 @@ def gradients(ys,
loop_state.EnterGradWhileContext(op)
out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)
grad_fn = None
+
# pylint: disable=protected-access
is_func_call = ops.get_default_graph()._is_function(op.type)
# pylint: enable=protected-access
if not is_func_call and any(out_grads) and op._id not in stop_ops:
+ # pylint: enable=protected-access
# A grad_fn must be defined, either as a function or as None
# for ops that do not have gradients.
try:
@@ -472,8 +475,8 @@ def gradients(ys,
if loop_state:
wrapped_op = loop_state.MakeWrapper(op)
if is_func_call:
- # For function call ops, we add a 'SymbolicGradient' node to the
- # graph to compute gradients.
+ # For function call ops, we add a 'SymbolicGradient'
+ # node to the graph to compute gradients.
f_in = [x for x in op.inputs] + out_grads
f_types = [x.dtype for x in op.inputs]
# pylint: disable=protected-access
@@ -501,6 +504,7 @@ def gradients(ys,
loop_state.ExitGradWhileContext(op)
# update pending count for the inputs of op.
+ # pylint: disable=protected-access
for x in op.inputs:
pending_count[x.op._id] -= 1
ready = (pending_count[x.op._id] == 0)
@@ -513,6 +517,7 @@ def gradients(ys,
pending_count[x._id] -= 1
if pending_count[x._id] is 0:
queue.append(x)
+ # pylint: enable=protected-access
return [_GetGrad(grads, x) for x in xs]