aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/parallel_for/pfor.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/parallel_for/pfor.py')
-rw-r--r--tensorflow/python/ops/parallel_for/pfor.py2552
1 files changed, 2552 insertions, 0 deletions
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
new file mode 100644
index 0000000000..77ec3bc0d4
--- /dev/null
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -0,0 +1,2552 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Compiled parallel-for loop."""
+# pylint: disable=missing-docstring
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from absl import flags
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_parsing_ops
+from tensorflow.python.ops import gen_sparse_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import nest
+
+flags.DEFINE_bool(
+ "op_conversion_fallback_to_while_loop", False,
+ "If true, falls back to using a while loop for ops for "
+ "which a converter is not defined.")
+
+
+def _stack(t, length):
+ """stacks `t` `length` times."""
+ ones = array_ops.ones_like(array_ops.shape(t))
+ multiples = array_ops.concat([length, ones], 0)
+ t = array_ops.tile(array_ops.expand_dims(t, 0), multiples)
+ return wrap(t, True)
+
+
+# The following stateful ops can be safely called once, and with the same
+# signature as the unconverted version, if their inputs are loop invariant.
+# TODO(agarwal): implement a strategy for converting Variable reads/writes. The
+# plan is to map each read/write in the loop_fn to a corresponding merged
+# read/write in the converted graph. Writes need to be mergeable (e.g.
+# AssignAdd) to be used in `pfor`. Given a certain read/write order in the
+# loop_fn, doing a one-to-one conversion will simulate executing such
+# instructions in lock-step across all iterations.
+passthrough_stateful_ops = set([
+ "VariableV2",
+ "VarHandleOp",
+ "ReadVariableOp",
+ "StackV2",
+ "TensorArrayWriteV3",
+ "TensorArrayReadV3",
+ "TensorArraySizeV3",
+])
+
+
+def _is_stateful_pfor_op(op):
+ if isinstance(op, WhileOp):
+ return op.is_stateful
+ if op.type == "Const":
+ # Const didn't have an op_def.
+ return False
+ if op.type in passthrough_stateful_ops:
+ return False
+ assert hasattr(op, "op_def") and op.op_def is not None, op
+ return op.op_def.is_stateful
+
+
+# pylint: disable=protected-access
+class WhileOp(object):
+ """Object for storing state for converting the outputs of a while_loop."""
+
+ def __init__(self, exit_node, pfor_ops):
+ """Initializer.
+
+ Args:
+ exit_node: A tensor output from the while_loop.
+ pfor_ops: list of ops inside the current pfor loop.
+ """
+ self._pfor_ops = set(pfor_ops)
+ self._pfor_op_ids = set([x._id for x in pfor_ops])
+ assert isinstance(exit_node, ops.Tensor)
+ self._while_context = exit_node.op._get_control_flow_context()
+ assert isinstance(self._while_context, control_flow_ops.WhileContext)
+ self._context_name = self._while_context.name
+ self._condition = self._while_context.pivot.op.inputs[0]
+ # Parts of an external while_loop could be created inside a pfor loop.
+ # However for the purpose here, we declare such loops to be external. Also
+ # note that we check if the condition was created inside or outside to
+ # determine if the while_loop was first created inside or outside.
+ # TODO(agarwal): check that the Enter and Exit of this loop are unstacked.
+ self._is_inside_loop = self.op_is_inside_loop(self._condition.op)
+ if self._is_inside_loop:
+ for e in self._while_context.loop_exits:
+ assert self.op_is_inside_loop(e.op)
+
+ # Note the code below tries to reverse engineer an existing while_loop graph
+ # by assuming the following pattern of nodes.
+ #
+ # NextIteration <---- Body <--- Enter
+ # | ^
+ # V ___| Y
+ # Enter -> Merge -> Switch___
+ # ^ | N
+ # | V
+ # LoopCond Exit
+
+ # Node that elements in the list below correspond one-to-one with each
+ # other. i.e. these lists are the same size, and the i_th entry corresponds
+ # to different Operations/Tensors of a single cycle as illustrated above.
+ # List of Switch ops (ops.Operation) that feed into an Exit Node.
+ self._exit_switches = []
+ # List of inputs (ops.Tensor) to NextIteration.
+ self._body_outputs = []
+ # List of list of control inputs of the NextIteration nodes.
+ self._next_iter_control_inputs = []
+ # List of Merge ops (ops.Operation).
+ self._enter_merges = []
+ # List of output (ops.Tensor) of Exit nodes.
+ self._outputs = []
+
+ # List of Enter Tensors.
+ # There are two types of Enter nodes:
+ # - The Enter nodes that are used in the `loop_vars` argument to
+ # `while_loop` (see
+ # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect
+ # these Enter nodes immediately below by tracing backwards from the Exit
+ # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the
+ # diagram above. This allows us to have a 1:1 correspondence between the
+ # self._outputs and the first elements in self._enters.
+ # - The Enter nodes that are used only by the body. They don't appear in the
+ # `loop_vars` and are not returned from the `while_loop`. In Python code,
+ # they are usually captured by the body lambda. We collect them below by
+ # iterating over all the ops in the graph. They are appended to the end of
+ # self._enters or self._direct_enters, and don't correspond to any outputs
+ # in self._outputs. Note that we keep the resource/variant Enter nodes in
+ # self._direct_enters and the constructed while_loop's body uses them
+ # directly as opposed to passing them as loop variables. This is done
+ # because the while_body cannot partition the resource/variant Tensors, so
+ # it has to leave them unchanged.
+ self._enters = []
+ self._direct_enters = []
+
+ for e in self._while_context.loop_exits:
+ self._outputs.append(e.op.outputs[0])
+ switch = e.op.inputs[0].op
+ assert switch.type == "Switch", switch
+ self._exit_switches.append(switch)
+ merge = switch.inputs[0].op
+ assert merge.type == "Merge", merge
+ self._enter_merges.append(merge)
+ enter = merge.inputs[0].op
+ assert enter.type == "Enter", enter
+ self._enters.append(enter.outputs[0])
+ next_iter = merge.inputs[1].op
+ assert next_iter.type == "NextIteration", next_iter
+ self._body_outputs.append(next_iter.inputs[0])
+ self._next_iter_control_inputs.append(next_iter.control_inputs)
+
+ # Collect all the Enter nodes that are not part of `loop_vars`, the second
+ # category described above.
+ # Also track whether the loop body has any stateful ops.
+ self._is_stateful = False
+ for op in ops.get_default_graph().get_operations():
+ # TODO(agarwal): make sure this works with nested case.
+ control_flow_context = op._get_control_flow_context()
+ if control_flow_context is None:
+ continue
+ if control_flow_context.name == self._context_name:
+ self._is_stateful |= _is_stateful_pfor_op(op)
+ if op.type == "Enter":
+ output = op.outputs[0]
+ if output not in self._enters:
+ if output.dtype in (dtypes.resource, dtypes.variant):
+ if output not in self._direct_enters:
+ self._direct_enters.append(output)
+ else:
+ self._enters.append(output)
+
+ def __str__(self):
+ """String representation."""
+ return "while_loop(%s)" % self.name
+
+ @property
+ def inputs(self):
+ """Input to all the Enter nodes."""
+ return [x.op.inputs[0] for x in self._enters + self._direct_enters]
+
+ @property
+ def control_inputs(self):
+ """Control input to all the Enter nodes."""
+ control_inputs = []
+ for x in self._enters + self._direct_enters:
+ control_inputs.extend(x.op.control_inputs)
+ return control_inputs
+
+ @property
+ def outputs(self):
+ """Outputs of all the Exit nodes."""
+ return self._outputs
+
+ @property
+ def name(self):
+ """Context name for the while loop."""
+ return self._context_name
+
+ @property
+ def is_inside_loop(self):
+ """Returns true if the while_loop was created inside the pfor."""
+ return self._is_inside_loop
+
+ def op_is_inside_loop(self, op):
+ """True if op was created inside the pfor loop body."""
+ assert isinstance(op, ops.Operation)
+ # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
+ # since it appears there tensorflow API could return different python
+ # objects representing the same Operation node.
+ return op._id in self._pfor_op_ids
+
+ @property
+ def is_stateful(self):
+ return self._is_stateful
+
+ @property
+ def pfor_converter(self):
+ """Return a converter for the while loop."""
+ return self
+
+ def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs,
+ inputs_stacked):
+ """Create a PFor object for converting parts of the while_loop.
+
+ Args:
+ parent_pfor: PFor object being used for converting the while_loop.
+ indices: int32 Tensor of ids for the iterations that are still active
+ (i.e. did not exit the while_loop).
+ cond_stacked: True if the while_loop condition is stacked.
+ inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note
+ that these Tensors are a subset of the loop variables for the generated
+ while_loop.
+ inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`,
+ indicating if the value is stacked or not.
+
+ Returns:
+ A PFor instance. The instance is initialized by adding conversion mappings
+ of nodes that will be external to the conversion that the returned
+ instance will be used for. e.g. Enter nodes as well as Merge and Switch
+ outputs are mapped to converted values.
+ """
+ num_outputs = len(self._outputs)
+ assert len(inputs) == len(self._enters)
+ assert len(inputs_stacked) == len(self._enters)
+ loop_var = parent_pfor.loop_var
+ loop_len = array_ops.size(indices)
+ pfor = PFor(
+ loop_var,
+ loop_len,
+ pfor_ops=self._pfor_ops,
+ all_indices=indices,
+ all_indices_partitioned=cond_stacked)
+ # Map all inputs of Enter nodes in self._direct_enters to their converted
+ # values.
+ for enter in self._direct_enters:
+ enter_input = enter.op.inputs[0]
+ converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper(
+ enter_input)
+ # Since these are resources / variants, they should be unstacked.
+ assert not stacked and not is_sparse_stacked, (enter, converted_enter)
+ pfor._add_conversion(enter, wrap(converted_enter, False))
+
+ # Map all Enter nodes to the inputs.
+ for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked):
+ pfor._add_conversion(enter, wrap(inp, stacked))
+ # Map outputs of Switch and Merge.
+ for i in range(num_outputs):
+ wrapped_inp = wrap(inputs[i], inputs_stacked[i])
+ merge = self._enter_merges[i]
+ pfor._add_conversion(merge.outputs[0], wrapped_inp)
+ # Note that second output of Merge is typically not used, except possibly
+ # as a control dependency. To avoid trying to output the correct value, we
+ # employ a hack here. We output a dummy invalid value with an incorrect
+ # dtype. This will allow control dependency to work but if using it as an
+ # input, it should typically lead to errors during graph construction due
+ # to dtype mismatch.
+ # TODO(agarwal): Check in the original graph to see if there are any
+ # consumers of this Tensor that use it as an input.
+ pfor._add_conversion(merge.outputs[1],
+ wrap(constant_op.constant(-1.0), False))
+ switch = self._exit_switches[i]
+ # Don't need to worry about switch.output[0] which will feed to Exit node.
+ pfor._add_conversion(switch.outputs[1], wrapped_inp)
+ return pfor
+
+ def _convert_enter(self, parent_pfor, enter):
+ """Converts an Enter node."""
+ inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0])
+ control_inputs = [
+ parent_pfor._convert_helper(x).t for x in enter.op.control_inputs
+ ]
+ if control_inputs:
+ with ops.control_dependencies(control_inputs):
+ inp = array_ops.identity(inp)
+ return inp, stacked
+
+ def _maybe_stacked(self, cache, inp):
+ """Heuristic to figue out if the coverting inp leads to a stacked value.
+
+
+ Args:
+ cache: map from Tensor to boolean indicating stacked/unstacked.
+ inp: input Tensor.
+
+ Returns:
+ True if `inp` could get stacked. If the function returns False, the
+ converted value should be guaranteed to be unstacked. If returning True,
+ it may or may not be stacked.
+ """
+ if inp in cache:
+ return cache[inp]
+ if not self.op_is_inside_loop(inp.op):
+ return False
+ op = inp.op
+ output = False
+ if op.type in [
+ "Shape",
+ "Rank"
+ "ShapeN",
+ "ZerosLike",
+ "TensorArrayV3",
+ "TensorArraySizeV3",
+ ]:
+ output = False
+ elif _is_stateful_pfor_op(op):
+ # This may be fairly aggressive.
+ output = True
+ elif op.type == "Exit":
+ # This may be fairly aggressive.
+ output = True
+ else:
+ for t in op.inputs:
+ if self._maybe_stacked(cache, t):
+ output = True
+ break
+ cache[inp] = output
+ return output
+
+ def _create_init_values(self, pfor_input):
+ """Create arguments passed to converted while_loop."""
+ with ops.name_scope("while_init"):
+ loop_len_vector = pfor_input.pfor.loop_len_vector
+ loop_len = loop_len_vector[0]
+ num_outputs = len(self._outputs)
+
+ inputs = []
+ maybe_stacked_cache = {}
+ # Convert all the Enters. Need to do this before checking for stacking
+ # below.
+ for i, enter in enumerate(self._enters):
+ inp, stacked = self._convert_enter(pfor_input.pfor, enter)
+ inputs.append(inp)
+ maybe_stacked_cache[enter] = stacked
+ # Since this enter node is part of the `loop_vars`, it corresponds to an
+ # output and its preceding switch. We mark this switch's output the same
+ # stackness, to act at the base case for the logic below. Below, we will
+ # be going through the body figuring out which inputs might need to be
+ # stacked and which inputs can safely remain unstacked.
+ if i < num_outputs:
+ maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked
+
+ # Shape invariants for init_values corresponding to self._enters.
+ input_shape_invariants = []
+ # TensorArrays for outputs of converted while loop
+ output_tas = []
+ # Shape invariants for output TensorArrays.
+ ta_shape_invariants = []
+ # List of booleans indicating stackness of inputs, i.e. tensors
+ # corresponding to self._enters.
+ inputs_stacked = []
+ for i, inp in enumerate(inputs):
+ enter = self._enters[i]
+ inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter)
+ # Note that even when an input is unstacked, the body could make it
+ # stacked. we use a heuristic below to figure out if body may be making
+ # it stacked.
+ if i < num_outputs:
+ body_output = self._body_outputs[i]
+ if enter.op in self._pfor_ops:
+ body_output_stacked = self._maybe_stacked(maybe_stacked_cache,
+ body_output)
+ else:
+ # If constructed outside of pfor loop, then the output would not be
+ # stacked.
+ body_output_stacked = False
+ if body_output_stacked and not inp_stacked:
+ inp = _stack(inp, loop_len_vector).t
+ inputs[i] = inp
+ inp_stacked = True
+ # TODO(agarwal): other attributes for the TensorArray ?
+ output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len))
+ ta_shape_invariants.append(tensor_shape.TensorShape(None))
+
+ inputs_stacked.append(inp_stacked)
+ input_shape_invariants.append(tensor_shape.TensorShape(None))
+
+ # See documentation for __call__ for the structure of init_values.
+ init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas
+ # TODO(agarwal): try stricter shape invariants
+ shape_invariants = (
+ [tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None)
+ ] + input_shape_invariants + ta_shape_invariants)
+
+ return init_values, inputs_stacked, shape_invariants
+
+ def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
+ """Handles case when condition is unstacked.
+
+ Note that all iterations end together. So we don't need to partition the
+ inputs. When all iterations are done, we write the inputs to the
+ TensorArrays. Note that we only write to index 0 of output_tas. Since all
+ iterations end together, they can all be output together.
+ """
+ not_all_done = array_ops.reshape(conditions, [])
+ new_output_tas = []
+ # pylint: disable=cell-var-from-loop
+ for i, out_ta in enumerate(output_tas):
+ inp = inputs[i]
+ new_output_tas.append(
+ control_flow_ops.cond(not_all_done,
+ lambda: out_ta,
+ lambda: out_ta.write(0, inp)))
+ # pylint: enable=cell-var-from-loop
+ return not_all_done, indices, inputs, new_output_tas
+
+ def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
+ output_tas):
+ num_outputs = len(self._outputs)
+ # Compute if all iterations are done.
+ not_all_done = math_ops.reduce_any(conditions)
+ conditions_int = math_ops.cast(conditions, dtypes.int32)
+ # Partition the indices.
+ done_indices, new_indices = data_flow_ops.dynamic_partition(
+ indices, conditions_int, 2)
+
+ new_inputs = []
+ new_output_tas = []
+ for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
+ # Partition the inputs.
+ if stacked:
+ done_inp, new_inp = data_flow_ops.dynamic_partition(
+ inp, conditions_int, 2)
+ else:
+ # TODO(agarwal): avoid this stacking. See TODO earlier in
+ # _process_cond_unstacked.
+ done_inp = _stack(inp, [array_ops.size(done_indices)]).t
+ new_inp = inp
+ new_inputs.append(new_inp)
+ # For iterations that are done, write them to TensorArrays.
+ if i < num_outputs:
+ out_ta = output_tas[i]
+ # Note that done_indices can be empty. done_inp should also be empty in
+ # that case.
+ new_output_tas.append(out_ta.scatter(done_indices, done_inp))
+ return not_all_done, new_indices, new_inputs, new_output_tas
+
+ def _process_body(self, pfor_input, inputs_stacked,
+ new_indices, cond_stacked, new_inputs,
+ not_all_done):
+ """Convert the body function."""
+
+ def true_fn(control_inputs, body_pfor, body_output, stacked):
+ """Converts the body function for all but last iteration.
+
+ This essentially converts body_output. Additionally, it needs to handle
+ any control dependencies on the NextIteration node. So it creates another
+ Identity node with the converted dependencies.
+ """
+ converted_control_inp = []
+ for x in control_inputs:
+ for t in x.outputs:
+ converted_control_inp.append(body_pfor._convert_helper(t).t)
+ if stacked:
+ # Note convert always does the stacking.
+ output = body_pfor.convert(body_output)
+ else:
+ output, convert_stacked, _ = body_pfor._convert_helper(body_output)
+ assert convert_stacked == stacked, body_output
+ with ops.control_dependencies(converted_control_inp):
+ return array_ops.identity(output)
+
+ body_pfor = self._init_pfor(pfor_input.pfor, new_indices,
+ cond_stacked, new_inputs,
+ inputs_stacked)
+ new_outputs = []
+
+ for i, (body_output, stacked) in enumerate(
+ zip(self._body_outputs, inputs_stacked)):
+ control_inp = self._next_iter_control_inputs[i]
+ out_dtype = body_output.dtype
+ # Note that we want to run the body only if not all pfor iterations are
+ # done. If all are done, we return empty tensors since these values will
+ # not be used. Notice that the value returned by the loop is based on
+ # TensorArrays and not directly on these returned values.
+ # pylint: disable=cell-var-from-loop
+ new_output = control_flow_ops.cond(
+ not_all_done,
+ lambda: true_fn(control_inp, body_pfor, body_output, stacked),
+ lambda: constant_op.constant([], dtype=out_dtype))
+ # pylint: enable=cell-var-from-loop
+ new_outputs.append(new_output)
+ return new_outputs
+
+ def __call__(self, pfor_input):
+ """Converter for the while_loop.
+
+ The conversion of a while_loop is another while_loop.
+
+ The arguments to this converted while_loop are as follows:
+ not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
+ are done.
+ indices: int32 1-D Tensor storing the id of the iterations that are not
+ done.
+ args: Remaining arguments. These can be divided into 3 categories:
+ - First set of arguments are the tensors that correspond to the initial
+ elements of self._enters. The elements that appear in original while
+ loop's `loop_vars`.
+ - The second set of arguments are the tensors that correspond to the
+ remaining elements of self._enters. These are the tensors that directly
+ enter the original while loop body.
+ - Finally, the last set of arguments are TensorArrays. These TensorArrays
+ correspond to the outputs of the original while_loop, i.e. to the
+ elements in self._outputs. Each TensorArray has `PFor.loop_len`
+ elements, i.e. the number of pfor iterations. At the end, the i'th
+ element of each TensorArray will contain the output computed by the
+ i'th iteration of pfor. Note that elements can be written into these
+ tensors arrays in any order, depending on when the corresponding pfor
+ iteration is done.
+ If the original while_loop had `k` tensors in its `loop_vars` and its body
+ directly captured `m` tensors, the `args` will contain `2 * k + m` values.
+
+ In each iteration, the while_loop body recomputes the condition for all
+ active pfor iterations to see which of them are now done. It then partitions
+ all the inputs and passes them along to the converted body. Values for all
+ the iterations that are done are written to TensorArrays indexed by the pfor
+ iteration number. When all iterations are done, the TensorArrays are stacked
+ to get the final value.
+
+ Args:
+ pfor_input: A PForInput object corresponding to the output of any Exit
+ node from this while loop.
+
+ Returns:
+ List of converted outputs.
+ """
+ # Create init_values that will be passed to the while_loop.
+ init_values, inputs_stacked, shape_invariants = self._create_init_values(
+ pfor_input)
+ # Note that we use a list as a hack since we need the nested function body
+ # to set the value of cond_is_stacked. python2.x doesn't support nonlocal
+ # variables.
+ cond_is_stacked = [None]
+
+ def cond(not_all_done, *_):
+ return not_all_done
+
+ def body(not_all_done, indices, *args):
+ # See documentatin for __call__ for the structure of *args.
+ num_enters = len(self._enters)
+ inputs = args[:num_enters]
+ output_tas = args[num_enters:]
+ # TODO(agarwal): see which outputs have consumers and only populate the
+ # TensorArrays corresponding to those. Or do those paths get trimmed out
+ # from inside the while_loop body?
+ assert len(inputs) >= len(output_tas)
+ assert len(inputs) == len(inputs_stacked)
+
+ # Convert condition
+ with ops.name_scope("while_cond"):
+ # Note that we set cond_stacked to True here. At this point we don't
+ # know if it could be loop invariant, hence the conservative value is
+ # to assume stacked.
+ cond_pfor = self._init_pfor(pfor_input.pfor, indices,
+ cond_stacked=True,
+ inputs=inputs,
+ inputs_stacked=inputs_stacked)
+ conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition)
+ cond_is_stacked[0] = cond_stacked
+
+ # Recompute the new condition, write outputs of done iterations, and
+ # partition the inputs if needed.
+ if not cond_stacked:
+ (not_all_done, new_indices,
+ new_inputs, new_output_tas) = self._process_cond_unstacked(
+ conditions, indices, inputs, output_tas)
+ else:
+ (not_all_done, new_indices,
+ new_inputs, new_output_tas) = self._process_cond_stacked(
+ conditions, indices, inputs, inputs_stacked, output_tas)
+
+ # Convert body
+ with ops.name_scope("while_body"):
+ # Compute the outputs from the body.
+ new_outputs = self._process_body(pfor_input, inputs_stacked,
+ new_indices, cond_stacked, new_inputs,
+ not_all_done)
+
+ # Note that the first num_outputs new values of inputs are computed using
+ # the body. Rest of them were direct Enters into the condition/body and
+ # the partitioning done earlier is sufficient to give the new value.
+ num_outputs = len(self._outputs)
+ new_args = ([not_all_done, new_indices] + new_outputs + list(
+ new_inputs[num_outputs:]) + new_output_tas)
+ return tuple(new_args)
+
+ while_outputs = control_flow_ops.while_loop(
+ cond, body, init_values, shape_invariants=shape_invariants)
+ output_tas = while_outputs[-len(self._outputs):]
+ outputs = []
+ assert cond_is_stacked[0] is not None
+ for inp_stacked, ta in zip(inputs_stacked, output_tas):
+ if cond_is_stacked[0]:
+ outputs.append(wrap(ta.stack(), True))
+ else:
+ # Note that if while_loop condition is unstacked, all iterations exit at
+ # the same time and we wrote those outputs in index 0 of the tensor
+ # array.
+ outputs.append(wrap(ta.read(0), inp_stacked))
+ return outputs
+
+
+class _PforInput(object):
+ """Input object passed to registered pfor converters."""
+
+ def __init__(self, pfor, op, inputs):
+ """Creates a _PforInput object.
+
+ Args:
+ pfor: PFor converter object.
+ op: the Operation object that is being converted.
+ inputs: list of WrappedTensor objects representing converted values of the
+ inputs of `op`.
+ """
+ self.pfor = pfor
+ self._op = op
+ self._inputs = inputs
+
+ def stack_inputs(self, stack_indices=None):
+ """Stacks unstacked inputs at `stack_indices`.
+
+ Args:
+ stack_indices: indices of inputs at which stacking is done. If None,
+ stacking is done at all indices.
+ """
+ if stack_indices is None:
+ stack_indices = range(len(self._inputs))
+ length = self.pfor.loop_len_vector
+ for i in stack_indices:
+ inp = self._inputs[i]
+ if not inp.is_stacked:
+ self._inputs[i] = _stack(inp.t, length)
+
+ def expanddim_inputs_for_broadcast(self):
+ """Reshapes stacked inputs to prepare them for broadcast.
+
+ Since stacked inputs have an extra leading dimension, automatic broadcasting
+ rules could incorrectly try to expand dimensions before that leading
+ dimension. To avoid that, we reshape these stacked inputs to the maximum
+ rank they will need to be broadcasted to.
+ """
+ if not self._inputs:
+ return
+
+ # Find max rank
+ def _get_rank(x):
+ rank = array_ops.rank(x.t)
+ if not x.is_stacked:
+ rank += 1
+ return rank
+
+ ranks = [_get_rank(x) for x in self._inputs]
+ max_rank = ranks[0]
+ for rank in ranks[1:]:
+ max_rank = math_ops.maximum(rank, max_rank)
+
+ for i, inp in enumerate(self._inputs):
+ if inp.is_stacked:
+ shape = array_ops.shape(inp.t)
+ rank_diff = array_ops.reshape(max_rank - ranks[i], [1])
+ ones = array_ops.tile([1], rank_diff)
+ new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0)
+ self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True)
+
+ @property
+ def inputs(self):
+ return self._inputs
+
+ @property
+ def num_inputs(self):
+ return len(self._inputs)
+
+ def input(self, index):
+ assert len(self._inputs) > index, (index, self._inputs)
+ return self._inputs[index]
+
+ def stacked_input(self, index):
+ t, is_stacked, _ = self.input(index)
+ if not is_stacked:
+ op_type = self.op_type
+ op_def = getattr(self._op, "op_def", None)
+ if op_def is None:
+ input_name = "at index %d" % index
+ else:
+ input_name = "\"%s\"" % op_def.input_arg[index].name
+ raise ValueError("Input %s of op \"%s\" expected to be not loop invariant"
+ ".\nError while converting op %s"
+ "with converted inputs\n%s" % (input_name, op_type,
+ self._op, self.inputs))
+ return t
+
+ def unstacked_input(self, index):
+ t, is_stacked, _ = self.input(index)
+ if is_stacked:
+ op_type = self.op_type
+ op_def = getattr(self._op, "op_def", None)
+ if op_def is None:
+ input_name = "at index %d" % index
+ else:
+ input_name = "\"%s\"" % op_def.input_arg[index].name
+ raise ValueError("Input %s of op \"%s\" expected to be loop invariant"
+ ".\nError while converting op %s"
+ "with converted inputs\n%s" % (input_name, op_type,
+ self._op, self.inputs))
+ return t
+
+ @property
+ def op(self):
+ return self._op
+
+ @property
+ def op_type(self):
+ return self._op.type
+
+ def get_attr(self, attr):
+ return self._op.get_attr(attr)
+
+ @property
+ def outputs(self):
+ return self._op.outputs
+
+ def output(self, index):
+ assert index < len(self._op.outputs)
+ return self._op.outputs[index]
+
+
+_pfor_converter_registry = {}
+
+
+class RegisterPFor(object):
+ """Utility to register converters for pfor.
+
+ Usage:
+ @RegisterPFor(foo_op_type)
+ def _foo_converter(pfor_input):
+ ...
+
+ The above will register conversion function `_foo_converter` for handling
+ conversion of `foo_op_type`. During conversion, the registered functin will be
+ called with a single argument of type `PForInput` which will contain state
+ needed for the conversion. This registered function should output a list of
+ WrappedTensor object with the same length as the number of outputs of op being
+ converted. If the op had zero outputs, then it should return a ops.Operation
+ object.
+ """
+
+ def __init__(self, op_type):
+ """Creates an object to register a converter for op with type `op_type`."""
+ self.op_type = op_type
+
+ def __call__(self, converter):
+ name = self.op_type
+ assert name not in _pfor_converter_registry, "Re-registering %s " % name
+ _pfor_converter_registry[name] = converter
+ return converter
+
+
+class RegisterPForWithArgs(RegisterPFor):
+ """Utility to register converters for pfor.
+
+ Usage:
+ @RegisteRPFor(foo_op_type, foo=value, ....)
+ def _foo_converter(pfor_input, foo=None, ....):
+ ...
+
+ See RegisterPFor for details on the conversion function.
+ `RegisterPForWithArgs` allows binding extra arguments to the
+ conversion function at registration time.
+ """
+
+ def __init__(self, op_type, *args, **kw_args):
+ super(RegisterPForWithArgs, self).__init__(op_type)
+ self._args = args
+ self._kw_args = kw_args
+
+ def __call__(self, converter):
+
+ def _f(pfor_input):
+ return converter(pfor_input, self.op_type, *self._args, **self._kw_args)
+
+ super(RegisterPForWithArgs, self).__call__(_f)
+ return converter
+
+
+def _create_op(op_type, inputs, op_dtypes, attrs=None):
+ """Utility to create an op."""
+ return ops.get_default_graph().create_op(
+ op_type, inputs, op_dtypes, attrs=attrs, compute_device=True)
+
+
+WrappedTensor = collections.namedtuple("WrappedTensor",
+ ["t", "is_stacked", "is_sparse_stacked"])
+"""Wrapper around the result of a Tensor conversion.
+
+The additional fields are useful for keeping track of the conversion state as
+data flows through the ops in the loop body. For every op whose output is a
+Tensor, its converter should return either a WrappedTensor or a list of
+WrappedTensors.
+
+Args:
+ t: The converted tensor
+ is_stacked: True if the tensor is stacked, i.e. represents the results of all
+ the iterations of the loop, where each row i of the tensor corresponds to
+ that op's output on iteration i of the loop. False if the tensor is not
+ stacked, i.e. represents the result of the op on of a single iteration of
+ the loop, where the result does not vary between iterations.
+ is_sparse_stacked: True if the tensor corresponds to a component tensor
+ (indices, values, or dense_shape) of a sparse tensor, and has been logically
+ stacked via a sparse conversion.
+"""
+
+
+def wrap(tensor, is_stacked=True, is_sparse_stacked=False):
+ """Helper to create a WrappedTensor object."""
+ assert isinstance(is_stacked, bool)
+ assert isinstance(is_sparse_stacked, bool)
+ assert isinstance(tensor, ops.Tensor)
+ assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is "
+ "stacked via a sparse "
+ "conversion, it must also be "
+ "stacked.")
+ return WrappedTensor(tensor, is_stacked, is_sparse_stacked)
+
+
+def _fallback_converter(pfor_input):
+ logging.warn("Using a while_loop for converting %s", pfor_input.op_type)
+ output_dtypes = [x.dtype for x in pfor_input.outputs]
+ iters = pfor_input.pfor.loop_len_vector[0]
+
+ def while_body(i, *ta_list):
+ """Body of while loop."""
+ inputs = [
+ x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs
+ ]
+ op_outputs = _create_op(
+ pfor_input.op_type,
+ inputs,
+ output_dtypes,
+ attrs=pfor_input.op.node_def.attr).outputs
+
+ outputs = []
+ for out, ta in zip(op_outputs, ta_list):
+ assert isinstance(out, ops.Tensor)
+ outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
+ return tuple([i + 1] + outputs)
+
+ ta_list = control_flow_ops.while_loop(
+ lambda i, *ta: i < iters, while_body, [0] + [
+ tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes
+ ])[1:]
+ return tuple([wrap(ta.concat(), True) for ta in ta_list])
+
+
+class PFor(object):
+ """Implementation of rewrite of parallel-for loops.
+
+ This class takes a DAG or a set of DAGs representing the body of a
+ parallel-for loop, and adds new operations to the graph that implements
+ functionality equivalent to running that loop body for a specified number of
+ iterations. This new set of nodes may or may not use a tensorflow loop
+ construct.
+
+ The process of conversion does not delete or change any existing operations.
+ It only adds operations that efficiently implement the equivalent
+ functionality. We refer to the added ops as "converted ops".
+
+ The conversion process uses a simple greedy heuristic. It walks the loop body
+ and tries to express the functionality of running each node in a loop with a
+ new set of nodes. When converting an op several cases are possible:
+ - The op is not inside the loop body. Hence it can be used as is.
+ - The op does not depend on the iteration number and is stateless. In this
+ case, it can be used as is.
+ - The op is not stateful, and depends on iteration number only through control
+ dependencies. In this case, we can create a single op with same inputs and
+ attributes, but with "converted" control dependencies.
+ - The op is not stateful, and all its inputs are loop invariant. In this
+ case, similar to above, we can create a single op with same inputs and
+ attributes, but with "converted" control dependencies.
+ - The op is stateful or at least one of the inputs is not loop invariant. In
+ this case, we run the registered converter for that op to create a set of
+ converted ops. All nodes in the set will have converted control dependencies
+ corresponding to control dependencies of the original op. If the op returned
+ multiple outputs, "converted outputs" could be produced by different ops in
+ this set.
+ """
+
+ def __init__(self,
+ loop_var,
+ loop_len,
+ pfor_ops,
+ all_indices=None,
+ all_indices_partitioned=False):
+ """Creates an object to rewrite a parallel-for loop.
+
+ Args:
+ loop_var: ops.Tensor output of a Placeholder operation. The value should
+ be an int32 scalar representing the loop iteration number.
+ loop_len: A scalar or scalar Tensor representing the number of iterations
+ the loop is run for.
+ pfor_ops: List of all ops inside the loop body.
+ all_indices: If not None, an int32 vector with size `loop_len`
+ representing the iteration ids that are still active. These values
+ should be unique and sorted. However they may not be contiguous. This is
+ typically the case when inside a control flow construct which has
+ partitioned the indices of the iterations that are being converted.
+ all_indices_partitioned: If True, this object is being constructed from a
+ control flow construct where not all the pfor iterations are guaranteed
+ to be active.
+ """
+ assert isinstance(loop_var, ops.Tensor)
+ assert loop_var.op.type == "Placeholder"
+ self._loop_var = loop_var
+ loop_len_value = tensor_util.constant_value(loop_len)
+ if loop_len_value is not None:
+ loop_len = loop_len_value
+ self._loop_len_vector = array_ops.reshape(loop_len, [1])
+ self._all_indices_partitioned = all_indices_partitioned
+ if all_indices_partitioned:
+ assert all_indices is not None
+ self.all_indices = (
+ math_ops.range(loop_len) if all_indices is None else all_indices)
+
+ self._conversion_map = {}
+ self._conversion_map[loop_var] = wrap(self.all_indices, True)
+ self._pfor_ops = set(pfor_ops)
+ self._pfor_op_ids = set([x._id for x in pfor_ops])
+
+ def op_is_inside_loop(self, op):
+ """True if op was created inside the pfor loop body."""
+ assert isinstance(op, ops.Operation)
+ # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
+ # since it appears there tensorflow API could return different python
+ # objects representing the same Operation node.
+ return op._id in self._pfor_op_ids
+
+ def _convert_sparse(self, y):
+ """Returns the converted value corresponding to SparseTensor y.
+
+ For SparseTensors, instead of stacking the component tensors separately,
+ resulting in component tensors with shapes (N, m, rank), (N, m), and (N,
+ rank) respectively for indices, values, and dense_shape (where N is the loop
+ length and m is the number of sparse tensor values per loop iter), we want
+ to logically stack the SparseTensors, to create a SparseTensor whose
+ components are size (N * m, rank + 1), (N * m, ), and (rank + 1,)
+ respectively.
+
+ Here, we try to get the conversion of each component tensor.
+ If the tensors are stacked via a sparse conversion, return the resulting
+ SparseTensor composed of the converted components. Otherwise, the component
+ tensors are either unstacked or stacked naively. In the latter case, we
+ unstack the component tensors to reform loop_len SparseTensor elements,
+ then correctly batch them.
+
+ The unstacked tensors must have the same rank. Each dimension of each
+ SparseTensor will expand to be the largest among all SparseTensor elements
+ for that dimension. For example, if there are N SparseTensors of rank 3
+ being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i),
+ the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)).
+
+ Args:
+ y: A tf.SparseTensor.
+
+ Returns:
+ A tf.SparseTensor that is the converted value corresponding to y.
+ """
+ outputs = [
+ self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape)
+ ]
+ assert all(isinstance(o, WrappedTensor) for o in outputs)
+
+ if all(w.is_sparse_stacked for w in outputs):
+ return sparse_tensor.SparseTensor(*[w.t for w in outputs])
+
+ assert not any(w.is_sparse_stacked for w in outputs), (
+ "Error converting SparseTensor. All components should be logically "
+ "stacked, or none.")
+
+ # If component tensors were not sparsely stacked, they are either unstacked
+ # or stacked without knowledge that they are components of sparse tensors.
+ # In this case, we have to restack them.
+ return self._restack_sparse_tensor_logically(
+ *[self._unwrap_or_tile(w) for w in outputs])
+
+ def _restack_sparse_tensor_logically(self, indices, values, shape):
+ sparse_tensor_rank = indices.get_shape()[-1].value
+ if sparse_tensor_rank is not None:
+ sparse_tensor_rank += 1
+
+ def map_fn(args):
+ res = gen_sparse_ops.serialize_sparse(
+ args[0], args[1], args[2], out_type=dtypes.variant)
+ return res
+
+ # Applies a map function to the component tensors to serialize each
+ # sparse tensor element and batch them all, then deserializes the batch.
+ # TODO(rachelim): Try to do this without map_fn -- add the right offsets
+ # to shape and indices tensors instead.
+ result = functional_ops.map_fn(
+ map_fn, [indices, values, shape], dtype=dtypes.variant)
+ return sparse_ops.deserialize_sparse(
+ result, dtype=values.dtype, rank=sparse_tensor_rank)
+
+ def _unwrap_or_tile(self, wrapped_tensor):
+ """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it."""
+ output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked
+ if is_stacked:
+ return output
+ else:
+ return _stack(output, self._loop_len_vector).t
+
+ def convert(self, y):
+ """Returns the converted value corresponding to y.
+
+ Args:
+ y: A ops.Tensor or a ops.Operation object. If latter, y should not have
+ any outputs.
+
+ Returns:
+ If y does not need to be converted, it returns y as is. Else it returns
+ the "converted value" corresponding to y.
+ """
+ if isinstance(y, sparse_tensor.SparseTensor):
+ return self._convert_sparse(y)
+ output = self._convert_helper(y)
+ if isinstance(output, WrappedTensor):
+ assert isinstance(y, ops.Tensor)
+ return self._unwrap_or_tile(output)
+ else:
+ assert isinstance(y, ops.Operation)
+ assert not y.outputs
+ assert isinstance(output, ops.Operation)
+ return output
+
+ def _was_converted(self, t):
+ """True if t is not a conversion of itself."""
+ converted_t = self._conversion_map[t]
+ return converted_t.t is not t
+
+ def _add_conversion(self, old_output, new_output):
+ self._conversion_map[old_output] = new_output
+
+ def _convert_helper(self, op_or_tensor):
+ stack = [op_or_tensor]
+ while stack:
+ y = stack[0]
+ if y in self._conversion_map:
+ assert isinstance(self._conversion_map[y],
+ (WrappedTensor, ops.Operation))
+ stack.pop(0)
+ continue
+ if isinstance(y, ops.Operation):
+ assert not y.outputs, (
+ "We only support converting Operation objects with no outputs. "
+ "Got %s", y)
+ y_op = y
+ else:
+ assert isinstance(y, ops.Tensor), y
+ y_op = y.op
+
+ is_while_loop = y_op.type == "Exit"
+ if is_while_loop:
+ while_op = WhileOp(y, pfor_ops=self._pfor_ops)
+ is_inside_loop = while_op.is_inside_loop
+ # If all nodes in the while_loop graph were created inside the pfor, we
+ # treat the whole loop subgraph as a single op (y_op) and try to convert
+ # it. For while_loops that are created completely or partially outside,
+ # we treat them as external and should be able to simply return the Exit
+ # node output as is without needing any conversion. Note that for
+ # while_loops that are partially constructed inside, we assume they will
+ # be loop invariant. If that is not the case, it will create runtime
+ # errors since the converted graph would depend on the self._loop_var
+ # placeholder.
+ if is_inside_loop:
+ y_op = while_op
+ else:
+ is_inside_loop = self.op_is_inside_loop(y_op)
+
+ # If this op was not created inside the loop body, we will return as is.
+ # 1. Convert inputs and control inputs.
+
+ def _add_to_stack(x):
+ if x not in self._conversion_map:
+ stack.insert(0, x)
+ return True
+ else:
+ return False
+
+ if is_inside_loop:
+ added_to_stack = False
+ for inp in y_op.inputs:
+ added_to_stack |= _add_to_stack(inp)
+ for cinp in y_op.control_inputs:
+ if cinp.outputs:
+ for t in cinp.outputs:
+ added_to_stack |= _add_to_stack(t)
+ else:
+ added_to_stack |= _add_to_stack(cinp)
+ if added_to_stack:
+ continue
+
+ converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs]
+ some_input_converted = any(
+ [self._was_converted(x) for x in y_op.inputs])
+ some_input_stacked = any([x.is_stacked for x in converted_inputs])
+
+ converted_control_ops = set()
+ some_control_input_converted = False
+ for cinp in y_op.control_inputs:
+ if cinp.outputs:
+ for t in cinp.outputs:
+ converted_t = self._conversion_map[t]
+ if self._was_converted(t):
+ some_control_input_converted = True
+ converted_control_ops.add(converted_t.t.op)
+ else:
+ converted_cinp = self._conversion_map[cinp]
+ assert isinstance(converted_cinp, ops.Operation)
+ if converted_cinp != cinp:
+ some_control_input_converted = True
+ converted_control_ops.add(converted_cinp)
+ converted_control_ops = list(converted_control_ops)
+ is_stateful = _is_stateful_pfor_op(y_op)
+ else:
+ converted_inputs = []
+ converted_control_ops = []
+ logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op,
+ converted_inputs, converted_control_ops)
+
+ # 2. Convert y_op
+ # If converting a while_loop, we let the while_loop convertor deal with
+ # putting the control dependencies appropriately.
+ control_dependencies = [] if is_while_loop else converted_control_ops
+ with ops.control_dependencies(control_dependencies), ops.name_scope(
+ y_op.name + "/pfor/"):
+ # None of the inputs and control inputs were converted.
+ if (not is_inside_loop or
+ (not is_stateful and not some_input_converted and
+ not some_control_input_converted)):
+ if y == y_op:
+ assert not isinstance(y_op, WhileOp)
+ new_outputs = y_op
+ else:
+ new_outputs = [wrap(x, False) for x in y_op.outputs]
+ elif not (is_stateful or is_while_loop or some_input_stacked):
+ # All inputs are unstacked or uncoverted but some control inputs are
+ # converted.
+ # TODO(rachelim): Handle the case where some inputs are sparsely
+ # stacked (i.e. any([x.is_sparse_stacked for x in converted_inputs]))
+ new_op = _create_op(y_op.type, [x.t for x in converted_inputs],
+ [x.dtype for x in y_op.outputs],
+ y_op.node_def.attr)
+ if y == y_op:
+ new_outputs = new_op
+ else:
+ new_outputs = [wrap(x, False) for x in new_op.outputs]
+ else:
+ # Either some inputs are not loop invariant or op is stateful.
+ if hasattr(y_op, "pfor_converter"):
+ converter = y_op.pfor_converter
+ else:
+ converter = _pfor_converter_registry.get(y_op.type, None)
+ if converter is None:
+ if flags.FLAGS.op_conversion_fallback_to_while_loop:
+ converter = _fallback_converter
+ else:
+ raise ValueError(
+ "No converter defined for %s\n%s\ninputs: %s. "
+ "\nEither add a converter or set "
+ "--op_conversion_fallback_to_while_loop=True, "
+ "which may run slower" % (y_op.type, y_op, converted_inputs))
+ # TODO(rachelim): Handle the case where some inputs are sparsely
+ # stacked. We should only call the converter if it supports handling
+ # those inputs.
+ new_outputs = converter(_PforInput(self, y_op, converted_inputs))
+ if isinstance(new_outputs, WrappedTensor):
+ new_outputs = [new_outputs]
+ assert isinstance(new_outputs,
+ (list, tuple, ops.Operation)), new_outputs
+ logging.vlog(2, "converted %s %s", y_op, new_outputs)
+
+ # Insert into self._conversion_map
+ if y == y_op:
+ assert isinstance(new_outputs, ops.Operation)
+ self._add_conversion(y_op, new_outputs)
+ else:
+ for old_output, new_output in zip(y_op.outputs, new_outputs):
+ assert isinstance(new_output, WrappedTensor), (new_output, y, y_op)
+ self._add_conversion(old_output, new_output)
+ stack.pop(0)
+
+ return self._conversion_map[op_or_tensor]
+
+ @property
+ def loop_len_vector(self):
+ """Returns a single element vector whose value is number of iterations."""
+ return self._loop_len_vector
+
+ @property
+ def loop_var(self):
+ """Returns placeholder loop variable."""
+ return self._loop_var
+
+ @property
+ def pfor_ops(self):
+ return self._pfor_ops
+
+ @property
+ def all_indices_partitioned(self):
+ """all_indices_partitioned property.
+
+ Returns:
+ True if we are inside a control flow construct and not all pfor iterations
+ may be active.
+ """
+ return self._all_indices_partitioned
+
+# nn_ops
+
+
+def _flatten_first_two_dims(x):
+ """Merges first two dimensions."""
+ old_shape = array_ops.shape(x)
+ new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0)
+ return array_ops.reshape(x, new_shape)
+
+
+def _unflatten_first_dim(x, first_dim):
+ """Splits first dimension into [first_dim, -1]."""
+ old_shape = array_ops.shape(x)
+ new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0)
+ return array_ops.reshape(x, new_shape)
+
+
+def _inputs_with_flattening(pfor_input, input_indices):
+ """Stacks and flattens first dim of inputs at indices `input_indices`."""
+ if input_indices is None:
+ input_indices = []
+ pfor_input.stack_inputs(stack_indices=input_indices)
+ inputs = []
+ for i in range(pfor_input.num_inputs):
+ if i in input_indices:
+ inp = pfor_input.stacked_input(i)
+ inp = _flatten_first_two_dims(inp)
+ else:
+ inp = pfor_input.unstacked_input(i)
+ inputs.append(inp)
+ return inputs
+
+
+@RegisterPForWithArgs("Conv2D", dims=[0])
+@RegisterPForWithArgs("AvgPool", dims=[0])
+@RegisterPForWithArgs("MaxPool", dims=[0])
+@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2])
+@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1])
+def _convert_flatten_batch(pfor_input, op_type, dims):
+ del op_type
+ inputs = _inputs_with_flattening(pfor_input, dims)
+ outputs = _create_op(
+ pfor_input.op_type,
+ inputs, [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ n = pfor_input.pfor.loop_len_vector
+ outputs = [_unflatten_first_dim(x, n) for x in outputs]
+ return [wrap(x, True) for x in outputs]
+
+
+_channel_flatten_input_cache = {}
+
+
+def _channel_flatten_input(x, data_format):
+ """Merge the stack dimension with the channel dimension.
+
+ If S is pfor's stacking dimension, then,
+ - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose
+ should be cheap.
+ - for SNHWC, we transpose to NHWCS.
+ We then merge the S and C dimension.
+
+ Args:
+ x: ops.Tensor to transform.
+ data_format: "NCHW" or "NHWC".
+
+ Returns:
+ A 3-element tuple with the transformed value, along with the shape for
+ reshape and order for transpose required to transform back.
+ """
+
+ graph = ops.get_default_graph()
+ cache_key = (graph, x, data_format)
+ if cache_key not in _channel_flatten_input_cache:
+ x_shape = array_ops.shape(x)
+ if data_format == b"NCHW":
+ order = [1, 0, 2, 3, 4]
+ shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0)
+ reverse_order = order
+ else:
+ order = [1, 2, 3, 0, 4]
+ shape = array_ops.concat([x_shape[1:4], [-1]], axis=0)
+ reverse_order = [3, 0, 1, 2, 4]
+ # Move S dimension next to C dimension.
+ x = array_ops.transpose(x, order)
+ reverse_shape = array_ops.shape(x)
+ # Reshape to merge the S and C dimension.
+ x = array_ops.reshape(x, shape)
+ outputs = x, reverse_order, reverse_shape
+ _channel_flatten_input_cache[cache_key] = outputs
+ else:
+ outputs = _channel_flatten_input_cache[cache_key]
+ return outputs
+
+
+# Note that with training=True, running FusedBatchNorm on individual examples
+# is very different from running FusedBatchNorm on a batch of those examples.
+# This is because, for the latter case, the operation can be considered as first
+# computing the mean and variance over all the examples and then using these
+# to scale all those examples. This creates a data dependency between these
+# different "iterations" since the inputs to the scaling step depends on the
+# statistics coming from all these inputs.
+# As with other kernels, the conversion here effectively runs the kernel
+# independently for each iteration, and returns outputs by stacking outputs from
+# each of those iterations.
+@RegisterPFor("FusedBatchNorm")
+def _convert_fused_batch_norm(pfor_input):
+ is_training = pfor_input.get_attr("is_training")
+ # When BatchNorm is used with training=False, mean and variance are provided
+ # externally and used as is by the op. Thus, we can merge the S and N
+ # dimensions as we do for regular operations.
+ # When BatchNorm is used with training=True, mean and variance are computed
+ # for each channel across the batch dimension (first one). If we merge S and N
+ # dimensions, mean and variances will be computed over a larger set. So, we
+ # merge the S and C dimensions instead.
+ if not is_training:
+ # We return zeros for batch_mean and batch_variance output. Note that CPU
+ # and GPU seem to have different behavior for those two outputs. CPU outputs
+ # zero because these values are not used during inference. GPU outputs
+ # something, probably real means and variances.
+ inputs = _inputs_with_flattening(pfor_input, [0])
+ outputs = _create_op(
+ pfor_input.op_type,
+ inputs, [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ y = outputs[0]
+ n = pfor_input.pfor.loop_len_vector
+ y = _unflatten_first_dim(y, n)
+ mean = pfor_input.unstacked_input(3)
+ zeros = array_ops.zeros_like(mean)
+ return [wrap(y, True), wrap(zeros, False), wrap(zeros, False)]
+
+ pfor_input.stack_inputs()
+ data_format = pfor_input.get_attr("data_format")
+ # We merge the first dimension with the "C" dimension, run FusedBatchNorm, and
+ # then transpose back.
+ x = pfor_input.stacked_input(0)
+ x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format)
+ # Note that we stack all the other inputs as well so that they are the same
+ # size as the new size of the channel dimension.
+ inputs = [x] + [
+ array_ops.reshape(pfor_input.stacked_input(i), [-1])
+ for i in range(1, pfor_input.num_inputs)
+ ]
+ outputs = _create_op(
+ pfor_input.op_type,
+ inputs, [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ y = outputs[0]
+ y = array_ops.reshape(y, reverse_shape)
+ y = array_ops.transpose(y, reverse_order)
+ n = pfor_input.pfor.loop_len_vector
+ outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
+ outputs = [y] + outputs
+ return [wrap(x, True) for x in outputs]
+
+
+@RegisterPFor("FusedBatchNormGrad")
+def _convert_fused_batch_norm_grad(pfor_input):
+ pfor_input.stack_inputs()
+ data_format = pfor_input.get_attr("data_format")
+ y_backprop = pfor_input.stacked_input(0)
+ y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format)
+ x = pfor_input.stacked_input(1)
+ x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format)
+ inputs = [y_backprop, x] + [
+ array_ops.reshape(pfor_input.stacked_input(i), [-1])
+ for i in range(2, pfor_input.num_inputs)
+ ]
+ outputs = _create_op(
+ pfor_input.op_type,
+ inputs, [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ x_backprop = outputs[0]
+ x_backprop = array_ops.reshape(x_backprop, x_reverse_shape)
+ x_backprop = array_ops.transpose(x_backprop, x_reverse_order)
+ n = pfor_input.pfor.loop_len_vector
+ outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
+ outputs = [x_backprop] + outputs
+ return [wrap(output, True) for output in outputs]
+
+
+@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0)
+@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0)
+def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims,
+ shape_dim):
+ del op_type
+ inputs = _inputs_with_flattening(pfor_input, flatten_dims)
+ n = pfor_input.pfor.loop_len_vector
+ # Adjust the `input_sizes` input.
+ ones = array_ops.ones(
+ [array_ops.shape(inputs[shape_dim])[0] - 1], dtype=n.dtype)
+ inputs[shape_dim] *= array_ops.concat([n, ones], axis=0)
+ outputs = _create_op(
+ pfor_input.op_type,
+ inputs, [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ outputs = [_unflatten_first_dim(x, n) for x in outputs]
+ return [wrap(x, True) for x in outputs]
+
+
+@RegisterPFor("Conv2DBackpropFilter")
+def _convert_conv2d_backprop_filter(pfor_input):
+ pfor_input.stack_inputs(stack_indices=[2])
+ inputs, inputs_stacked, _ = pfor_input.input(0)
+ filter_sizes = pfor_input.unstacked_input(1)
+ grads = pfor_input.stacked_input(2)
+ strides = pfor_input.get_attr("strides")
+ padding = pfor_input.get_attr("padding")
+ use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu")
+ data_format = pfor_input.get_attr("data_format")
+ dilations = pfor_input.get_attr("dilations")
+ if inputs_stacked:
+ # TODO(agarwal): Implement this efficiently.
+ logging.warn("Conv2DBackpropFilter uses a while_loop. Fix that!")
+
+ def while_body(i, ta):
+ inp_i = inputs[i, ...]
+ grad_i = grads[i, ...]
+ output = nn_ops.conv2d_backprop_filter(
+ inp_i,
+ filter_sizes,
+ grad_i,
+ strides=strides,
+ padding=padding,
+ use_cudnn_on_gpu=use_cudnn_on_gpu,
+ data_format=data_format,
+ dilations=dilations)
+ return i + 1, ta.write(i, array_ops.expand_dims(output, 0))
+
+ n = array_ops.reshape(pfor_input.pfor.loop_len_vector, [])
+ _, ta = control_flow_ops.while_loop(
+ lambda i, ta: i < n, while_body,
+ (0, tensor_array_ops.TensorArray(inputs.dtype, n)))
+ output = ta.concat()
+ return wrap(output, True)
+ else:
+ # We merge the stack dimension with the channel dimension of the gradients
+ # and pretend we had a larger filter (see change to filter_sizes below).
+ # Once the filter backprop is computed, we reshape and transpose back
+ # appropriately.
+ grads, _, _ = _channel_flatten_input(grads, data_format)
+ n = pfor_input.pfor.loop_len_vector
+ old_filter_sizes = filter_sizes
+ filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0)
+ output = nn_ops.conv2d_backprop_filter(
+ inputs,
+ filter_sizes,
+ grads,
+ strides=strides,
+ padding=padding,
+ use_cudnn_on_gpu=use_cudnn_on_gpu,
+ data_format=data_format,
+ dilations=dilations)
+ new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0)
+ output = array_ops.reshape(output, new_filter_shape)
+ output = array_ops.transpose(output, [3, 0, 1, 2, 4])
+ return wrap(output, True)
+
+
+# array_ops
+
+
+@RegisterPForWithArgs("Identity", array_ops.identity)
+@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient)
+def _convert_identity(pfor_input, op_type, op_func):
+ del op_type
+ return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
+
+
+@RegisterPFor("Reshape")
+def _convert_reshape(pfor_input):
+ t = pfor_input.stacked_input(0)
+ shape = pfor_input.unstacked_input(1)
+ new_dim = array_ops.shape(t)[:1]
+ new_shape = array_ops.concat([new_dim, shape], axis=0)
+ return wrap(array_ops.reshape(t, new_shape), True)
+
+
+@RegisterPFor("ExpandDims")
+def _convert_expanddims(pfor_input):
+ t = pfor_input.stacked_input(0)
+ dim = pfor_input.unstacked_input(1)
+ dim += math_ops.cast(dim >= 0, dtypes.int32)
+ return wrap(array_ops.expand_dims(t, axis=dim), True)
+
+
+@RegisterPFor("Slice")
+def _convert_slice(pfor_input):
+ t = pfor_input.stacked_input(0)
+ begin = pfor_input.unstacked_input(1)
+ size = pfor_input.unstacked_input(2)
+ begin = array_ops.concat([[0], begin], axis=0)
+ size = array_ops.concat([[-1], size], axis=0)
+ return wrap(array_ops.slice(t, begin, size), True)
+
+
+@RegisterPFor("Tile")
+def _convert_tile(pfor_input):
+ t = pfor_input.stacked_input(0)
+ multiples = pfor_input.unstacked_input(1)
+ multiples = array_ops.concat([[1], multiples], 0)
+ return wrap(array_ops.tile(t, multiples), True)
+
+
+@RegisterPFor("Pack")
+def _convert_pack(pfor_input):
+ pfor_input.stack_inputs()
+ axis = pfor_input.get_attr("axis")
+ if axis >= 0:
+ axis += 1
+ return wrap(
+ array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True)
+
+
+@RegisterPFor("Unpack")
+def _convert_unpack(pfor_input):
+ value = pfor_input.stacked_input(0)
+ axis = pfor_input.get_attr("axis")
+ if axis >= 0:
+ axis += 1
+ num = pfor_input.get_attr("num")
+ return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)]
+
+
+@RegisterPFor("Pad")
+def _convert_pad(pfor_input):
+ t = pfor_input.stacked_input(0)
+ paddings = pfor_input.unstacked_input(1)
+ paddings = array_ops.concat([[[0, 0]], paddings], 0)
+ return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True)
+
+
+@RegisterPFor("Split")
+def _convert_split(pfor_input):
+ split_dim = pfor_input.unstacked_input(0)
+ t = pfor_input.stacked_input(1)
+ num_split = pfor_input.get_attr("num_split")
+ split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
+ return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)]
+
+
+@RegisterPFor("Transpose")
+def _convert_transpose(pfor_input):
+ t = pfor_input.stacked_input(0)
+ perm = pfor_input.unstacked_input(1)
+ new_perm = array_ops.concat([[0], perm + 1], axis=0)
+ return wrap(array_ops.transpose(t, new_perm), True)
+
+
+@RegisterPFor("ZerosLike")
+def _convert_zeroslike(pfor_input):
+ t = pfor_input.stacked_input(0)
+ shape = array_ops.shape(t)[1:]
+ return wrap(array_ops.zeros(shape, dtype=t.dtype), False)
+
+
+@RegisterPFor("Gather")
+@RegisterPFor("GatherV2")
+def _convert_gather(pfor_input):
+ param, param_stacked, _ = pfor_input.input(0)
+ indices, indices_stacked, _ = pfor_input.input(1)
+ op_type = pfor_input.op_type
+ if op_type == "Gather":
+ validate_indices = pfor_input.get_attr("validate_indices")
+ axis = 0
+ else:
+ validate_indices = None
+ axis = pfor_input.unstacked_input(2)
+ axis_value = tensor_util.constant_value(axis)
+ if axis_value is not None:
+ axis = axis_value
+ if indices_stacked and not param_stacked:
+ if indices == pfor_input.pfor.all_indices and axis == 0:
+ param_shape0 = param.shape[0].value
+ indices_shape0 = indices.shape[0].value
+ if param_shape0 is not None and indices_shape0 == param_shape0:
+ # Note that with loops and conditionals, indices may not be contiguous.
+ # However they will be sorted and unique. So if the shape matches, then
+ # it must be picking up all the rows of param.
+ return wrap(param, True)
+ # TODO(agarwal): use array_ops.slice here.
+ output = array_ops.gather(
+ param, indices, validate_indices=validate_indices, axis=axis)
+ if axis != 0:
+ axis = control_flow_ops.cond(
+ axis < 0, lambda: axis + array_ops.rank(param), lambda: axis)
+ order = array_ops.concat(
+ [[axis],
+ math_ops.range(axis),
+ math_ops.range(axis + 1, array_ops.rank(output))],
+ axis=0)
+ output = control_flow_ops.cond(
+ math_ops.equal(axis, 0), lambda: output,
+ lambda: array_ops.transpose(output, order))
+ return wrap(output, True)
+ if param_stacked:
+ loop_len_vector = pfor_input.pfor.loop_len_vector
+ pfor_input.stack_inputs(stack_indices=[1])
+ indices = pfor_input.stacked_input(1)
+ param_flat = _flatten_first_two_dims(param)
+
+ # Recompute indices to handle stacked param.
+ indices_offset = math_ops.range(
+ loop_len_vector[0]) * array_ops.shape(param)[1]
+ # Reshape indices_offset to allow broadcast addition
+ ones = array_ops.ones([array_ops.rank(indices) - 1], dtype=dtypes.int32)
+ new_shape = array_ops.concat([loop_len_vector, ones], axis=0)
+ indices_offset = array_ops.reshape(indices_offset, new_shape)
+ indices += indices_offset
+
+ # TODO(agarwal): handle axis != 0. May need to transpose param or
+ # array_ops.gather_nd.
+ if isinstance(axis, ops.Tensor):
+ axis_value = tensor_util.constant_value(axis)
+ else:
+ try:
+ axis_value = int(axis)
+ except TypeError:
+ axis_value = None
+ msg = ("Gather, where indices and param are both loop dependent, currently "
+ "requires axis=0")
+ if axis_value is not None and axis_value != 0:
+ raise ValueError("Error while converting %s. %s. Got axis=%d" %
+ (pfor_input.op, msg, axis))
+ with ops.control_dependencies(
+ [check_ops.assert_equal(axis, 0, message=msg)]):
+ output = array_ops.gather(param_flat, indices)
+ return wrap(output, True)
+
+
+@RegisterPFor("ConcatV2")
+def _convert_concatv2(pfor_input):
+ n = pfor_input.num_inputs
+ pfor_input.stack_inputs(stack_indices=range(n - 1))
+ axis = pfor_input.unstacked_input(n - 1)
+ axis += math_ops.cast(axis >= 0, axis.dtype)
+ return wrap(
+ array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis),
+ True)
+
+
+@RegisterPFor("StridedSlice")
+def _convert_strided_slice(pfor_input):
+ inp = pfor_input.stacked_input(0)
+ begin = pfor_input.unstacked_input(1)
+ end = pfor_input.unstacked_input(2)
+ strides = pfor_input.unstacked_input(3)
+ begin_mask = pfor_input.get_attr("begin_mask")
+ end_mask = pfor_input.get_attr("end_mask")
+ ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
+ new_axis_mask = pfor_input.get_attr("new_axis_mask")
+ shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
+
+ begin = array_ops.concat([[0], begin], axis=0)
+ end = array_ops.concat([[0], end], axis=0)
+ strides = array_ops.concat([[1], strides], axis=0)
+ begin_mask = begin_mask << 1 | 1
+ end_mask = end_mask << 1 | 1
+ ellipsis_mask <<= 1
+ new_axis_mask <<= 1
+ shrink_axis_mask <<= 1
+ return wrap(
+ array_ops.strided_slice(
+ inp,
+ begin,
+ end,
+ strides,
+ begin_mask=begin_mask,
+ end_mask=end_mask,
+ ellipsis_mask=ellipsis_mask,
+ new_axis_mask=new_axis_mask,
+ shrink_axis_mask=shrink_axis_mask), True)
+
+
+@RegisterPFor("StridedSliceGrad")
+def _convert_strided_slice_grad(pfor_input):
+ shape = pfor_input.unstacked_input(0)
+ begin = pfor_input.unstacked_input(1)
+ end = pfor_input.unstacked_input(2)
+ strides = pfor_input.unstacked_input(3)
+ dy = pfor_input.stacked_input(4)
+ begin_mask = pfor_input.get_attr("begin_mask")
+ end_mask = pfor_input.get_attr("end_mask")
+ ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
+ new_axis_mask = pfor_input.get_attr("new_axis_mask")
+ shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
+
+ shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
+ begin = array_ops.concat([[0], begin], axis=0)
+ end = array_ops.concat([[0], end], axis=0)
+ strides = array_ops.concat([[1], strides], axis=0)
+ begin_mask = begin_mask << 1 | 1
+ end_mask = end_mask << 1 | 1
+ ellipsis_mask <<= 1
+ new_axis_mask <<= 1
+ shrink_axis_mask <<= 1
+ return wrap(
+ array_ops.strided_slice_grad(
+ shape,
+ begin,
+ end,
+ strides,
+ dy,
+ begin_mask=begin_mask,
+ end_mask=end_mask,
+ ellipsis_mask=ellipsis_mask,
+ new_axis_mask=new_axis_mask,
+ shrink_axis_mask=shrink_axis_mask), True)
+
+
+# math_ops
+
+
+@RegisterPFor("MatMul")
+def _convert_matmul(pfor_input):
+ # TODO(agarwal): Check if tiling is faster than two transposes.
+ a, a_stacked, _ = pfor_input.input(0)
+ b, b_stacked, _ = pfor_input.input(1)
+ tr_a = pfor_input.get_attr("transpose_a")
+ tr_b = pfor_input.get_attr("transpose_b")
+ if a_stacked and b_stacked:
+ output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True)
+ return output
+ elif a_stacked:
+ if tr_a:
+ a = array_ops.transpose(a, [0, 2, 1])
+ if a.shape.is_fully_defined():
+ x, y, z = a.shape
+ else:
+ x, y, z = [
+ array_ops.reshape(i, [])
+ for i in array_ops.split(array_ops.shape(a), 3)
+ ]
+ a = array_ops.reshape(a, [x * y, z])
+ prod = math_ops.matmul(a, b, transpose_b=tr_b)
+ return wrap(array_ops.reshape(prod, [x, y, -1]), True)
+ else:
+ assert b_stacked
+ if tr_b:
+ perm = [2, 0, 1]
+ b = array_ops.transpose(b, perm)
+ else:
+ # As an optimization, if one of the first two dimensions is 1, then we can
+ # reshape instead of transpose.
+ # TODO(agarwal): This check can be done inside Transpose kernel.
+ b_shape = array_ops.shape(b)
+ min_dim = math_ops.minimum(b_shape[0], b_shape[1])
+ perm = control_flow_ops.cond(
+ math_ops.equal(min_dim, 1), lambda: [0, 1, 2], lambda: [1, 0, 2])
+ new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]])
+ b = array_ops.transpose(b, perm)
+ b = array_ops.reshape(b, new_shape)
+
+ if b.shape.is_fully_defined():
+ x, y, z = b.shape
+ else:
+ x, y, z = [
+ array_ops.reshape(i, [])
+ for i in array_ops.split(array_ops.shape(b), 3)
+ ]
+ b = array_ops.reshape(b, [x, y * z])
+ prod = math_ops.matmul(a, b, transpose_a=tr_a)
+ prod = array_ops.reshape(prod, [-1, y, z])
+ prod = array_ops.transpose(prod, [1, 0, 2])
+ return wrap(prod, True)
+
+
+@RegisterPFor("BatchMatMul")
+def _convert_batch_mat_mul(pfor_input):
+ # TODO(agarwal): There may be a more efficient way to do this instead of
+ # stacking the inputs.
+ pfor_input.stack_inputs()
+ x = pfor_input.stacked_input(0)
+ y = pfor_input.stacked_input(1)
+ adj_x = pfor_input.get_attr("adj_x")
+ adj_y = pfor_input.get_attr("adj_y")
+
+ x = _flatten_first_two_dims(x)
+ y = _flatten_first_two_dims(y)
+ output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
+ output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector)
+ return wrap(output, True)
+
+
+@RegisterPForWithArgs("Sum", math_ops.reduce_sum)
+@RegisterPForWithArgs("Prod", math_ops.reduce_prod)
+@RegisterPForWithArgs("Max", math_ops.reduce_max)
+@RegisterPForWithArgs("Min", math_ops.reduce_min)
+def _convert_reduction(pfor_input, _, op_func):
+ t = pfor_input.stacked_input(0)
+ indices = pfor_input.unstacked_input(1)
+ # Shift positive indices by one to account for the extra dimension.
+ indices += math_ops.cast(indices >= 0, dtypes.int32)
+ keep_dims = pfor_input.get_attr("keep_dims")
+ return wrap(op_func(t, indices, keepdims=keep_dims), True)
+
+
+@RegisterPForWithArgs("Cumsum", math_ops.cumsum)
+@RegisterPForWithArgs("Cumprod", math_ops.cumprod)
+def _convert_cumfoo(pfor_input, _, op_func):
+ t = pfor_input.stacked_input(0)
+ axis = pfor_input.unstacked_input(1)
+ # Shift positive indices by one to account for the extra dimension.
+ axis += math_ops.cast(axis >= 0, dtypes.int32)
+ exclusive = pfor_input.get_attr("exclusive")
+ reverse = pfor_input.get_attr("reverse")
+ return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True)
+
+
+@RegisterPFor("BiasAdd")
+def _convert_biasadd(pfor_input):
+ t = pfor_input.stacked_input(0)
+ bias = pfor_input.unstacked_input(1)
+ data_format = pfor_input.get_attr("data_format")
+ if data_format != b"NCHW":
+ return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True)
+ shape = array_ops.shape(t)
+ flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0)
+ t = array_ops.reshape(t, flattened_shape)
+ t = nn_ops.bias_add(t, bias, data_format=b"NCHW")
+ t = array_ops.reshape(t, shape)
+ return wrap(t, True)
+
+
+@RegisterPFor("UnsortedSegmentSum")
+def _convert_unsortedsegmentsum(pfor_input):
+ data, data_stacked, _ = pfor_input.input(0)
+ # TODO(agarwal): handle unstacked?
+ segment_ids = pfor_input.stacked_input(1)
+ # TODO(agarwal): handle stacked?
+ num_segments = pfor_input.unstacked_input(2)
+ if not data_stacked:
+ data = _stack(data, pfor_input.pfor.loop_len_vector).t
+ segment_shape = array_ops.shape(segment_ids)
+ n = segment_shape[0]
+ ones = array_ops.ones_like(segment_shape)[1:]
+ segment_offset = num_segments * math_ops.range(n)
+ segment_offset = array_ops.reshape(segment_offset,
+ array_ops.concat([[n], ones], axis=0))
+ segment_ids += segment_offset
+ num_segments *= n
+ output = math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
+ new_output_shape = array_ops.concat(
+ [[n, -1], array_ops.shape(output)[1:]], axis=0)
+ output = array_ops.reshape(output, new_output_shape)
+ return wrap(output, True)
+
+
+@RegisterPFor("Cast")
+def _convert_cast(pfor_input):
+ inp = pfor_input.stacked_input(0)
+ dtype = pfor_input.get_attr("DstT")
+ return wrap(math_ops.cast(inp, dtype), True)
+
+
+# Note that ops handled here do not have attributes except "T", and hence don't
+# need extra arguments passed to the cwise_op call below.
+@RegisterPForWithArgs("Add", math_ops.add)
+@RegisterPForWithArgs("Ceil", math_ops.ceil)
+@RegisterPForWithArgs("Equal", math_ops.equal)
+@RegisterPForWithArgs("NotEqual", math_ops.not_equal)
+@RegisterPForWithArgs("Floor", math_ops.floor)
+@RegisterPForWithArgs("Greater", math_ops.greater)
+@RegisterPForWithArgs("GreaterEqual", math_ops.greater_equal)
+@RegisterPForWithArgs("Less", math_ops.less)
+@RegisterPForWithArgs("LessEqual", math_ops.less_equal)
+@RegisterPForWithArgs("LogicalOr", math_ops.logical_or)
+@RegisterPForWithArgs("LogicalAnd", math_ops.logical_and)
+@RegisterPForWithArgs("LogicalNot", math_ops.logical_not)
+@RegisterPForWithArgs("LogicalXor", math_ops.logical_xor)
+@RegisterPForWithArgs("Maximum", math_ops.maximum)
+@RegisterPForWithArgs("Minimum", math_ops.minimum)
+@RegisterPForWithArgs("Mul", math_ops.multiply)
+@RegisterPForWithArgs("Neg", math_ops.negative)
+@RegisterPForWithArgs("RealDiv", math_ops.divide)
+@RegisterPForWithArgs("Relu", nn_ops.relu)
+@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid)
+@RegisterPForWithArgs("Square", math_ops.square)
+@RegisterPForWithArgs("Sub", math_ops.subtract)
+@RegisterPForWithArgs("Tanh", math_ops.tanh)
+def _convert_cwise(pfor_input, op_type, op_func):
+ del op_type
+ pfor_input.expanddim_inputs_for_broadcast()
+ return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
+
+
+@RegisterPFor("Shape")
+def _convert_shape(pfor_input):
+ out_type = pfor_input.get_attr("out_type")
+ return wrap(
+ array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:],
+ False)
+
+
+@RegisterPFor("ShapeN")
+def _convert_shape_n(pfor_input):
+ out_type = pfor_input.get_attr("out_type")
+ shapes = [
+ array_ops.shape(x, out_type=out_type)[1:]
+ if stacked else array_ops.shape(x) for x, stacked, _ in pfor_input.inputs
+ ]
+ return [wrap(x, False) for x in shapes]
+
+
+@RegisterPFor("Size")
+def _convert_size(pfor_input):
+ out_type = pfor_input.get_attr("out_type")
+ n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type)
+ return wrap(
+ array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n,
+ False)
+
+
+@RegisterPFor("Rank")
+def _convert_rank(pfor_input):
+ return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False)
+
+
+@RegisterPFor("AddN")
+def _convert_addn(pfor_input):
+ # AddN does not support broadcasting.
+ pfor_input.stack_inputs()
+ return wrap(math_ops.add_n([x.t for x in pfor_input.inputs]), True)
+
+
+@RegisterPFor("BiasAddGrad")
+def _convert_biasaddgrad(pfor_input):
+ grad = pfor_input.stacked_input(0)
+ fmt = pfor_input.get_attr("data_format")
+ if fmt == b"NCHW":
+ output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False)
+ else:
+ grad_shape = array_ops.shape(grad)
+ last_dim_shape = grad_shape[-1]
+ first_dim_shape = grad_shape[0]
+ output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape])
+ output = math_ops.reduce_sum(output, axis=[1], keepdims=False)
+ return wrap(output, True)
+
+
+# Some required ops are not exposed under the tf namespace. Hence relying on
+# _create_op to create them.
+@RegisterPForWithArgs("ReluGrad")
+@RegisterPForWithArgs("TanhGrad")
+@RegisterPForWithArgs("SigmoidGrad")
+def _convert_grads(pfor_input, op_type, *args, **kw_args):
+ del args
+ del kw_args
+ # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we
+ # have to use tiling here.
+ pfor_input.stack_inputs()
+ outputs = _create_op(
+ op_type, [x.t for x in pfor_input.inputs],
+ [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ return [wrap(x, True) for x in outputs]
+
+
+@RegisterPFor("Select")
+def _convert_select(pfor_input):
+ pfor_input.stack_inputs()
+ cond = pfor_input.stacked_input(0)
+ t = pfor_input.stacked_input(1)
+ e = pfor_input.stacked_input(2)
+ cond_rank = array_ops.rank(cond)
+ cond, t, e = control_flow_ops.cond(
+ cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]),
+ lambda: [cond, t, e])
+ outputs = _create_op(
+ pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ n = pfor_input.pfor.loop_len_vector
+ out = control_flow_ops.cond(cond_rank > 1,
+ lambda: _unflatten_first_dim(outputs[0], n),
+ lambda: outputs[0])
+ return [wrap(out, True) for x in outputs]
+
+
+# random_ops
+
+
+@RegisterPForWithArgs("RandomUniform")
+@RegisterPForWithArgs("RandomUniformInt")
+@RegisterPForWithArgs("RandomStandardNormal")
+@RegisterPForWithArgs("TruncatedNormal")
+@RegisterPForWithArgs("RandomGamma")
+@RegisterPForWithArgs("RandomPoissonV2")
+def _convert_random(pfor_input, op_type, *args, **kw_args):
+ del args
+ del kw_args
+ inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)]
+ # inputs[0] is "shape"
+ inputs[0] = array_ops.concat(
+ [pfor_input.pfor.loop_len_vector, inputs[0]], axis=0)
+ logging.warning(
+ "Note that %s inside pfor op may not give same output as "
+ "inside a sequential loop.", op_type)
+ outputs = _create_op(
+ op_type,
+ inputs, [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ return [wrap(x, True) for x in outputs]
+
+
+# logging_ops
+
+
+@RegisterPFor("Assert")
+def _convert_assert(pfor_input):
+ cond, cond_stacked, _ = pfor_input.input(0)
+ if cond_stacked:
+ cond = math_ops.reduce_all(cond)
+
+ data_list = [x.t for x in pfor_input.inputs][1:]
+ return _create_op("Assert", [cond] + data_list, [],
+ attrs=pfor_input.op.node_def.attr)
+
+
+@RegisterPFor("Print")
+def _convert_print(pfor_input):
+ # Note that we don't stack all the inputs. Hence unstacked values are printed
+ # once here vs multiple times in a while_loop.
+ pfor_input.stack_inputs([0])
+ outputs = _create_op(
+ "Print", [x.t for x in pfor_input.inputs],
+ [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ return [wrap(x, True) for x in outputs]
+
+
+# data_flow_ops
+
+# TensorArray conversion is tricky since we don't support arrays of
+# TensorArrays. For converting them, we consider two distinct cases:
+#
+# 1. The array is constructed outside the pfor call, and read/written inside the
+# loop.
+# This is an easier case since we don't need to make an array of TensorArrays.
+# A correctness requirement is that these parallel iterations shouldn't attempt
+# to write to the same location. Hence at conversion time we disallow indices to
+# be loop-invariant as that would guarantee a collision. Even if the indices are
+# not loop-invariant, they could conflict and that shall trigger runtime errors.
+#
+# 2. The array is constructed and used entirely inside each pfor iteration.
+# For simplicity, here we require that the indices used for write/scatter are
+# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in
+# different pfor iterations. We consider two sub_cases:
+#
+# 2a Elements written to the array are "stacked"
+# To simulate multiple TensorArrays, we may increase the dimension of each
+# element of the array. i.e. the i_th row of the j_th entry of the converted
+# TensorArray corresponds to to the j_th entry of the TensorArray in the i_th
+# pfor iteration.
+#
+# 2b Elements written to the array are "unstacked"
+# In this case we don't increase the dimensions to avoid redundant tiling. Each
+# iteration is trying to write the same value. So we convert that to a single
+# write.
+#
+# Here are some tricks used to implement the above:
+# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of
+# trying to trace whether future writes are stacked or unstacked in order to set
+# this attr, we set it to correspond to unknown shape.
+# - We use the "flow" output of the different ops to track whether the array
+# elements are stacked or unstacked. If a stacked write/scatter is done, we make
+# the flow stacked as well.
+# - We use some heuristic traversal of the graph to track whether the
+# TensorArray handle was created inside or outside the pfor loop.
+
+
+@RegisterPFor("TensorArrayV3")
+def _convert_tensor_array_v3(pfor_input):
+ size = pfor_input.unstacked_input(0)
+ dtype = pfor_input.get_attr("dtype")
+ dynamic_size = pfor_input.get_attr("dynamic_size")
+ clear_after_read = pfor_input.get_attr("clear_after_read")
+ identical_element_shapes = pfor_input.get_attr("identical_element_shapes")
+ tensor_array_name = pfor_input.get_attr("tensor_array_name")
+ handle, flow = data_flow_ops.tensor_array_v3(
+ size,
+ dtype=dtype,
+ # We don't set element shape since we don't know if writes are stacked or
+ # not yet.
+ element_shape=None,
+ dynamic_size=dynamic_size,
+ clear_after_read=clear_after_read,
+ identical_element_shapes=identical_element_shapes,
+ tensor_array_name=tensor_array_name)
+ # Note we keep flow unstacked for now since we don't know if writes will be
+ # stacked or not.
+ return wrap(handle, False), wrap(flow, False)
+
+
+@RegisterPFor("TensorArraySizeV3")
+def _convert_tensor_array_size_v3(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ flow, flow_stacked, _ = pfor_input.input(1)
+ if flow_stacked:
+ flow = _unstack_flow(flow)
+ size = data_flow_ops.tensor_array_size_v3(handle, flow)
+ return wrap(size, False)
+
+
+def _handle_inside_pfor(pfor_input, handle):
+ """Returns True if handle was created inside the pfor loop."""
+ # We use some heuristic to find the original TensorArray creation op.
+ # The logic should handle the common cases (except cond based subgraphs).
+ # In theory the user could perform different operations on the handle (like
+ # Reshape, stack multiple handles, etc) which could break this logic.
+ # TODO(agarwal): handle Switch/Merge.
+ while handle.op.type in ("Enter", "Identity"):
+ handle = handle.op.inputs[0]
+ if handle.op.type not in [
+ "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape"]:
+ raise ValueError("Unable to find source for handle %s" % handle)
+ else:
+ return pfor_input.pfor.op_is_inside_loop(handle.op)
+
+
+def _unstack_flow(value):
+ # TODO(agarwal): consider looking if this is a Tile op then get its input.
+ # This may avoid running the Tile operations.
+ return array_ops.gather(value, 0)
+
+
+@RegisterPFor("TensorArrayReadV3")
+def _convert_tensor_array_read_v3(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ index, index_stacked, _ = pfor_input.input(1)
+ dtype = pfor_input.get_attr("dtype")
+ flow, flow_stacked, _ = pfor_input.input(2)
+ if flow_stacked:
+ flow = _unstack_flow(flow)
+
+ is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
+ if is_inside_pfor:
+ # Note that if we are inside a control flow construct inside the pfor, and
+ # only some of the iterations are doing the read (i.e.
+ # `all_indices_partitioned` is True), then the read operation should only
+ # return values for the currently active pfor iterations (`all_indices`
+ # below). Hence, whenever the returned value is stacked (i.e. `flow` is
+ # stacked), we may need to do an extra gather after reading the values. Also
+ # note that if `is_inside` is false, then values in the tensor array are
+ # unstacked. So the check is only needed in this branch.
+ all_indices = pfor_input.pfor.all_indices
+ all_indices_partitioned = pfor_input.pfor.all_indices_partitioned
+ # Note: flow_stacked indicates if values in the TensorArray are stacked or
+ # not.
+ if index_stacked:
+ if flow_stacked:
+ raise ValueError(
+ "It looks like TensorArrayReadV3 was called on a TensorArray whose"
+ " values are not loop-invariant, and the read indices were also"
+ " not loop invariant. This is currently unsupported.")
+ value = data_flow_ops.tensor_array_gather_v3(
+ handle, index, flow, dtype=dtype)
+ return wrap(value, True)
+ value = data_flow_ops.tensor_array_read_v3(
+ handle, index, flow, dtype=dtype)
+ if flow_stacked and all_indices_partitioned:
+ value = array_ops.gather(value, all_indices)
+ return wrap(value, flow_stacked)
+ # Values in the TensorArray should be unstacked (since different iterations
+ # couldn't write to the same location). So whether output is stacked or not
+ # depends on index_stacked.
+ if index_stacked:
+ value = data_flow_ops.tensor_array_gather_v3(
+ handle, index, flow, dtype=dtype)
+ else:
+ value = data_flow_ops.tensor_array_read_v3(
+ handle, index, flow, dtype=dtype)
+ return wrap(value, index_stacked)
+
+
+@RegisterPFor("TensorArrayWriteV3")
+def _convert_tensor_array_write_v3(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ index, index_stacked, _ = pfor_input.input(1)
+ value, value_stacked, _ = pfor_input.input(2)
+ flow, flow_stacked, _ = pfor_input.input(3)
+ if value_stacked and pfor_input.pfor.all_indices_partitioned:
+ # Looks like we are in a control flow in a pfor where not all iterations are
+ # active now. We don't allow that since that could lead to different indices
+ # having different shapes which will be hard to merge later.
+ raise ValueError("Writing non loop invariant values to TensorArray from "
+ "inside a while_loop/cond not supported.")
+ if flow_stacked:
+ flow = _unstack_flow(flow)
+ is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
+ if is_inside:
+ if index_stacked:
+ raise ValueError("Need indices for %s to be loop invariant" % handle)
+ if not flow_stacked and not value_stacked:
+ flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
+ return wrap(flow_out, False)
+ else:
+ if not value_stacked:
+ value = _stack(value, pfor_input.pfor.loop_len_vector).t
+ # TODO(agarwal): Note that if flow is unstacked and value is stacked, then
+ # this may or may not be a safe situation. flow is unstacked both for a
+ # freshly created TensorArray, as well as after unstacked values are
+ # written to it. If it is the latter, then we cannot write a stacked value
+ # now since that may cause runtime errors due to different shapes in the
+ # array. At the moment we are not able to handle this gracefully and
+ # distinguish between the two cases. That would require some heuristic
+ # traversal of the graph to figure out whether all the writes are
+ # unstacked or not.
+ flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
+ return _stack(flow_out, pfor_input.pfor.loop_len_vector)
+ else:
+ if not index_stacked:
+ raise ValueError("Need indices for %s to be not loop invariant" % handle)
+ # Note that even when index_stacked is true, actual values in index may
+ # still not be unique. However that will cause runtime error when executing
+ # the scatter operation below.
+ if not value_stacked:
+ value = _stack(value, pfor_input.pfor.loop_len_vector).t
+ flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow)
+ return _stack(flow_out, pfor_input.pfor.loop_len_vector)
+
+
+def _transpose_first_two_dims(value):
+ # TODO(agarwal): optimize if one of the dims == 1.
+ value_shape = array_ops.shape(value)
+ v0 = value_shape[0]
+ v1 = value_shape[1]
+ value = array_ops.reshape(value, [v0, v1, -1])
+ value = array_ops.transpose(value, [1, 0, 2])
+ new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0)
+ return array_ops.reshape(value, new_shape)
+
+
+@RegisterPFor("TensorArrayGatherV3")
+def _convert_tensor_array_gather_v3(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ indices, indices_stacked, _ = pfor_input.input(1)
+ indices = array_ops.reshape(indices, [-1])
+ flow, flow_stacked, _ = pfor_input.input(2)
+ if flow_stacked:
+ flow = _unstack_flow(flow)
+ dtype = pfor_input.get_attr("dtype")
+ # TODO(agarwal): support element_shape attr?
+
+ n = pfor_input.pfor.loop_len_vector
+ value = data_flow_ops.tensor_array_gather_v3(
+ handle, indices, flow, dtype=dtype)
+ is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
+ if is_inside:
+ # flow_stacked indicates if values in the TensorArray are stacked or not.
+ if indices_stacked:
+ if flow_stacked:
+ raise ValueError(
+ "It looks like TensorArrayGatherV3 was called on a TensorArray "
+ "whose values are not loop-invariant, and the indices were also "
+ "not loop invariant. This is currently unsupported.")
+ else:
+ value = _unflatten_first_dim(value, n)
+ return wrap(value, True)
+ else:
+ if flow_stacked:
+ # Since elements in this array are stacked and `value` was produced by
+ # gather, its first two dims are "gathered elements" and "stack
+ # dimension". Our semantics require these two to be flipped.
+ value = _transpose_first_two_dims(value)
+ return wrap(value, flow_stacked)
+ else:
+ # Values in the TensorArray should be unstacked (since different iterations
+ # couldn't write to the same location). So whether output is stacked or not
+ # depends on indices_stacked.
+ if indices_stacked:
+ value = _unflatten_first_dim(value, n)
+ return wrap(value, indices_stacked)
+
+
+@RegisterPFor("TensorArrayScatterV3")
+def _convert_tensor_array_scatter_v3(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ indices, indices_stacked, _ = pfor_input.input(1)
+ indices = array_ops.reshape(indices, [-1])
+ value, value_stacked, _ = pfor_input.input(2)
+ flow, flow_stacked, _ = pfor_input.input(3)
+
+ if flow_stacked:
+ flow = _unstack_flow(flow)
+
+ is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
+ if is_inside:
+ if indices_stacked:
+ raise ValueError("Need indices for %s to be loop invariant" % handle)
+ # Note that flow_stacked indicates if existing values in the array are
+ # stacked or not.
+ if not flow_stacked and not value_stacked:
+ flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
+ flow)
+ return wrap(flow_out, False)
+ if not value_stacked:
+ # TODO(agarwal): tile in the second dimension directly instead of
+ # transposing below.
+ value = _stack(value, pfor_input.pfor.loop_len_vector).t
+
+ value = _transpose_first_two_dims(value)
+ # TODO(agarwal): Note that if a previous write was unstacked, flow will be
+ # unstacked, and a stacked value may be written here which may cause
+ # runtime error due to different elements having different shape. We do
+ # not try to prevent that.
+ flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
+ flow)
+ return _stack(flow_out, pfor_input.pfor.loop_len_vector)
+ if not indices_stacked:
+ raise ValueError("Need indices for %s to be not loop invariant" % handle)
+ if not value_stacked:
+ value = _stack(value, pfor_input.pfor.loop_len_vector).t
+ value = _flatten_first_two_dims(value)
+ flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
+ flow)
+ return _stack(flow_out, pfor_input.pfor.loop_len_vector)
+
+
+@RegisterPFor("TensorArrayGradV3")
+def _convert_tensor_array_grad_v3(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ flow, flow_stacked, _ = pfor_input.input(1)
+ if flow_stacked:
+ flow = _unstack_flow(flow)
+ source = pfor_input.get_attr("source")
+ # TODO(agarwal): For now, we assume that gradients are stacked if the
+ # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong
+ # will give runtime error due to incorrect shape being written to the
+ # accumulator. It is difficult to know in advance if gradients written will be
+ # stacked or not. Note that flow being stacked is not indicative of the
+ # gradient being stacked or not. Revisit this later.
+ shape_to_prepend = pfor_input.pfor.loop_len_vector
+ grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape(
+ handle=handle,
+ flow_in=flow,
+ shape_to_prepend=shape_to_prepend,
+ source=source)
+ flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t
+ return [wrap(grad_handle, False), wrap(flow_out, True)]
+
+
+# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar
+# to TensorArrays, we convert them by changing the dimension of the elements
+# inside the stack.
+#
+# We consider two cases:
+#
+# 1. StackV2 is constructed and used entirely inside the pfor loop.
+# We keep a single Stack and perform the push/pop operations of all the
+# iterations in lock-step. We also assume that all the iterations perform these
+# operations. In case of dynamic control flow, if only some of the iterations
+# try to perform a push/pop, then the conversion may not work correctly and may
+# cause undefined behavior.
+# TODO(agarwal): test StackV2 with dynamic control flow.
+#
+# 2. StackV2 is constructed outside the pfor loop.
+# Performing stack push/pop in a parallel fashion is ill-defined. However given
+# that reading stacks created externally is a common operation when computing
+# jacobians, we provide some special semantics here as follows.
+# - disallow push operations to the stack
+# - pop operations are performed in lock step by all iterations, similar to the
+# case when the stack is created inside. A single value is popped during the
+# lock-step operation and broadcast to all the iterations. Values in the stack
+# are assumed to be loop-invariant.
+#
+# Some other implementation details:
+# We use an ugly logic to find whether values in Stack data structure are
+# loop invariant or not. When converting push/pop operations, we keep track of
+# whether the last conversion used a stacked value or not (see _stack_cache
+# below). As a result if an unstacked value is written first, subsequent stacked
+# writes are disallowed when they could have been allowed in theory.
+
+# Map from cache key based on StackV2 handle to a bool indicating whether values
+# are stacked or not.
+# TODO(agarwal): move _stack_cache inside pfor?
+_stack_cache = {}
+
+
+def _stack_cache_key(pfor_input):
+ """Create cache key corresponding to a stack handle."""
+ op_type = pfor_input.op_type
+ assert op_type in ["StackPushV2", "StackPopV2"], op_type
+ orig_handle = pfor_input.op.inputs[0]
+ while orig_handle.op.type in ["Identity", "Enter"]:
+ orig_handle = orig_handle.op.inputs[0]
+ assert orig_handle.op.type == "StackV2", orig_handle.op
+ return ops.get_default_graph(), pfor_input.pfor, orig_handle
+
+
+def _stack_handle_inside_pfor(handle, pfor_input):
+ while handle.op.type in ["Identity", "Enter"]:
+ handle = handle.op.inputs[0]
+ assert handle.op.type == "StackV2", (
+ "Unable to find StackV2 op. Got %s" % handle.op)
+ return pfor_input.pfor.op_is_inside_loop(handle.op)
+
+
+@RegisterPFor("StackPushV2")
+def _convert_stack_push_v2(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ elem, elem_stacked, _ = pfor_input.input(1)
+ swap_memory = pfor_input.get_attr("swap_memory")
+
+ if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input):
+ raise ValueError("StackPushV2 not allowed on stacks created outside pfor")
+ stack_cache_key = _stack_cache_key(pfor_input)
+ stacked = _stack_cache.get(stack_cache_key, None)
+ if stacked is None:
+ stacked = elem_stacked
+ _stack_cache[stack_cache_key] = stacked
+ else:
+ # If we previously made it unstacked then we can't revert to being stacked.
+ if not stacked and elem_stacked:
+ raise ValueError(
+ "It looks like the stack was previously determined to be loop"
+ " invariant, but we are now trying to push a loop dependent value"
+ " to it. This is currently unsupported.")
+ if stacked and not elem_stacked:
+ elem = _stack(elem, pfor_input.pfor.loop_len_vector).t
+ out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory)
+ return wrap(out, stacked)
+
+
+# Note that inputs to this convertor will be unstacked. However it should get
+# called since it is a stateful op.
+@RegisterPFor("StackPopV2")
+def _convert_stack_pop_v2(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ stack_cache_key = _stack_cache_key(pfor_input)
+ stacked = _stack_cache.get(stack_cache_key, None)
+ # If a StackPushV2 has not been converted yet, we default to unstacked since
+ # the push could be outside of pfor, or the covertor may not be called if the
+ # inputs are unconverted.
+ if stacked is None:
+ stacked = False
+ _stack_cache[stack_cache_key] = False
+ elem_type = pfor_input.get_attr("elem_type")
+ out = data_flow_ops.stack_pop_v2(handle, elem_type)
+ return wrap(out, stacked)
+
+
+# parsing_ops
+
+
+@RegisterPFor("DecodeCSV")
+def _convert_decode_csv(pfor_input):
+ lines = pfor_input.stacked_input(0)
+ record_defaults = [
+ pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
+ ]
+ field_delim = pfor_input.get_attr("field_delim")
+ use_quote_delim = pfor_input.get_attr("use_quote_delim")
+ select_cols = pfor_input.get_attr("select_cols")
+ if not select_cols:
+ select_cols = None
+ return [
+ wrap(t, True) for t in parsing_ops.decode_csv(
+ lines,
+ record_defaults,
+ field_delim=field_delim,
+ use_quote_delim=use_quote_delim,
+ select_cols=select_cols)
+ ]
+
+
+@RegisterPFor("ParseSingleExample")
+def _convert_parse_single_example(pfor_input):
+ serialized = pfor_input.stacked_input(0)
+ dense_defaults = [
+ pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
+ ]
+ sparse_keys = pfor_input.get_attr("sparse_keys")
+ dense_keys = pfor_input.get_attr("dense_keys")
+ sparse_types = pfor_input.get_attr("sparse_types")
+ dense_shapes = pfor_input.get_attr("dense_shapes")
+ output = gen_parsing_ops.parse_example(
+ serialized=serialized,
+ names=[],
+ dense_defaults=dense_defaults,
+ sparse_keys=sparse_keys,
+ dense_keys=dense_keys,
+ sparse_types=sparse_types,
+ dense_shapes=dense_shapes)
+ return [wrap(t, True, True) for t in nest.flatten(output)]