diff options
author | Yuan Yu <yuanbyu@google.com> | 2016-05-31 18:52:22 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-05-31 20:04:08 -0700 |
commit | b85dca5f2bcf3013d9a74c7a87a88ecdccb29b03 (patch) | |
tree | 3e9c89d49ae599d049435fc4e7e73bade771868e /tensorflow | |
parent | 25a5809660fc9edb6b462ea7d2189edfdd323d0c (diff) |
Add IndexedSlices and SparseTensor support for control flow ops.
Change: 123710536
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/framework/ops.py | 8 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 32 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_grad.py | 31 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 345 |
4 files changed, 300 insertions, 116 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 0ee4e8cfbf..257fe8cf97 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -939,8 +939,9 @@ class SparseTensor(object): @@__init__ @@indices @@values - @@dtype @@shape + @@dtype + @@op @@graph """ @@ -1004,6 +1005,11 @@ class SparseTensor(object): return self._values @property + def op(self): + """The `Operation` that produces `values` as an output.""" + return self.values.op + + @property def dtype(self): """The `DType` of elements in this tensor.""" return self._values.dtype diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 74768d429d..6921ab2aa6 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1236,6 +1236,38 @@ class ControlFlowTest(tf.test.TestCase): self.assertEqual(0, value_x) self.assertEqual(73, value_x_grad) + def testWhileGrad_IndexedSlices(self): + with self.test_session(): + values = tf.constant([2.0, 4.0], name="values") + indices = tf.constant([0, 3], name="indices") + shape = tf.constant([10], name="dense_shape") + i = tf.constant(0) + x = tf.IndexedSlices(values, indices, dense_shape=shape) + def c(i, _): + return i < 10 + def b(i, x): + return [i + 1, tf.IndexedSlices(x.values * 2.0, x.indices, + x.dense_shape)] + _, r = tf.while_loop(c, b, [i, x]) + r = tf.gradients(r.values, values)[0] + self.assertAllClose(np.array([1024.0, 1024.0]), r.eval()) + + def testWhileGrad_SparseTensor(self): + with self.test_session(): + values = tf.constant([2.0, 4.0], name="values") + indices = tf.constant([[0], [3]], dtype=tf.int64, name="indices") + shape = tf.constant([10], dtype=tf.int64, name="dense_shape") + i = tf.constant(0) + x = tf.SparseTensor(indices, values, shape=shape) + def c(i, _): + return i < 10 + def b(i, x): + return [i + 1, tf.SparseTensor(x.indices, x.values * 2.0, + x.shape)] + _, r = tf.while_loop(c, b, [i, x]) + r = tf.gradients(r.values, values)[0] + self.assertAllClose(np.array([1024.0, 1024.0]), r.eval()) + def testOneValueCond(self): with self.test_session(): c = tf.placeholder(tf.int32, shape=[]) diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py index 292e49593d..018de6ae11 100644 --- a/tensorflow/python/ops/control_flow_grad.py +++ b/tensorflow/python/ops/control_flow_grad.py @@ -43,25 +43,24 @@ def _SwitchGrad(op, *grad): grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access if isinstance(op_ctxt, WhileContext): - merge_op = grad_ctxt.grad_state.switch_map.get(op) - if merge_op: + merge_grad = grad_ctxt.grad_state.switch_map.get(op) + if merge_grad is not None: # This is the second time this Switch is visited. It comes from # the non-exit branch of the Switch, so update the second input # to the Merge. # TODO: Perform shape inference with this new input. # pylint: disable=protected-access - merge_op._update_input(1, control_flow_ops._NextIteration(grad[1])) + control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1]) # pylint: enable=protected-access return None, None else: - # This is the first time this Switch is visited. It always comes - # from the Exit branch, which is grad[0]. grad[1] is empty at this point. + # This is the first time this Switch is visited. It always comes from + # the Exit branch, which is grad[0]. grad[1] is empty at this point. # Use grad[0] for both inputs to merge for now, but update the second # input of merge when we see this Switch the second time. - merge_fn = control_flow_ops._Merge # pylint: disable=protected-access - merge_op = merge_fn([grad[0], grad[0]], name="b_switch")[0] - grad_ctxt.grad_state.switch_map[op] = merge_op.op - return merge_op, None + merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] + grad_ctxt.grad_state.switch_map[op] = merge_grad + return merge_grad, None elif isinstance(op_ctxt, CondContext): good_grad = grad[op_ctxt.branch] zero_grad = grad[1 - op_ctxt.branch] @@ -140,7 +139,19 @@ def _ExitGrad(_, grad): # computation for this loop. If the attribute `back_prop` is false, # no gradient computation. return None - grad_ctxt.AddName(grad.name) + if isinstance(grad, ops.Tensor): + grad_ctxt.AddName(grad.name) + else: + if not isinstance(grad, (ops.IndexedSlices, ops.SparseTensor)): + raise TypeError("Type %s not supported" % type(grad)) + grad_ctxt.AddName(grad.values.name) + grad_ctxt.AddName(grad.indices.name) + if isinstance(grad, ops.IndexedSlices): + dense_shape = grad.dense_shape + else: + dense_shape = grad.shape + if dense_shape is not None: + grad_ctxt.AddName(dense_shape.name) enter_fn = control_flow_ops._Enter # pylint: disable=protected-access grad_ctxt.Enter() result = enter_fn(grad, grad_ctxt.name, is_constant=False, diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index ba2d31b024..e00fec07c1 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -110,24 +110,47 @@ def _Identity(data, name=None): Returns: A Tensor with the same type and value as the input Tensor. """ - if not data.dtype.is_ref_dtype: - return array_ops.identity(data, name=name) + data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True) + if isinstance(data, ops.Tensor): + if data.dtype.is_ref_dtype: + return gen_array_ops._ref_identity(data, name=name) + else: + return array_ops.identity(data, name=name) else: - return gen_array_ops._ref_identity(data, name=name) + values = _Identity(data.values, name=name) + indices = array_ops.identity(data.indices, name="indices") + if isinstance(data, ops.IndexedSlices): + dense_shape = data.dense_shape + if dense_shape is not None: + dense_shape = array_ops.identity(dense_shape, name="dense_shape") + return ops.IndexedSlices(values, indices, dense_shape) + elif isinstance(data, ops.SparseTensor): + dense_shape = array_ops.identity(data.shape, name="dense_shape") + return ops.SparseTensor(indices, values, dense_shape) + else: + raise TypeError("Type %s not supported" % type(data)) def _NextIteration(data, name=None): - if not data.dtype.is_ref_dtype: - return next_iteration(data, name=name) - else: - return ref_next_iteration(data, name=name) - - -def _Merge(values, name=None): - if all([v.dtype.is_ref_dtype for v in values]): - return gen_control_flow_ops._ref_merge(values, name) + data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True) + if isinstance(data, ops.Tensor): + if data.dtype.is_ref_dtype: + return ref_next_iteration(data, name=name) + else: + return next_iteration(data, name=name) else: - return gen_control_flow_ops._merge(values, name) + values = _NextIteration(data.values, name=name) + indices = next_iteration(data.indices, name="indices") + if isinstance(data, ops.IndexedSlices): + dense_shape = data.dense_shape + if dense_shape is not None: + dense_shape = next_iteration(dense_shape, name="dense_shape") + return ops.IndexedSlices(values, indices, dense_shape) + elif isinstance(data, ops.SparseTensor): + dense_shape = next_iteration(data.shape, name="dense_shape") + return ops.SparseTensor(indices, values, dense_shape) + else: + raise TypeError("Type %s not supported" % type(data)) def _Enter(data, frame_name, is_constant=False, parallel_iterations=10, @@ -150,12 +173,31 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10, Returns: The same tensor as `data`. """ - if data.dtype.is_ref_dtype and use_ref: - return ref_enter(data, frame_name, is_constant, parallel_iterations, - name=name) + data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True) + if isinstance(data, ops.Tensor): + if data.dtype.is_ref_dtype and use_ref: + return ref_enter(data, frame_name, is_constant, parallel_iterations, + name=name) + else: + return enter(data, frame_name, is_constant, parallel_iterations, + name=name) else: - return enter(data, frame_name, is_constant, parallel_iterations, - name=name) + values = _Enter(data.values, frame_name, is_constant, + parallel_iterations, name=name) + indices = enter(data.indices, frame_name, is_constant, + parallel_iterations, name="indices") + if isinstance(data, ops.IndexedSlices): + dense_shape = data.dense_shape + if dense_shape is not None: + dense_shape = enter(dense_shape, frame_name, is_constant, + parallel_iterations, name="dense_shape") + return ops.IndexedSlices(values, indices, dense_shape) + elif isinstance(data, ops.SparseTensor): + dense_shape = enter(data.shape, frame_name, is_constant, + parallel_iterations, name="dense_shape") + return ops.SparseTensor(indices, values, dense_shape) + else: + raise TypeError("Type %s not supported" % type(data)) def exit(data, name=None): @@ -170,10 +212,25 @@ def exit(data, name=None): Returns: The same tensor as `data`. """ - if data.dtype.is_ref_dtype: - return gen_control_flow_ops._ref_exit(data, name) + data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True) + if isinstance(data, ops.Tensor): + if data.dtype.is_ref_dtype: + return gen_control_flow_ops._ref_exit(data, name) + else: + return gen_control_flow_ops._exit(data, name) else: - return gen_control_flow_ops._exit(data, name) + values = exit(data.values, name=name) + indices = gen_control_flow_ops._exit(data.indices, name="indices") + if isinstance(data, ops.IndexedSlices): + dense_shape = data.dense_shape + if dense_shape is not None: + dense_shape = gen_control_flow_ops._exit(dense_shape, name) + return ops.IndexedSlices(values, indices, dense_shape) + elif isinstance(data, ops.SparseTensor): + dense_shape = gen_control_flow_ops._exit(data.shape, name) + return ops.SparseTensor(indices, values, dense_shape) + else: + raise TypeError("Type %s not supported" % type(data)) def switch(data, pred, dtype=None, name=None): @@ -192,73 +249,36 @@ def switch(data, pred, dtype=None, name=None): name: A name for this operation (optional). Returns: - `(output_false, output_true)`: If `pred` is true, data will be forwarded to - `output_true`, otherwise it goes to `output_false`. + `(output_false, output_true)`: If `pred` is true, data will be forwarded + to `output_true`, otherwise it goes to `output_false`. """ with ops.op_scope([data, pred], name, "Switch") as name: data = ops.convert_to_tensor_or_indexed_slices(data, dtype=dtype, - name="data") + name="data", as_ref=True) pred = ops.convert_to_tensor(pred, name="pred") if isinstance(data, ops.Tensor): return gen_control_flow_ops._switch(data, pred, name=name) else: - val, ind, dense_shape = data.values, data.indices, data.dense_shape + if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)): + raise TypeError("Type %s not supported" % type(data)) + val, ind = data.values, data.indices val_f, val_t = gen_control_flow_ops._switch(val, pred, name=name) ind_f, ind_t = gen_control_flow_ops._switch(ind, pred, name="indices") - if dense_shape is not None: - dense_shape_f, dense_shape_t = gen_control_flow_ops._switch( - dense_shape, pred, name="dense_shape") - else: - dense_shape_f, dense_shape_t = None, None - return (ops.IndexedSlices(val_f, ind_f, dense_shape_f), - ops.IndexedSlices(val_t, ind_t, dense_shape_t)) - - -def merge(inputs, name=None): - """Returns the value of an available element of `inputs`. - - This op tests each of the tensors in `inputs` in turn to determine if any of - them is available. If it finds an available tensor, it returns it and its - index in `inputs`. - - It is an error if more than one tensor in `inputs` is available. If no tensor - in `inputs` is available, the returned tensor and index are not set. - - This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of - `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices - before merging. - - Args: - inputs: The input tensors, at most one of which is available. - name: A name for this operation (optional). - - Returns: - A tuple containing the chosen input tensor and its index in `inputs`. - - Raises: - ValueError: If inputs are IndexedSlices and some but not all have a - dense_shape property. - """ - with ops.op_scope(inputs, name, "Merge") as name: - inputs = [ops.convert_to_tensor_or_indexed_slices(inp) - for inp in inputs] - if all([isinstance(inp, ops.Tensor) for inp in inputs]): - return _Merge(inputs, name=name) - else: - inputs = math_ops._as_indexed_slices_list(inputs) - values, _ = _Merge([inp.values for inp in inputs], name=name) - indices, chosen_index = _Merge( - [inp.indices for inp in inputs], name="indices") - if any(inp.dense_shape is not None for inp in inputs): - if any(inp.dense_shape is None for inp in inputs): - raise ValueError("Either all merged IndexedSlices must have a " - "dense_shape, or none must have a dense_shape.") - dense_shape, _ = _Merge( - [inp.dense_shape for inp in inputs], name="dense_shape") + if isinstance(data, ops.IndexedSlices): + dense_shape = data.dense_shape + if dense_shape is not None: + dense_shape_f, dense_shape_t = gen_control_flow_ops._switch( + dense_shape, pred, name="dense_shape") + else: + dense_shape_f, dense_shape_t = None, None + return (ops.IndexedSlices(val_f, ind_f, dense_shape_f), + ops.IndexedSlices(val_t, ind_t, dense_shape_t)) else: - dense_shape = None - return ops.IndexedSlices(values, indices, dense_shape), chosen_index -# pylint: enable=protected-access + dense_shape = data.shape + dense_shape_f, dense_shape_t = gen_control_flow_ops._switch( + data.shape, pred, name="dense_shape") + return (ops.SparseTensor(ind_f, val_f, dense_shape_f), + ops.SparseTensor(ind_t, val_t, dense_shape_t)) def _SwitchRefOrTensor(data, pred, name="Switch"): @@ -300,12 +320,68 @@ def _SwitchRefOrTensor(data, pred, name="Switch"): # created within ops.colocate_with(data) to ignore the existing stack. with ops.colocate_with(data, ignore_existing=True): if isinstance(data, ops.Tensor): - if not data.dtype.is_ref_dtype: - return switch(data, pred, name=name) - else: + if data.dtype.is_ref_dtype: return ref_switch(data, pred, name=name) + return switch(data, pred, name=name) + + +def merge(inputs, name=None): + """Returns the value of an available element of `inputs`. + + This op tests each of the tensors in `inputs` in turn to determine if any of + them is available. If it finds an available tensor, it returns it and its + index in `inputs`. + + It is an error if more than one tensor in `inputs` is available. If no tensor + in `inputs` is available, the returned tensor and index are not set. + + This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of + `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices + before merging. + + Args: + inputs: The input tensors, at most one of which is available. + name: A name for this operation (optional). + + Returns: + A tuple containing the chosen input tensor and its index in `inputs`. + + Raises: + ValueError: If inputs are IndexedSlices and some but not all have a + dense_shape property. + """ + with ops.op_scope(inputs, name, "Merge") as name: + inputs = [ops.convert_to_tensor_or_indexed_slices(inp, as_ref=True) + for inp in inputs] + if all([isinstance(v, ops.Tensor) for v in inputs]): + if all([v.dtype.is_ref_dtype for v in inputs]): + return gen_control_flow_ops._ref_merge(inputs, name) + else: + return gen_control_flow_ops._merge(inputs, name) + elif all([isinstance(v, ops.SparseTensor) for v in inputs]): + # Only handle the case when all inputs are SparseTensor. + values, _ = merge([inp.values for inp in inputs], name=name) + indices, chosen_index = gen_control_flow_ops._merge( + [inp.indices for inp in inputs], name="indices") + dense_shape, _ = gen_control_flow_ops._merge( + [inp.shape for inp in inputs], name="dense_shape") + return ops.SparseTensor(indices, values, dense_shape), chosen_index else: - return switch(data, pred, name=name) + # For now convert all the inputs as IndexedSlices. + inputs = math_ops._as_indexed_slices_list(inputs) + values, _ = merge([inp.values for inp in inputs], name=name) + indices, chosen_index = gen_control_flow_ops._merge( + [inp.indices for inp in inputs], name="indices") + if any(inp.dense_shape is not None for inp in inputs): + if any(inp.dense_shape is None for inp in inputs): + raise ValueError("Either all merged IndexedSlices must have a " + "dense_shape, or none must have a dense_shape.") + dense_shape, _ = gen_control_flow_ops._merge( + [inp.dense_shape for inp in inputs], name="dense_shape") + else: + dense_shape = None + return ops.IndexedSlices(values, indices, dense_shape), chosen_index +# pylint: enable=protected-access def _convert_tensorarrays_to_flows(tensors_or_tensor_arrays): @@ -365,6 +441,37 @@ def _ShapeIntersection(shape1, shape2): return tensor_shape.TensorShape(rdims) +def _AddNextAndBackEdge(m, v): + """Add NextIteration and back edge from v to m.""" + if isinstance(m, ops.Tensor): + v = ops.convert_to_tensor(v) + v = _NextIteration(v) + m.op._update_input(1, v) # pylint: disable=protected-access + elif isinstance(m, ops.IndexedSlices): + # pylint: disable=protected-access + v = math_ops._as_indexed_slices(v) + v = _NextIteration(v) + m.values.op._update_input(1, v.values) + m.indices.op._update_input(1, v.indices) + # pylint: enable=protected-access + if m.dense_shape is not None: + if v.dense_shape is None: + raise ValueError("Must have dense shape: %s" % v.name) + m.dense_shape.op._update_input(1, v.dense_shape) + elif isinstance(m, ops.SparseTensor): + if not isinstance(v, ops.SparseTensor): + raise ValueError("Must be a sparse tensor: %s" % v.name) + v = _NextIteration(v) + # pylint: disable=protected-access + m.values.op._update_input(1, v.values) + m.indices.op._update_input(1, v.indices) + m.shape.op._update_input(1, v.shape) + # pylint: enable=protected-access + else: + raise TypeError("Type %s not supported" % type(m)) + return v + + class GradLoopState(object): """The state used for constructing the gradient graph for a while loop. @@ -877,12 +984,12 @@ class ControlFlowState(object): """ for _, grad_state in self._map.items(): for _, b_merge in grad_state.switch_map.items(): - if b_merge.inputs[0] == b_merge.inputs[1]: + if b_merge.op.inputs[0] == b_merge.op.inputs[1]: # The value of this loop variable at iteration i+1 doesn't # depend on its value at iteration i. So use zeros as the # gradients for all iterations > 0. - dtype = b_merge.inputs[0].dtype - shape = b_merge.inputs[0].get_shape() + dtype = b_merge.op.inputs[0].dtype + shape = b_merge.op.inputs[0].get_shape() if not shape.is_fully_defined(): shape = None grad_state.grad_context.Enter() @@ -891,8 +998,8 @@ class ControlFlowState(object): grad_state.grad_context.Exit() # pylint: disable=protected-access if not shape: - grad_val._shape = b_merge.inputs[0].get_shape() - b_merge._update_input(1, grad_val) + grad_val._shape = b_merge.op.inputs[0].get_shape() + b_merge.op._update_input(1, grad_val) # pylint: enable=protected-access @@ -1578,6 +1685,23 @@ class WhileContext(ControlFlowContext): return ops.IndexedSlices(values=acc_result[1], indices=acc_result[0], dense_shape=self.ExitResult(value.dense_shape)) + def _InitializeValues(self, values): + self._values = set() + for x in values: + if isinstance(x, ops.Tensor): + self._values.add(x.name) + else: + self._values.add(x.values.name) + self._values.add(x.indices.name) + if isinstance(x, ops.IndexedSlices): + dense_shape = x.dense_shape + elif isinstance(x, ops.SparseTensor): + dense_shape = x.shape + else: + raise TypeError("Type %s not supported" % type(x)) + if dense_shape is not None: + self._values.add(dense_shape.name) + def BuildLoop(self, pred, body, loop_vars): """Add the loop termination condition and body to the graph.""" @@ -1588,7 +1712,7 @@ class WhileContext(ControlFlowContext): loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars) # Let the context know the loop variabes so the loop variables # would be added in the outer contexts properly. - self._values = set([x.name for x in loop_vars]) + self._InitializeValues(loop_vars) real_vars = loop_vars if self._outer_context: real_vars = [self._outer_context.AddValue(x) for x in loop_vars] @@ -1598,7 +1722,7 @@ class WhileContext(ControlFlowContext): for x in real_vars] # Fix the control inputs and control flow context of these enter ops. self._FixControlInputsAndContext(enter_vars) - self._values = set([x.name for x in enter_vars]) + self._InitializeValues(enter_vars) merge_vars = [merge([x, x])[0] for x in enter_vars] self._pivot_for_pred = merge_vars[0] @@ -1626,21 +1750,23 @@ class WhileContext(ControlFlowContext): # Convert TensorArrays returned by body into their flow variables result = _convert_tensorarrays_to_flows(body_result) result = ops.convert_n_to_tensor_or_indexed_slices(result) - next_vars = [_NextIteration(x) for x in result] - # Add the back edges to complete the loop. - if len(merge_vars) != len(next_vars): + # Add NextIteration and the back edges to complete the loop. + if len(merge_vars) != len(result): raise ValueError("Number of inputs and outputs of body must match " - "loop_vars: %d, %d" % (len(merge_vars), len(next_vars))) - for x in zip(merge_vars, next_vars): - x[0].op._update_input(1, x[1]) + "loop_vars: %d, %d" % (len(merge_vars), len(result))) + next_vars = [] + for m, v in zip(merge_vars, result): + next_vars.append(_AddNextAndBackEdge(m, v)) # Add the exit ops. exit_vars = [exit(x[0]) for x in switch_vars] self._loop_exits = exit_vars + # Make sure the shapes of loop outputs are correct. for m_var, n_var, e_var in zip(merge_vars, next_vars, exit_vars): - e_var._shape = _ShapeIntersection(m_var.get_shape(), n_var.get_shape()) + if isinstance(m_var, ops.Tensor): + e_var._shape = _ShapeIntersection(m_var.get_shape(), n_var.get_shape()) # Exit the loop. self.ExitResult(exit_vars) @@ -1653,16 +1779,25 @@ class WhileContext(ControlFlowContext): if len(exit_vars) == 1 else exit_vars_with_tensor_arrays) - def _FixControlInputsAndContext(self, input_tensors): - # pylint: disable=protected-access + def _FixControlInputsAndContext(self, enters): graph = ops.get_default_graph() - control_inputs = graph._control_dependencies_for_inputs(input_tensors) - control_inputs = [op for op in control_inputs - if op._get_control_flow_context() != self] - for x in input_tensors: - x.op._set_control_flow_context(self) - x.op._add_control_inputs(control_inputs) - graph._record_op_seen_by_control_dependencies(x.op) + # pylint: disable=protected-access + for e in enters: + if isinstance(e, ops.Tensor): + xs = [e] + else: + xs = [e.values, e.indices] + shape = e.dense_shape if isinstance(e, ops.IndexedSlices) else e.shape + if shape is not None: + xs.append(shape) + for x in xs: + inp_op = x.op.inputs[0] + control_inputs = graph._control_dependencies_for_inputs([inp_op]) + control_inputs = [op for op in control_inputs + if op._get_control_flow_context() != self] + x.op._set_control_flow_context(self) + x.op._add_control_inputs(control_inputs) + graph._record_op_seen_by_control_dependencies(x.op) # pylint: enable=protected-access |