aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Yuan Yu <yuanbyu@google.com>2016-05-31 18:52:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-31 20:04:08 -0700
commitb85dca5f2bcf3013d9a74c7a87a88ecdccb29b03 (patch)
tree3e9c89d49ae599d049435fc4e7e73bade771868e /tensorflow/python
parent25a5809660fc9edb6b462ea7d2189edfdd323d0c (diff)
Add IndexedSlices and SparseTensor support for control flow ops.
Change: 123710536
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/framework/ops.py8
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py32
-rw-r--r--tensorflow/python/ops/control_flow_grad.py31
-rw-r--r--tensorflow/python/ops/control_flow_ops.py345
4 files changed, 300 insertions, 116 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 0ee4e8cfbf..257fe8cf97 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -939,8 +939,9 @@ class SparseTensor(object):
@@__init__
@@indices
@@values
- @@dtype
@@shape
+ @@dtype
+ @@op
@@graph
"""
@@ -1004,6 +1005,11 @@ class SparseTensor(object):
return self._values
@property
+ def op(self):
+ """The `Operation` that produces `values` as an output."""
+ return self.values.op
+
+ @property
def dtype(self):
"""The `DType` of elements in this tensor."""
return self._values.dtype
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 74768d429d..6921ab2aa6 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1236,6 +1236,38 @@ class ControlFlowTest(tf.test.TestCase):
self.assertEqual(0, value_x)
self.assertEqual(73, value_x_grad)
+ def testWhileGrad_IndexedSlices(self):
+ with self.test_session():
+ values = tf.constant([2.0, 4.0], name="values")
+ indices = tf.constant([0, 3], name="indices")
+ shape = tf.constant([10], name="dense_shape")
+ i = tf.constant(0)
+ x = tf.IndexedSlices(values, indices, dense_shape=shape)
+ def c(i, _):
+ return i < 10
+ def b(i, x):
+ return [i + 1, tf.IndexedSlices(x.values * 2.0, x.indices,
+ x.dense_shape)]
+ _, r = tf.while_loop(c, b, [i, x])
+ r = tf.gradients(r.values, values)[0]
+ self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+
+ def testWhileGrad_SparseTensor(self):
+ with self.test_session():
+ values = tf.constant([2.0, 4.0], name="values")
+ indices = tf.constant([[0], [3]], dtype=tf.int64, name="indices")
+ shape = tf.constant([10], dtype=tf.int64, name="dense_shape")
+ i = tf.constant(0)
+ x = tf.SparseTensor(indices, values, shape=shape)
+ def c(i, _):
+ return i < 10
+ def b(i, x):
+ return [i + 1, tf.SparseTensor(x.indices, x.values * 2.0,
+ x.shape)]
+ _, r = tf.while_loop(c, b, [i, x])
+ r = tf.gradients(r.values, values)[0]
+ self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+
def testOneValueCond(self):
with self.test_session():
c = tf.placeholder(tf.int32, shape=[])
diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py
index 292e49593d..018de6ae11 100644
--- a/tensorflow/python/ops/control_flow_grad.py
+++ b/tensorflow/python/ops/control_flow_grad.py
@@ -43,25 +43,24 @@ def _SwitchGrad(op, *grad):
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
- merge_op = grad_ctxt.grad_state.switch_map.get(op)
- if merge_op:
+ merge_grad = grad_ctxt.grad_state.switch_map.get(op)
+ if merge_grad is not None:
# This is the second time this Switch is visited. It comes from
# the non-exit branch of the Switch, so update the second input
# to the Merge.
# TODO: Perform shape inference with this new input.
# pylint: disable=protected-access
- merge_op._update_input(1, control_flow_ops._NextIteration(grad[1]))
+ control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
# pylint: enable=protected-access
return None, None
else:
- # This is the first time this Switch is visited. It always comes
- # from the Exit branch, which is grad[0]. grad[1] is empty at this point.
+ # This is the first time this Switch is visited. It always comes from
+ # the Exit branch, which is grad[0]. grad[1] is empty at this point.
# Use grad[0] for both inputs to merge for now, but update the second
# input of merge when we see this Switch the second time.
- merge_fn = control_flow_ops._Merge # pylint: disable=protected-access
- merge_op = merge_fn([grad[0], grad[0]], name="b_switch")[0]
- grad_ctxt.grad_state.switch_map[op] = merge_op.op
- return merge_op, None
+ merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
+ grad_ctxt.grad_state.switch_map[op] = merge_grad
+ return merge_grad, None
elif isinstance(op_ctxt, CondContext):
good_grad = grad[op_ctxt.branch]
zero_grad = grad[1 - op_ctxt.branch]
@@ -140,7 +139,19 @@ def _ExitGrad(_, grad):
# computation for this loop. If the attribute `back_prop` is false,
# no gradient computation.
return None
- grad_ctxt.AddName(grad.name)
+ if isinstance(grad, ops.Tensor):
+ grad_ctxt.AddName(grad.name)
+ else:
+ if not isinstance(grad, (ops.IndexedSlices, ops.SparseTensor)):
+ raise TypeError("Type %s not supported" % type(grad))
+ grad_ctxt.AddName(grad.values.name)
+ grad_ctxt.AddName(grad.indices.name)
+ if isinstance(grad, ops.IndexedSlices):
+ dense_shape = grad.dense_shape
+ else:
+ dense_shape = grad.shape
+ if dense_shape is not None:
+ grad_ctxt.AddName(dense_shape.name)
enter_fn = control_flow_ops._Enter # pylint: disable=protected-access
grad_ctxt.Enter()
result = enter_fn(grad, grad_ctxt.name, is_constant=False,
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index ba2d31b024..e00fec07c1 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -110,24 +110,47 @@ def _Identity(data, name=None):
Returns:
A Tensor with the same type and value as the input Tensor.
"""
- if not data.dtype.is_ref_dtype:
- return array_ops.identity(data, name=name)
+ data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
+ if isinstance(data, ops.Tensor):
+ if data.dtype.is_ref_dtype:
+ return gen_array_ops._ref_identity(data, name=name)
+ else:
+ return array_ops.identity(data, name=name)
else:
- return gen_array_ops._ref_identity(data, name=name)
+ values = _Identity(data.values, name=name)
+ indices = array_ops.identity(data.indices, name="indices")
+ if isinstance(data, ops.IndexedSlices):
+ dense_shape = data.dense_shape
+ if dense_shape is not None:
+ dense_shape = array_ops.identity(dense_shape, name="dense_shape")
+ return ops.IndexedSlices(values, indices, dense_shape)
+ elif isinstance(data, ops.SparseTensor):
+ dense_shape = array_ops.identity(data.shape, name="dense_shape")
+ return ops.SparseTensor(indices, values, dense_shape)
+ else:
+ raise TypeError("Type %s not supported" % type(data))
def _NextIteration(data, name=None):
- if not data.dtype.is_ref_dtype:
- return next_iteration(data, name=name)
- else:
- return ref_next_iteration(data, name=name)
-
-
-def _Merge(values, name=None):
- if all([v.dtype.is_ref_dtype for v in values]):
- return gen_control_flow_ops._ref_merge(values, name)
+ data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
+ if isinstance(data, ops.Tensor):
+ if data.dtype.is_ref_dtype:
+ return ref_next_iteration(data, name=name)
+ else:
+ return next_iteration(data, name=name)
else:
- return gen_control_flow_ops._merge(values, name)
+ values = _NextIteration(data.values, name=name)
+ indices = next_iteration(data.indices, name="indices")
+ if isinstance(data, ops.IndexedSlices):
+ dense_shape = data.dense_shape
+ if dense_shape is not None:
+ dense_shape = next_iteration(dense_shape, name="dense_shape")
+ return ops.IndexedSlices(values, indices, dense_shape)
+ elif isinstance(data, ops.SparseTensor):
+ dense_shape = next_iteration(data.shape, name="dense_shape")
+ return ops.SparseTensor(indices, values, dense_shape)
+ else:
+ raise TypeError("Type %s not supported" % type(data))
def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
@@ -150,12 +173,31 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
Returns:
The same tensor as `data`.
"""
- if data.dtype.is_ref_dtype and use_ref:
- return ref_enter(data, frame_name, is_constant, parallel_iterations,
- name=name)
+ data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
+ if isinstance(data, ops.Tensor):
+ if data.dtype.is_ref_dtype and use_ref:
+ return ref_enter(data, frame_name, is_constant, parallel_iterations,
+ name=name)
+ else:
+ return enter(data, frame_name, is_constant, parallel_iterations,
+ name=name)
else:
- return enter(data, frame_name, is_constant, parallel_iterations,
- name=name)
+ values = _Enter(data.values, frame_name, is_constant,
+ parallel_iterations, name=name)
+ indices = enter(data.indices, frame_name, is_constant,
+ parallel_iterations, name="indices")
+ if isinstance(data, ops.IndexedSlices):
+ dense_shape = data.dense_shape
+ if dense_shape is not None:
+ dense_shape = enter(dense_shape, frame_name, is_constant,
+ parallel_iterations, name="dense_shape")
+ return ops.IndexedSlices(values, indices, dense_shape)
+ elif isinstance(data, ops.SparseTensor):
+ dense_shape = enter(data.shape, frame_name, is_constant,
+ parallel_iterations, name="dense_shape")
+ return ops.SparseTensor(indices, values, dense_shape)
+ else:
+ raise TypeError("Type %s not supported" % type(data))
def exit(data, name=None):
@@ -170,10 +212,25 @@ def exit(data, name=None):
Returns:
The same tensor as `data`.
"""
- if data.dtype.is_ref_dtype:
- return gen_control_flow_ops._ref_exit(data, name)
+ data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
+ if isinstance(data, ops.Tensor):
+ if data.dtype.is_ref_dtype:
+ return gen_control_flow_ops._ref_exit(data, name)
+ else:
+ return gen_control_flow_ops._exit(data, name)
else:
- return gen_control_flow_ops._exit(data, name)
+ values = exit(data.values, name=name)
+ indices = gen_control_flow_ops._exit(data.indices, name="indices")
+ if isinstance(data, ops.IndexedSlices):
+ dense_shape = data.dense_shape
+ if dense_shape is not None:
+ dense_shape = gen_control_flow_ops._exit(dense_shape, name)
+ return ops.IndexedSlices(values, indices, dense_shape)
+ elif isinstance(data, ops.SparseTensor):
+ dense_shape = gen_control_flow_ops._exit(data.shape, name)
+ return ops.SparseTensor(indices, values, dense_shape)
+ else:
+ raise TypeError("Type %s not supported" % type(data))
def switch(data, pred, dtype=None, name=None):
@@ -192,73 +249,36 @@ def switch(data, pred, dtype=None, name=None):
name: A name for this operation (optional).
Returns:
- `(output_false, output_true)`: If `pred` is true, data will be forwarded to
- `output_true`, otherwise it goes to `output_false`.
+ `(output_false, output_true)`: If `pred` is true, data will be forwarded
+ to `output_true`, otherwise it goes to `output_false`.
"""
with ops.op_scope([data, pred], name, "Switch") as name:
data = ops.convert_to_tensor_or_indexed_slices(data, dtype=dtype,
- name="data")
+ name="data", as_ref=True)
pred = ops.convert_to_tensor(pred, name="pred")
if isinstance(data, ops.Tensor):
return gen_control_flow_ops._switch(data, pred, name=name)
else:
- val, ind, dense_shape = data.values, data.indices, data.dense_shape
+ if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)):
+ raise TypeError("Type %s not supported" % type(data))
+ val, ind = data.values, data.indices
val_f, val_t = gen_control_flow_ops._switch(val, pred, name=name)
ind_f, ind_t = gen_control_flow_ops._switch(ind, pred, name="indices")
- if dense_shape is not None:
- dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(
- dense_shape, pred, name="dense_shape")
- else:
- dense_shape_f, dense_shape_t = None, None
- return (ops.IndexedSlices(val_f, ind_f, dense_shape_f),
- ops.IndexedSlices(val_t, ind_t, dense_shape_t))
-
-
-def merge(inputs, name=None):
- """Returns the value of an available element of `inputs`.
-
- This op tests each of the tensors in `inputs` in turn to determine if any of
- them is available. If it finds an available tensor, it returns it and its
- index in `inputs`.
-
- It is an error if more than one tensor in `inputs` is available. If no tensor
- in `inputs` is available, the returned tensor and index are not set.
-
- This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
- `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
- before merging.
-
- Args:
- inputs: The input tensors, at most one of which is available.
- name: A name for this operation (optional).
-
- Returns:
- A tuple containing the chosen input tensor and its index in `inputs`.
-
- Raises:
- ValueError: If inputs are IndexedSlices and some but not all have a
- dense_shape property.
- """
- with ops.op_scope(inputs, name, "Merge") as name:
- inputs = [ops.convert_to_tensor_or_indexed_slices(inp)
- for inp in inputs]
- if all([isinstance(inp, ops.Tensor) for inp in inputs]):
- return _Merge(inputs, name=name)
- else:
- inputs = math_ops._as_indexed_slices_list(inputs)
- values, _ = _Merge([inp.values for inp in inputs], name=name)
- indices, chosen_index = _Merge(
- [inp.indices for inp in inputs], name="indices")
- if any(inp.dense_shape is not None for inp in inputs):
- if any(inp.dense_shape is None for inp in inputs):
- raise ValueError("Either all merged IndexedSlices must have a "
- "dense_shape, or none must have a dense_shape.")
- dense_shape, _ = _Merge(
- [inp.dense_shape for inp in inputs], name="dense_shape")
+ if isinstance(data, ops.IndexedSlices):
+ dense_shape = data.dense_shape
+ if dense_shape is not None:
+ dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(
+ dense_shape, pred, name="dense_shape")
+ else:
+ dense_shape_f, dense_shape_t = None, None
+ return (ops.IndexedSlices(val_f, ind_f, dense_shape_f),
+ ops.IndexedSlices(val_t, ind_t, dense_shape_t))
else:
- dense_shape = None
- return ops.IndexedSlices(values, indices, dense_shape), chosen_index
-# pylint: enable=protected-access
+ dense_shape = data.shape
+ dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(
+ data.shape, pred, name="dense_shape")
+ return (ops.SparseTensor(ind_f, val_f, dense_shape_f),
+ ops.SparseTensor(ind_t, val_t, dense_shape_t))
def _SwitchRefOrTensor(data, pred, name="Switch"):
@@ -300,12 +320,68 @@ def _SwitchRefOrTensor(data, pred, name="Switch"):
# created within ops.colocate_with(data) to ignore the existing stack.
with ops.colocate_with(data, ignore_existing=True):
if isinstance(data, ops.Tensor):
- if not data.dtype.is_ref_dtype:
- return switch(data, pred, name=name)
- else:
+ if data.dtype.is_ref_dtype:
return ref_switch(data, pred, name=name)
+ return switch(data, pred, name=name)
+
+
+def merge(inputs, name=None):
+ """Returns the value of an available element of `inputs`.
+
+ This op tests each of the tensors in `inputs` in turn to determine if any of
+ them is available. If it finds an available tensor, it returns it and its
+ index in `inputs`.
+
+ It is an error if more than one tensor in `inputs` is available. If no tensor
+ in `inputs` is available, the returned tensor and index are not set.
+
+ This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
+ `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
+ before merging.
+
+ Args:
+ inputs: The input tensors, at most one of which is available.
+ name: A name for this operation (optional).
+
+ Returns:
+ A tuple containing the chosen input tensor and its index in `inputs`.
+
+ Raises:
+ ValueError: If inputs are IndexedSlices and some but not all have a
+ dense_shape property.
+ """
+ with ops.op_scope(inputs, name, "Merge") as name:
+ inputs = [ops.convert_to_tensor_or_indexed_slices(inp, as_ref=True)
+ for inp in inputs]
+ if all([isinstance(v, ops.Tensor) for v in inputs]):
+ if all([v.dtype.is_ref_dtype for v in inputs]):
+ return gen_control_flow_ops._ref_merge(inputs, name)
+ else:
+ return gen_control_flow_ops._merge(inputs, name)
+ elif all([isinstance(v, ops.SparseTensor) for v in inputs]):
+ # Only handle the case when all inputs are SparseTensor.
+ values, _ = merge([inp.values for inp in inputs], name=name)
+ indices, chosen_index = gen_control_flow_ops._merge(
+ [inp.indices for inp in inputs], name="indices")
+ dense_shape, _ = gen_control_flow_ops._merge(
+ [inp.shape for inp in inputs], name="dense_shape")
+ return ops.SparseTensor(indices, values, dense_shape), chosen_index
else:
- return switch(data, pred, name=name)
+ # For now convert all the inputs as IndexedSlices.
+ inputs = math_ops._as_indexed_slices_list(inputs)
+ values, _ = merge([inp.values for inp in inputs], name=name)
+ indices, chosen_index = gen_control_flow_ops._merge(
+ [inp.indices for inp in inputs], name="indices")
+ if any(inp.dense_shape is not None for inp in inputs):
+ if any(inp.dense_shape is None for inp in inputs):
+ raise ValueError("Either all merged IndexedSlices must have a "
+ "dense_shape, or none must have a dense_shape.")
+ dense_shape, _ = gen_control_flow_ops._merge(
+ [inp.dense_shape for inp in inputs], name="dense_shape")
+ else:
+ dense_shape = None
+ return ops.IndexedSlices(values, indices, dense_shape), chosen_index
+# pylint: enable=protected-access
def _convert_tensorarrays_to_flows(tensors_or_tensor_arrays):
@@ -365,6 +441,37 @@ def _ShapeIntersection(shape1, shape2):
return tensor_shape.TensorShape(rdims)
+def _AddNextAndBackEdge(m, v):
+ """Add NextIteration and back edge from v to m."""
+ if isinstance(m, ops.Tensor):
+ v = ops.convert_to_tensor(v)
+ v = _NextIteration(v)
+ m.op._update_input(1, v) # pylint: disable=protected-access
+ elif isinstance(m, ops.IndexedSlices):
+ # pylint: disable=protected-access
+ v = math_ops._as_indexed_slices(v)
+ v = _NextIteration(v)
+ m.values.op._update_input(1, v.values)
+ m.indices.op._update_input(1, v.indices)
+ # pylint: enable=protected-access
+ if m.dense_shape is not None:
+ if v.dense_shape is None:
+ raise ValueError("Must have dense shape: %s" % v.name)
+ m.dense_shape.op._update_input(1, v.dense_shape)
+ elif isinstance(m, ops.SparseTensor):
+ if not isinstance(v, ops.SparseTensor):
+ raise ValueError("Must be a sparse tensor: %s" % v.name)
+ v = _NextIteration(v)
+ # pylint: disable=protected-access
+ m.values.op._update_input(1, v.values)
+ m.indices.op._update_input(1, v.indices)
+ m.shape.op._update_input(1, v.shape)
+ # pylint: enable=protected-access
+ else:
+ raise TypeError("Type %s not supported" % type(m))
+ return v
+
+
class GradLoopState(object):
"""The state used for constructing the gradient graph for a while loop.
@@ -877,12 +984,12 @@ class ControlFlowState(object):
"""
for _, grad_state in self._map.items():
for _, b_merge in grad_state.switch_map.items():
- if b_merge.inputs[0] == b_merge.inputs[1]:
+ if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
# The value of this loop variable at iteration i+1 doesn't
# depend on its value at iteration i. So use zeros as the
# gradients for all iterations > 0.
- dtype = b_merge.inputs[0].dtype
- shape = b_merge.inputs[0].get_shape()
+ dtype = b_merge.op.inputs[0].dtype
+ shape = b_merge.op.inputs[0].get_shape()
if not shape.is_fully_defined():
shape = None
grad_state.grad_context.Enter()
@@ -891,8 +998,8 @@ class ControlFlowState(object):
grad_state.grad_context.Exit()
# pylint: disable=protected-access
if not shape:
- grad_val._shape = b_merge.inputs[0].get_shape()
- b_merge._update_input(1, grad_val)
+ grad_val._shape = b_merge.op.inputs[0].get_shape()
+ b_merge.op._update_input(1, grad_val)
# pylint: enable=protected-access
@@ -1578,6 +1685,23 @@ class WhileContext(ControlFlowContext):
return ops.IndexedSlices(values=acc_result[1], indices=acc_result[0],
dense_shape=self.ExitResult(value.dense_shape))
+ def _InitializeValues(self, values):
+ self._values = set()
+ for x in values:
+ if isinstance(x, ops.Tensor):
+ self._values.add(x.name)
+ else:
+ self._values.add(x.values.name)
+ self._values.add(x.indices.name)
+ if isinstance(x, ops.IndexedSlices):
+ dense_shape = x.dense_shape
+ elif isinstance(x, ops.SparseTensor):
+ dense_shape = x.shape
+ else:
+ raise TypeError("Type %s not supported" % type(x))
+ if dense_shape is not None:
+ self._values.add(dense_shape.name)
+
def BuildLoop(self, pred, body, loop_vars):
"""Add the loop termination condition and body to the graph."""
@@ -1588,7 +1712,7 @@ class WhileContext(ControlFlowContext):
loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
# Let the context know the loop variabes so the loop variables
# would be added in the outer contexts properly.
- self._values = set([x.name for x in loop_vars])
+ self._InitializeValues(loop_vars)
real_vars = loop_vars
if self._outer_context:
real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
@@ -1598,7 +1722,7 @@ class WhileContext(ControlFlowContext):
for x in real_vars]
# Fix the control inputs and control flow context of these enter ops.
self._FixControlInputsAndContext(enter_vars)
- self._values = set([x.name for x in enter_vars])
+ self._InitializeValues(enter_vars)
merge_vars = [merge([x, x])[0] for x in enter_vars]
self._pivot_for_pred = merge_vars[0]
@@ -1626,21 +1750,23 @@ class WhileContext(ControlFlowContext):
# Convert TensorArrays returned by body into their flow variables
result = _convert_tensorarrays_to_flows(body_result)
result = ops.convert_n_to_tensor_or_indexed_slices(result)
- next_vars = [_NextIteration(x) for x in result]
- # Add the back edges to complete the loop.
- if len(merge_vars) != len(next_vars):
+ # Add NextIteration and the back edges to complete the loop.
+ if len(merge_vars) != len(result):
raise ValueError("Number of inputs and outputs of body must match "
- "loop_vars: %d, %d" % (len(merge_vars), len(next_vars)))
- for x in zip(merge_vars, next_vars):
- x[0].op._update_input(1, x[1])
+ "loop_vars: %d, %d" % (len(merge_vars), len(result)))
+ next_vars = []
+ for m, v in zip(merge_vars, result):
+ next_vars.append(_AddNextAndBackEdge(m, v))
# Add the exit ops.
exit_vars = [exit(x[0]) for x in switch_vars]
self._loop_exits = exit_vars
+ # Make sure the shapes of loop outputs are correct.
for m_var, n_var, e_var in zip(merge_vars, next_vars, exit_vars):
- e_var._shape = _ShapeIntersection(m_var.get_shape(), n_var.get_shape())
+ if isinstance(m_var, ops.Tensor):
+ e_var._shape = _ShapeIntersection(m_var.get_shape(), n_var.get_shape())
# Exit the loop.
self.ExitResult(exit_vars)
@@ -1653,16 +1779,25 @@ class WhileContext(ControlFlowContext):
if len(exit_vars) == 1
else exit_vars_with_tensor_arrays)
- def _FixControlInputsAndContext(self, input_tensors):
- # pylint: disable=protected-access
+ def _FixControlInputsAndContext(self, enters):
graph = ops.get_default_graph()
- control_inputs = graph._control_dependencies_for_inputs(input_tensors)
- control_inputs = [op for op in control_inputs
- if op._get_control_flow_context() != self]
- for x in input_tensors:
- x.op._set_control_flow_context(self)
- x.op._add_control_inputs(control_inputs)
- graph._record_op_seen_by_control_dependencies(x.op)
+ # pylint: disable=protected-access
+ for e in enters:
+ if isinstance(e, ops.Tensor):
+ xs = [e]
+ else:
+ xs = [e.values, e.indices]
+ shape = e.dense_shape if isinstance(e, ops.IndexedSlices) else e.shape
+ if shape is not None:
+ xs.append(shape)
+ for x in xs:
+ inp_op = x.op.inputs[0]
+ control_inputs = graph._control_dependencies_for_inputs([inp_op])
+ control_inputs = [op for op in control_inputs
+ if op._get_control_flow_context() != self]
+ x.op._set_control_flow_context(self)
+ x.op._add_control_inputs(control_inputs)
+ graph._record_op_seen_by_control_dependencies(x.op)
# pylint: enable=protected-access