aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-03 14:25:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 14:28:36 -0700
commitdf955f791714e53530440c6396652d3d97b6cb43 (patch)
treef325e853ad6d4e155dfe1f3986241e66cd22559e
parent998a3e619c5ee6eb3f8dd92e80748b63dc24bfef (diff)
Parallel-for: Change while_loop conversion so that any resource/variant tensors
directly entering the body of the original while_loop also enters directly in the converted while_loop. PiperOrigin-RevId: 203183248
-rw-r--r--tensorflow/python/ops/parallel_for/pfor.py28
1 files changed, 24 insertions, 4 deletions
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index 1b709535f6..ec4ef0f1ab 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -155,8 +155,14 @@ class WhileOp(object):
# `loop_vars` and are not returned from the `while_loop`. In Python code,
# they are usually captured by the body lambda. We collect them below by
# iterating over all the ops in the graph. They are appended to the end of
- # self._enters and don't correspond to any outputs in self._outputs.
+ # self._enters or self._direct_enters, and don't correspond to any outputs
+ # in self._outputs. Note that we keep the resource/variant Enter nodes in
+ # self._direct_enters and the constructed while_loop's body uses them
+ # directly as opposed to passing them as loop variables. This is done
+ # because the while_body cannot partition the resource/variant Tensors, so
+ # it has to leave them unchanged.
self._enters = []
+ self._direct_enters = []
for e in self._while_context.loop_exits:
self._outputs.append(e.op.outputs[0])
@@ -188,7 +194,11 @@ class WhileOp(object):
if op.type == "Enter":
output = op.outputs[0]
if output not in self._enters:
- self._enters.append(output)
+ if output.dtype in (dtypes.resource, dtypes.variant):
+ if output not in self._direct_enters:
+ self._direct_enters.append(output)
+ else:
+ self._enters.append(output)
def __str__(self):
"""String representation."""
@@ -197,13 +207,13 @@ class WhileOp(object):
@property
def inputs(self):
"""Input to all the Enter nodes."""
- return [x.op.inputs[0] for x in self._enters]
+ return [x.op.inputs[0] for x in self._enters + self._direct_enters]
@property
def control_inputs(self):
"""Control input to all the Enter nodes."""
control_inputs = []
- for x in self._enters:
+ for x in self._enters + self._direct_enters:
control_inputs.extend(x.op.control_inputs)
return control_inputs
@@ -271,6 +281,16 @@ class WhileOp(object):
pfor_ops=self._pfor_ops,
all_indices=indices,
all_indices_partitioned=cond_stacked)
+ # Map all inputs of Enter nodes in self._direct_enters to their converted
+ # values.
+ for enter in self._direct_enters:
+ enter_input = enter.op.inputs[0]
+ converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper(
+ enter_input)
+ # Since these are resources / variants, they should be unstacked.
+ assert not stacked and not is_sparse_stacked, (enter, converted_enter)
+ pfor._add_conversion(enter, wrap(converted_enter, False))
+
# Map all Enter nodes to the inputs.
for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked):
pfor._add_conversion(enter, wrap(inp, stacked))