diff options
author | 2018-07-03 14:25:26 -0700 | |
---|---|---|
committer | 2018-07-03 14:28:36 -0700 | |
commit | df955f791714e53530440c6396652d3d97b6cb43 (patch) | |
tree | f325e853ad6d4e155dfe1f3986241e66cd22559e | |
parent | 998a3e619c5ee6eb3f8dd92e80748b63dc24bfef (diff) |
Parallel-for: Change while_loop conversion so that any resource/variant tensors
directly entering the body of the original while_loop also enters directly
in the converted while_loop.
PiperOrigin-RevId: 203183248
-rw-r--r-- | tensorflow/python/ops/parallel_for/pfor.py | 28 |
1 files changed, 24 insertions, 4 deletions
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 1b709535f6..ec4ef0f1ab 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -155,8 +155,14 @@ class WhileOp(object): # `loop_vars` and are not returned from the `while_loop`. In Python code, # they are usually captured by the body lambda. We collect them below by # iterating over all the ops in the graph. They are appended to the end of - # self._enters and don't correspond to any outputs in self._outputs. + # self._enters or self._direct_enters, and don't correspond to any outputs + # in self._outputs. Note that we keep the resource/variant Enter nodes in + # self._direct_enters and the constructed while_loop's body uses them + # directly as opposed to passing them as loop variables. This is done + # because the while_body cannot partition the resource/variant Tensors, so + # it has to leave them unchanged. self._enters = [] + self._direct_enters = [] for e in self._while_context.loop_exits: self._outputs.append(e.op.outputs[0]) @@ -188,7 +194,11 @@ class WhileOp(object): if op.type == "Enter": output = op.outputs[0] if output not in self._enters: - self._enters.append(output) + if output.dtype in (dtypes.resource, dtypes.variant): + if output not in self._direct_enters: + self._direct_enters.append(output) + else: + self._enters.append(output) def __str__(self): """String representation.""" @@ -197,13 +207,13 @@ class WhileOp(object): @property def inputs(self): """Input to all the Enter nodes.""" - return [x.op.inputs[0] for x in self._enters] + return [x.op.inputs[0] for x in self._enters + self._direct_enters] @property def control_inputs(self): """Control input to all the Enter nodes.""" control_inputs = [] - for x in self._enters: + for x in self._enters + self._direct_enters: control_inputs.extend(x.op.control_inputs) return control_inputs @@ -271,6 +281,16 @@ class WhileOp(object): pfor_ops=self._pfor_ops, all_indices=indices, all_indices_partitioned=cond_stacked) + # Map all inputs of Enter nodes in self._direct_enters to their converted + # values. + for enter in self._direct_enters: + enter_input = enter.op.inputs[0] + converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper( + enter_input) + # Since these are resources / variants, they should be unstacked. + assert not stacked and not is_sparse_stacked, (enter, converted_enter) + pfor._add_conversion(enter, wrap(converted_enter, False)) + # Map all Enter nodes to the inputs. for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked): pfor._add_conversion(enter, wrap(inp, stacked)) |