diff options
author | 2016-07-29 08:59:49 -0800 | |
---|---|---|
committer | 2016-07-29 10:03:18 -0700 | |
commit | f419dd1f11548386bacfb789a47d5ff6c13ae197 (patch) | |
tree | c5814169bef26b8bad3ff13da0fb096507d61a0d | |
parent | acd882cd838b59874cf880c718ce32aef33a3044 (diff) |
Fix a bug in handling while loop inside a "with tf.control_dependencies()". Also tidy up some error checking.
Change: 128818017
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 9 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 45 |
2 files changed, 41 insertions, 13 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index a8ae70e49b..00372831df 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -721,6 +721,15 @@ class ControlFlowTest(tf.test.TestCase): c = tf.while_loop(lambda x: x < 10, lambda x: x + 1, [c]) self.assertEqual(10, sess.run(c, {b: True})) + def testWhileWithControl_4(self): + with self.test_session() as sess: + b = tf.placeholder(tf.bool) + c = tf.constant(1) + x0 = tf.constant(0) + with tf.control_dependencies([b]): + r = tf.while_loop(lambda x: x < 10, lambda x: x + tf.identity(c), [x0]) + self.assertEqual(10, sess.run(r, {b: True})) + def testCondWhile_1(self): with self.test_session(): n = tf.convert_to_tensor(0, name="n") diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 18a7a20c11..eee3b3e2d4 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -118,6 +118,8 @@ def _Identity(data, name=None): else: return array_ops.identity(data, name=name) else: + if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)): + raise TypeError("Type %s not supported" % type(data)) values = _Identity(data.values, name=name) indices = array_ops.identity(data.indices, name="indices") if isinstance(data, ops.IndexedSlices): @@ -125,11 +127,9 @@ def _Identity(data, name=None): if dense_shape is not None: dense_shape = array_ops.identity(dense_shape, name="dense_shape") return ops.IndexedSlices(values, indices, dense_shape) - elif isinstance(data, ops.SparseTensor): + else: dense_shape = array_ops.identity(data.shape, name="dense_shape") return ops.SparseTensor(indices, values, dense_shape) - else: - raise TypeError("Type %s not supported" % type(data)) def _NextIteration(data, name=None): @@ -140,6 +140,8 @@ def _NextIteration(data, name=None): else: return next_iteration(data, name=name) else: + if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)): + raise TypeError("Type %s not supported" % type(data)) values = _NextIteration(data.values, name=name) indices = next_iteration(data.indices, name="indices") if isinstance(data, ops.IndexedSlices): @@ -147,11 +149,9 @@ def _NextIteration(data, name=None): if dense_shape is not None: dense_shape = next_iteration(dense_shape, name="dense_shape") return ops.IndexedSlices(values, indices, dense_shape) - elif isinstance(data, ops.SparseTensor): + else: dense_shape = next_iteration(data.shape, name="dense_shape") return ops.SparseTensor(indices, values, dense_shape) - else: - raise TypeError("Type %s not supported" % type(data)) def _Enter(data, frame_name, is_constant=False, parallel_iterations=10, @@ -183,6 +183,8 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10, return enter(data, frame_name, is_constant, parallel_iterations, name=name) else: + if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)): + raise TypeError("Type %s not supported" % type(data)) values = _Enter(data.values, frame_name, is_constant, parallel_iterations, name=name) indices = enter(data.indices, frame_name, is_constant, @@ -193,12 +195,10 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10, dense_shape = enter(dense_shape, frame_name, is_constant, parallel_iterations, name="dense_shape") return ops.IndexedSlices(values, indices, dense_shape) - elif isinstance(data, ops.SparseTensor): + else: dense_shape = enter(data.shape, frame_name, is_constant, parallel_iterations, name="dense_shape") return ops.SparseTensor(indices, values, dense_shape) - else: - raise TypeError("Type %s not supported" % type(data)) def exit(data, name=None): @@ -220,6 +220,8 @@ def exit(data, name=None): else: return gen_control_flow_ops._exit(data, name) else: + if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)): + raise TypeError("Type %s not supported" % type(data)) values = exit(data.values, name=name) indices = gen_control_flow_ops._exit(data.indices, name="indices") if isinstance(data, ops.IndexedSlices): @@ -227,11 +229,9 @@ def exit(data, name=None): if dense_shape is not None: dense_shape = gen_control_flow_ops._exit(dense_shape, name) return ops.IndexedSlices(values, indices, dense_shape) - elif isinstance(data, ops.SparseTensor): + else: dense_shape = gen_control_flow_ops._exit(data.shape, name) return ops.SparseTensor(indices, values, dense_shape) - else: - raise TypeError("Type %s not supported" % type(data)) def switch(data, pred, dtype=None, name=None): @@ -1493,8 +1493,15 @@ class WhileContext(ControlFlowContext): self._AddOpInternal(op) def _AddOpInternal(self, op): - """Add `op` to the current context.""" + """Add `op` to the current context. + + In the case that op has only external data inputs, we remove all of its + external control inputs so all its inputs are in the same while loop + context. This is valid because op now has an Enter input that has all + the right control dependency. + """ if not op.inputs: + # Remove any external control dependency on this op control_inputs = [x for x in op.control_inputs if x._get_control_flow_context() == self] if len(control_inputs) != len(op.control_inputs): @@ -1508,12 +1515,22 @@ class WhileContext(ControlFlowContext): for x in op.outputs: self._values.add(x.name) else: + has_internal_data_input = False for index in range(len(op.inputs)): x = op.inputs[index] self.AddValue(x) real_x = self._external_values.get(x.name) if real_x is not None: op._update_input(index, real_x) + else: + has_internal_data_input = True + if not has_internal_data_input: + # Remove any external control dependency on this op + control_inputs = [x for x in op.control_inputs + if x._get_control_flow_context() == self] + if len(control_inputs) != len(op.control_inputs): + del op.control_inputs[:] + op._add_control_inputs(control_inputs) # Add a control dependency to prevent loop invariants from # enabling ops that should not be executed. self._MaybeAddControlDependency(op) @@ -1879,6 +1896,8 @@ class WhileContext(ControlFlowContext): if isinstance(e, ops.Tensor): xs = [e] else: + if not isinstance(e, (ops.IndexedSlices, ops.SparseTensor)): + raise TypeError("Type %s not supported" % type(e)) xs = [e.values, e.indices] shape = e.dense_shape if isinstance(e, ops.IndexedSlices) else e.shape if shape is not None: |