aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yuan Yu <yuanbyu@google.com>2016-07-29 08:59:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-29 10:03:18 -0700
commitf419dd1f11548386bacfb789a47d5ff6c13ae197 (patch)
treec5814169bef26b8bad3ff13da0fb096507d61a0d
parentacd882cd838b59874cf880c718ce32aef33a3044 (diff)
Fix a bug in handling while loop inside a "with tf.control_dependencies()". Also tidy up some error checking.
Change: 128818017
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py9
-rw-r--r--tensorflow/python/ops/control_flow_ops.py45
2 files changed, 41 insertions, 13 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index a8ae70e49b..00372831df 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -721,6 +721,15 @@ class ControlFlowTest(tf.test.TestCase):
c = tf.while_loop(lambda x: x < 10, lambda x: x + 1, [c])
self.assertEqual(10, sess.run(c, {b: True}))
+ def testWhileWithControl_4(self):
+ with self.test_session() as sess:
+ b = tf.placeholder(tf.bool)
+ c = tf.constant(1)
+ x0 = tf.constant(0)
+ with tf.control_dependencies([b]):
+ r = tf.while_loop(lambda x: x < 10, lambda x: x + tf.identity(c), [x0])
+ self.assertEqual(10, sess.run(r, {b: True}))
+
def testCondWhile_1(self):
with self.test_session():
n = tf.convert_to_tensor(0, name="n")
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 18a7a20c11..eee3b3e2d4 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -118,6 +118,8 @@ def _Identity(data, name=None):
else:
return array_ops.identity(data, name=name)
else:
+ if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)):
+ raise TypeError("Type %s not supported" % type(data))
values = _Identity(data.values, name=name)
indices = array_ops.identity(data.indices, name="indices")
if isinstance(data, ops.IndexedSlices):
@@ -125,11 +127,9 @@ def _Identity(data, name=None):
if dense_shape is not None:
dense_shape = array_ops.identity(dense_shape, name="dense_shape")
return ops.IndexedSlices(values, indices, dense_shape)
- elif isinstance(data, ops.SparseTensor):
+ else:
dense_shape = array_ops.identity(data.shape, name="dense_shape")
return ops.SparseTensor(indices, values, dense_shape)
- else:
- raise TypeError("Type %s not supported" % type(data))
def _NextIteration(data, name=None):
@@ -140,6 +140,8 @@ def _NextIteration(data, name=None):
else:
return next_iteration(data, name=name)
else:
+ if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)):
+ raise TypeError("Type %s not supported" % type(data))
values = _NextIteration(data.values, name=name)
indices = next_iteration(data.indices, name="indices")
if isinstance(data, ops.IndexedSlices):
@@ -147,11 +149,9 @@ def _NextIteration(data, name=None):
if dense_shape is not None:
dense_shape = next_iteration(dense_shape, name="dense_shape")
return ops.IndexedSlices(values, indices, dense_shape)
- elif isinstance(data, ops.SparseTensor):
+ else:
dense_shape = next_iteration(data.shape, name="dense_shape")
return ops.SparseTensor(indices, values, dense_shape)
- else:
- raise TypeError("Type %s not supported" % type(data))
def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
@@ -183,6 +183,8 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
return enter(data, frame_name, is_constant, parallel_iterations,
name=name)
else:
+ if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)):
+ raise TypeError("Type %s not supported" % type(data))
values = _Enter(data.values, frame_name, is_constant,
parallel_iterations, name=name)
indices = enter(data.indices, frame_name, is_constant,
@@ -193,12 +195,10 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
dense_shape = enter(dense_shape, frame_name, is_constant,
parallel_iterations, name="dense_shape")
return ops.IndexedSlices(values, indices, dense_shape)
- elif isinstance(data, ops.SparseTensor):
+ else:
dense_shape = enter(data.shape, frame_name, is_constant,
parallel_iterations, name="dense_shape")
return ops.SparseTensor(indices, values, dense_shape)
- else:
- raise TypeError("Type %s not supported" % type(data))
def exit(data, name=None):
@@ -220,6 +220,8 @@ def exit(data, name=None):
else:
return gen_control_flow_ops._exit(data, name)
else:
+ if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)):
+ raise TypeError("Type %s not supported" % type(data))
values = exit(data.values, name=name)
indices = gen_control_flow_ops._exit(data.indices, name="indices")
if isinstance(data, ops.IndexedSlices):
@@ -227,11 +229,9 @@ def exit(data, name=None):
if dense_shape is not None:
dense_shape = gen_control_flow_ops._exit(dense_shape, name)
return ops.IndexedSlices(values, indices, dense_shape)
- elif isinstance(data, ops.SparseTensor):
+ else:
dense_shape = gen_control_flow_ops._exit(data.shape, name)
return ops.SparseTensor(indices, values, dense_shape)
- else:
- raise TypeError("Type %s not supported" % type(data))
def switch(data, pred, dtype=None, name=None):
@@ -1493,8 +1493,15 @@ class WhileContext(ControlFlowContext):
self._AddOpInternal(op)
def _AddOpInternal(self, op):
- """Add `op` to the current context."""
+ """Add `op` to the current context.
+
+ In the case that op has only external data inputs, we remove all of its
+ external control inputs so all its inputs are in the same while loop
+ context. This is valid because op now has an Enter input that has all
+ the right control dependency.
+ """
if not op.inputs:
+ # Remove any external control dependency on this op
control_inputs = [x for x in op.control_inputs
if x._get_control_flow_context() == self]
if len(control_inputs) != len(op.control_inputs):
@@ -1508,12 +1515,22 @@ class WhileContext(ControlFlowContext):
for x in op.outputs:
self._values.add(x.name)
else:
+ has_internal_data_input = False
for index in range(len(op.inputs)):
x = op.inputs[index]
self.AddValue(x)
real_x = self._external_values.get(x.name)
if real_x is not None:
op._update_input(index, real_x)
+ else:
+ has_internal_data_input = True
+ if not has_internal_data_input:
+ # Remove any external control dependency on this op
+ control_inputs = [x for x in op.control_inputs
+ if x._get_control_flow_context() == self]
+ if len(control_inputs) != len(op.control_inputs):
+ del op.control_inputs[:]
+ op._add_control_inputs(control_inputs)
# Add a control dependency to prevent loop invariants from
# enabling ops that should not be executed.
self._MaybeAddControlDependency(op)
@@ -1879,6 +1896,8 @@ class WhileContext(ControlFlowContext):
if isinstance(e, ops.Tensor):
xs = [e]
else:
+ if not isinstance(e, (ops.IndexedSlices, ops.SparseTensor)):
+ raise TypeError("Type %s not supported" % type(e))
xs = [e.values, e.indices]
shape = e.dense_shape if isinstance(e, ops.IndexedSlices) else e.shape
if shape is not None: