aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/control_flow_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/control_flow_ops.py')
-rw-r--r--tensorflow/python/ops/control_flow_ops.py1561
1 files changed, 1561 insertions, 0 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
new file mode 100644
index 0000000000..068e3b5553
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -0,0 +1,1561 @@
+"""## Control Flow Operations
+
+TensorFlow provides several operations and classes that you can use to control
+the execution of operations and add conditional dependencies to your graph.
+
+@@identity
+@@tuple
+@@group
+@@no_op
+@@count_up_to
+
+## Logical Operators
+
+TensorFlow provides several operations that you can use to add logical operators
+to your graph.
+
+@@logical_and
+@@logical_not
+@@logical_or
+@@logical_xor
+
+## Comparison Operators
+
+TensorFlow provides several operations that you can use to add comparison
+operators to your graph.
+
+@@equal
+@@not_equal
+@@less
+@@less_equal
+@@greater
+@@greater_equal
+@@select
+@@where
+
+## Debugging Operations
+
+TensorFlow provides several operations that you can use to validate values and
+debug your graph.
+
+@@is_finite
+@@is_inf
+@@is_nan
+@@verify_tensor_all_finite
+@@check_numerics
+@@add_check_numerics_ops
+@@Assert
+@@Print
+"""
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import gen_control_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.python.ops.gen_control_flow_ops import *
+
+
+# We override the 'tuple' for a control flow op, so we keep python's
+# existing 'tuple' for later use in this module.
+_basetuple = tuple
+
+
+# pylint: disable=protected-access
+def _Identity(data, name=None):
+ """Return a tensor with the same shape and contents as the input tensor.
+
+ Args:
+ data: A Tensor.
+ name: A name for this operation (optional).
+
+ Returns:
+ A Tensor with the same type and value as the input Tensor.
+ """
+ if not data.dtype.is_ref_dtype:
+ return array_ops.identity(data, name=name)
+ else:
+ return gen_array_ops._ref_identity(data, name=name)
+
+
+def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
+ name=None):
+ """Creates or finds a child frame, and makes 'data' available to it.
+
+ The unique `frame_name` is used by the `Executor` to identify frames. If
+ `is_constant` is true, `output` is a constant in the child frame; otherwise
+ it may be changed in the child frame. At most `parallel_iterations` iterations
+ are run in parallel in the child frame.
+
+ Args:
+ data: The tensor to be made available to the child frame.
+ frame_name: The name of the child frame.
+ is_constant: If true, the output is constant within the child frame.
+ parallel_iterations: The number of iterations allowed to run in parallel.
+ name: A name for this operation (optional).
+
+ Returns:
+ The same tensor as 'data'.
+ """
+ if not data.dtype.is_ref_dtype:
+ return enter(data, frame_name, is_constant, parallel_iterations,
+ name=name)
+ else:
+ return ref_enter(data, frame_name, is_constant, parallel_iterations,
+ name=name)
+
+
+def exit(data, name=None):
+ """Exits the current frame to its parent frame.
+
+ Exit makes its input `data` available to the parent frame.
+
+ Args:
+ data: The tensor to be made available to the parent frame.
+ name: A name for this operation (optional).
+
+ Returns:
+ The same tensor as `data`.
+ """
+ return gen_control_flow_ops._exit(data, name)
+
+
+def switch(data, pred, name=None):
+ """Forwards `data` to an output determined by `pred`.
+
+ If `pred` is true, the `data` input is forwared to the first output.
+ Otherwise, the data goes to the second output.
+
+ This op handles `Tensor`s and `IndexedSlices`.
+
+ Args:
+ data: The tensor to be forwarded to the appropriate output.
+ pred: A scalar that specifies which output port will receive data.
+ name: A name for this operation (optional).
+
+ Returns:
+ `(output_true, output_false)`: If `pred` is true, data will be forwarded to
+ `output_true`, otherwise it goes to `output_false`.
+ """
+ with ops.op_scope([data, pred], name, "Switch") as name:
+ data = ops.convert_to_tensor_or_indexed_slices(data, name="data")
+ pred = ops.convert_to_tensor(pred, name="pred")
+ if isinstance(data, ops.Tensor):
+ return gen_control_flow_ops._switch(data, pred, name=name)
+ else:
+ val, ind, dense_shape = data.values, data.indices, data.dense_shape
+ val_f, val_t = gen_control_flow_ops._switch(val, pred, name=name)
+ ind_f, ind_t = gen_control_flow_ops._switch(ind, pred, name="indices")
+ if dense_shape:
+ dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(
+ dense_shape, pred, name="dense_shape")
+ else:
+ dense_shape_f, dense_shape_t = None, None
+ return (ops.IndexedSlices(val_f, ind_f, dense_shape_f),
+ ops.IndexedSlices(val_t, ind_t, dense_shape_t))
+
+
+def merge(inputs, name=None):
+ """Returns the value of an available element of `inputs`.
+
+ This op tests each of the tensors in `inputs` in turn to determine if any of
+ them is available. If it finds an available tensor, it returns it and its
+ index in `inputs`.
+
+ It is an error if more than one tensor in `inputs` is available. If no tensor
+ in `inputs` is available, the returned tensor and index are not set.
+
+ This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
+ `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
+ before merging.
+
+ Args:
+ inputs: The input tensors, at most one of which is available.
+ name: A name for this operation (optional).
+
+ Returns:
+ A tuple containing the chosen input tensor and its index in `inputs`.
+
+ Raises:
+ ValueError: If inputs are IndexedSlices and some but not all have a
+ dense_shape property.
+ """
+ with ops.op_scope(inputs, name, "Merge") as name:
+ inputs = [ops.convert_to_tensor_or_indexed_slices(inp) for inp in inputs]
+ if all([isinstance(inp, ops.Tensor) for inp in inputs]):
+ return gen_control_flow_ops._merge(inputs, name=name)
+ else:
+ inputs = math_ops._as_indexed_slices_list(inputs)
+ values, _ = gen_control_flow_ops._merge([inp.values for inp in inputs],
+ name=name)
+ indices, chosen_index = gen_control_flow_ops._merge(
+ [inp.indices for inp in inputs], name="indices")
+ if any(inp.dense_shape for inp in inputs):
+ if not all(inp.dense_shape for inp in inputs):
+ raise ValueError("Either all merged IndexedSlices must have a "
+ "dense_shape, or none must have a dense_shape.")
+ dense_shape, _ = gen_control_flow_ops._merge(
+ [inp.dense_shape for inp in inputs], name="dense_shape")
+ else:
+ dense_shape = None
+ return ops.IndexedSlices(values, indices, dense_shape), chosen_index
+
+
+def _SwitchRefOrTensor(data, pred, name="Switch"):
+ """Forwards `data` to an output determined by `pred`.
+
+ If `pred` is true, the `data` input is forwared to the first output.
+ Otherwise, the data goes to the second output.
+
+ This op handles `Tensor`s and `IndexedSlices`.
+
+ Args:
+ data: The tensor to be forwarded to the appropriate output.
+ pred: A scalar that specifies which output port will receive data.
+ name: A name for this operation (optional).
+
+ Returns:
+ `(output_false, output_false)`: If `pred` is true, data will be forwarded to
+ `output_true`, otherwise it goes to `output_false`.
+
+ Raises:
+ TypeError: if data is not a Tensor or IndexedSlices
+ """
+ data = ops.convert_to_tensor_or_indexed_slices(data, name="data")
+ if isinstance(data, ops.Tensor):
+ if not data.dtype.is_ref_dtype:
+ return switch(data, pred, name=name)
+ else:
+ return ref_switch(data, pred, name=name)
+ else:
+ return switch(data, pred, name=name)
+
+
+class ControlFlowOpInputs(object):
+ """An indirection to capture the input tensors needed in backprop."""
+
+ def __init__(self, op):
+ self._op = op
+ self._inputs = None
+
+ def __len__(self):
+ return len(self._op._inputs)
+
+ def __getitem__(self, index):
+ if self._inputs is None:
+ self._inputs = [None for _ in self._op.inputs]
+ if isinstance(index, int):
+ val = self._inputs[index]
+ if val is None:
+ f_val = self._op.inputs[index]
+ val = _GetRealValue(f_val)
+ self._inputs[index] = val
+ return val
+ elif isinstance(index, slice):
+ start, stop, step = index.indices(len(self))
+ vals = [self[i] for i in xrange(start, stop, step)]
+ return vals
+ else:
+ raise TypeError("index must be an integer or slice")
+
+
+class ControlFlowOpOutputs(object):
+ """An indirection to capture the output tensors needed in backprop."""
+
+ def __init__(self, op):
+ self._op = op
+ self._outputs = None
+
+ def __len__(self):
+ return len(self._op._outputs)
+
+ def __getitem__(self, index):
+ if self._outputs is None:
+ self._outputs = [None for _ in self._op.outputs]
+ if isinstance(index, int):
+ val = self._outputs[index]
+ if val is None:
+ f_val = self._op.outputs[index]
+ val = _GetRealValue(f_val)
+ self._outputs[index] = val
+ return val
+ elif isinstance(index, slice):
+ start, stop, step = index.indices(len(self))
+ vals = [self[i] for i in xrange(start, stop, step)]
+ return vals
+ else:
+ raise TypeError("index must be an integer or slice")
+
+
+class ControlFlowOpWrapper(object):
+ """A wrapper class for Operation."""
+
+ def __init__(self, op):
+ self._op = op
+ self._inputs = None
+ self._outputs = None
+
+ @property
+ def inputs(self):
+ if self._inputs is None:
+ self._inputs = ControlFlowOpInputs(self._op)
+ return self._inputs
+
+ @property
+ def outputs(self):
+ if self._outputs is None:
+ self._outputs = ControlFlowOpOutputs(self._op)
+ return self._outputs
+
+ @property
+ def op(self):
+ return self._op
+
+ @property
+ def name(self):
+ """Returns the name of this instance of op."""
+ return self._op.name
+
+ @property
+ def _id(self):
+ """Returns the unique id of this operation."""
+ return self._op._id
+
+ @property
+ def device(self):
+ """Returns the device of this operation.
+
+ Returns:
+ a string or None if the device was not set.
+ """
+ return self._op.device
+
+ @property
+ def output_types(self):
+ return self._op.output_types
+
+ @property
+ def input_types(self):
+ return self._op._input_types
+
+ @property
+ def type(self):
+ """Returns the type of the op."""
+ return self._op.type
+
+ @property
+ def graph(self):
+ """Returns the parent graph."""
+ return self._op.graph
+
+ def GetAttr(self, attr_name):
+ """Returns the value of attribute 'attr_name' of NodeDef."""
+ return self._op.get_attr(attr_name)
+
+ def _get_control_flow_context(self):
+ return self._op._get_control_flow_context()
+
+
+def GetRealOp(op):
+ while isinstance(op, ControlFlowOpWrapper):
+ op = op.op
+ return op
+
+
+def MakeWrapper(op):
+ """Make a wrapper for op if it is in a WhileContext."""
+ forward_ctxt = op._get_control_flow_context()
+ if forward_ctxt and isinstance(forward_ctxt, WhileContext):
+ return ControlFlowOpWrapper(op)
+ return op
+
+
+def EnterGradWhileContext(op):
+ """Enter the WhileContext for gradient computation."""
+ forward_ctxt = op._get_control_flow_context()
+ if forward_ctxt and isinstance(forward_ctxt, WhileContext):
+ grad_ctxt = forward_ctxt.CreateGradWhileContext()
+ grad_ctxt.Enter()
+
+
+def ExitGradWhileContext(op):
+ """Exit the WhileContext for gradient computation."""
+ forward_ctxt = op._get_control_flow_context()
+ if forward_ctxt and isinstance(forward_ctxt, WhileContext):
+ assert forward_ctxt.grad_context
+ forward_ctxt.grad_context.Exit()
+
+
+def _GetRealValue(value):
+ """Get the real value.
+
+ If backprop "uses" a value produced by forward inference, an
+ accumulator is added in the forward loop to accumulate its values,
+ so we use the accumulated value, indexed by the backprop counter.
+
+ Args:
+ value: A tensor to be captured.
+
+ Returns:
+ The same tensor value from the saved history.
+ """
+ real_value = value
+ forward_ctxt = value.op._get_control_flow_context()
+ real_value = forward_ctxt.history_map.get(value.name)
+ assert value.op.type != "Variable"
+ if real_value is None:
+ if value.op.type == "Enter" and value.op.get_attr("is_constant"):
+ # Use the input of this Enter node
+ real_value = GetRealOp(value.op).inputs[0]
+ else:
+ # Accumulate the history of this value.
+ # NOTE(yuanbyu): Don't accumulate for constants. One approach is
+ # to deepcopy the constants for the grad while context.
+ history_value = forward_ctxt.AddForwardAccumulateLoop(value)
+
+ # The shapes of the whole history and a single event element.
+ forward_ctxt.grad_context.Exit()
+ elem_rank = array_ops.rank(history_value) - 1
+ elem_rank_vec = array_ops.expand_dims(elem_rank, 0)
+ elem_shape = array_ops.slice(array_ops.shape(history_value), [1],
+ elem_rank_vec)
+ slice_shape = array_ops.concat(0, [[1], elem_shape])
+ forward_ctxt.grad_context.Enter()
+
+ # The begin position of the slice at slice_index.
+ slice_index = forward_ctxt.grad_context.index
+ b1 = array_ops.zeros(elem_rank_vec, dtype=types.int32)
+ b = array_ops.concat(0, [array_ops.expand_dims(slice_index, 0), b1])
+
+ # The slice at slice_index.
+ # TODO(irving): Replace with gather once that's GPU accelerated
+ real_value = array_ops.squeeze(
+ array_ops.slice(history_value,
+ b,
+ slice_shape,
+ name="real"),
+ squeeze_dims=[0])
+ forward_ctxt.history_map[value.name] = real_value
+ return real_value
+
+
+def IsLoopSwitch(op):
+ """Returns true if `op` is the Switch for a While loop."""
+ if op.type == "Switch":
+ ctxt = op._get_control_flow_context()
+ return ctxt and isinstance(ctxt, WhileContext)
+ return False
+
+
+class ControlFlowContext(object):
+ """The base class for control flow context.
+
+ The usage pattern is a sequence of (Enter, Exit) followed by a final
+ ExitResult.
+ """
+
+ def AddName(self, name):
+ self._values.add(name)
+
+ # pylint: disable=protected-access
+ def Enter(self):
+ """Enter the current context."""
+ self._outer_context = ops.get_default_graph()._get_control_flow_context()
+ ops.get_default_graph()._set_control_flow_context(self)
+
+ def Exit(self):
+ """Exit the current context."""
+ ops.get_default_graph()._set_control_flow_context(self._outer_context)
+ # pylint: enable=protected-access
+
+ def ExitResult(self, result):
+ """Make a list of tensors available in the outer context."""
+ if self._outer_context is not None:
+ for x in result:
+ self._outer_context.AddName(x.name)
+
+ def GetWhileContext(self):
+ """Get the current while context."""
+ if self._outer_context is not None:
+ return self._outer_context.GetWhileContext()
+ return None
+
+ def AddToWhileContext(self, op):
+ """Add a control dependency to the containing WhileContext.
+
+ The added control dependency ensures that the outputs of this op
+ belong to the WhileContext.
+
+ Args:
+ op: An operation.
+ """
+ while_ctxt = self.GetWhileContext()
+ if while_ctxt is not None:
+ # pylint: disable=protected-access
+ op._add_control_input(while_ctxt.GetControlPivot().op)
+ # pylint: enable=protected-access
+
+
+class CondContext(ControlFlowContext):
+ """The context for the conditional construct."""
+
+ def __init__(self, pred, pivot, branch):
+ self._pred = pred
+ self._outer_context = None
+ self._pivot = pivot
+ self._branch = branch
+ self._values = set()
+ self._values.add(pred.name)
+ self._values.add(pivot.name)
+ self._external_values = {}
+
+ @property
+ def pred(self):
+ return self._pred
+
+ @property
+ def pivot(self):
+ return self._pivot
+
+ @property
+ def branch(self):
+ return self._branch
+
+ def AddValue(self, val):
+ """Add 'val' to the current context and its outer context recursively."""
+ result = val
+ if val.name not in self._values:
+ self._values.add(val.name)
+ if self._outer_context is not None:
+ result = self._outer_context.AddValue(val)
+ result = with_dependencies([self._pivot], result)
+ self._external_values[val.name] = result
+ return result
+
+ def AddOp(self, op):
+ """Add 'op' to the current context."""
+ if not op.inputs:
+ # Add this op to the enclosing while context
+ self.AddToWhileContext(op)
+ # pylint: disable=protected-access
+ op._add_control_input(self._pivot.op)
+ # pylint: enable=protected-access
+ for x in op.outputs:
+ self._values.add(x.name)
+ else:
+ for index in range(len(op.inputs)):
+ x = op.inputs[index]
+ if x.name not in self._values:
+ self._values.add(x.name)
+ # Add this value to the parent contexts up to the context that
+ # creates this value.
+ real_x = x
+ if self._outer_context is not None:
+ real_x = self._outer_context.AddValue(x)
+ real_x = _SwitchRefOrTensor(real_x, self._pred)[self._branch]
+ self._external_values[x.name] = real_x
+ x = self._external_values.get(x.name)
+ if x is not None:
+ op._update_input(index, x)
+ for x in op.outputs:
+ self._values.add(x.name)
+
+ def BuildCondBranch(self, fn):
+ """Add the subgraph defined by fn() to the graph."""
+ r = fn()
+ result = []
+ if r is not None:
+ if not isinstance(r, list) and not isinstance(r, _basetuple):
+ r = [r]
+ for v in r:
+ if isinstance(v, ops.Operation):
+ v = with_dependencies([v], self._pivot)
+ elif v.name not in self._values:
+ self._values.add(v.name)
+ if self._outer_context is not None:
+ v = self._outer_context.AddValue(v)
+ v = _SwitchRefOrTensor(v, self._pred)[self._branch]
+ else:
+ external_v = self._external_values.get(v.name)
+ if external_v is not None:
+ v = external_v
+ result.append(v)
+ return result
+
+
+def cond(pred, fn1, fn2, name=None):
+ """Return either 'fn1()' or 'fn2()' based on the boolean predicate 'pred'.
+
+ `fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have
+ the same number and type of outputs.
+
+ Args:
+ pred: A scalar determining whether to return the result of `fn1` or `fn2`.
+ fn1: The function to be performed if pred is true.
+ fn2: The function to be performed if pref is false.
+ name: Optional name prefix for the returned tensors.
+
+ Returns:
+ Tensors returned by the call to either `fn1` or `fn2`. If the functions
+ return a singleton list, the element is extracted from the list.
+
+ Raises:
+ TypeError: if `fn1` or `fn2` is not callable.
+ ValueError: if `fn1` and `fn2` do not return the same number of tensors, or
+ return tensors of different types.
+
+ Example:
+ ```python
+ x = constant(2)
+ y = constant(5)
+ def f1(): return constant(17)
+ def f2(): return constant(23)
+ r = cond(math_ops.less(x, y), f1, f2)
+ # r is set to f1()
+ ```
+ """
+ with ops.op_scope([pred], name, "Cond") as name:
+ if not callable(fn1):
+ raise TypeError("fn1 must be callable.")
+ if not callable(fn2):
+ raise TypeError("fn2 must be callable.")
+
+ # Add the Switch to the graph.
+ p_2, p_1 = switch(pred, pred)
+ pivot_1 = array_ops.identity(p_1, name="switch_t")
+ pivot_2 = array_ops.identity(p_2, name="switch_f")
+ pred = array_ops.identity(pred, name="pred_id")
+
+ # Build the graph for the true branch in a new context.
+ context_t = CondContext(pred, pivot_1, 1)
+ context_t.Enter()
+ res_t = context_t.BuildCondBranch(fn1)
+ context_t.ExitResult(res_t)
+ context_t.Exit()
+
+ # Build the graph for the false branch in a new context.
+ context_f = CondContext(pred, pivot_2, 0)
+ context_f.Enter()
+ res_f = context_f.BuildCondBranch(fn2)
+ context_t.ExitResult(res_f)
+ context_f.Exit()
+
+ # Add the final merge to the graph.
+ if len(res_t) != len(res_f):
+ raise ValueError("fn1 and fn2 must return the same number of tensors.")
+ for x, y in zip(res_f, res_t):
+ assert ((isinstance(x, ops.IndexedSlices) and
+ isinstance(y, ops.IndexedSlices)) or
+ (isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)))
+ val_x = x if isinstance(x, ops.Tensor) else x.values
+ val_y = y if isinstance(y, ops.Tensor) else y.values
+ if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
+ raise ValueError("Outputs of fn1 and fn2 must have the same type: "
+ "%s, %s" % (val_x.dtype.name, val_y.dtype.name))
+ merges = [merge([x[0], x[1]])[0] for x in zip(res_f, res_t)]
+ return merges[0] if len(merges) == 1 else merges
+
+
+# TODO(yuanbyu): We should probably separate the notion of context so it
+# could be used not only for conditionals and loops but also subgraphs.
+class WhileContext(ControlFlowContext):
+ """The context for the loop construct."""
+
+ def __init__(self, parallel_iterations, back_prop, name):
+ self._name = ops.get_default_graph().unique_name(name)
+ self._parallel_iterations = parallel_iterations
+ self._back_prop = back_prop
+ self._outer_context = None
+ # We use this node to control constants created by the pred lambda.
+ self._pivot_for_pred = None
+ # We use this node to control constants created by the body lambda.
+ self._pivot_for_body = None
+ # The boolean tensor for loop termination condition. Used in code
+ # generation for gradient computation
+ self._pivot = None
+
+ # The tensors for the counters added by AddForwardCounterLoop or
+ # AddBackPropCounterLoop
+ self._index = None
+
+ # Information needed by backprop
+ self._grad_context = None
+ self._total_iterations = None
+ self._history_map = {}
+ self._switch_map = {}
+
+ # values considered to have been already seen in this context
+ self._values = set()
+
+ # values referenced by but external to this context
+ self._external_values = {}
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def parallel_iterations(self):
+ """The number of iterations allowed to run in parallel."""
+ return self._parallel_iterations
+
+ @property
+ def back_prop(self):
+ """True iff backprop is enabled for this While loop."""
+ return self._back_prop
+
+ @property
+ def pivot(self):
+ """The boolean tensor representing the loop termination condition."""
+ return self._pivot
+
+ @property
+ def index(self):
+ """The loop index representing the current iteration."""
+ return self._index
+
+ @property
+ def grad_context(self):
+ """The corresponding WhileContext for gradient."""
+ return self._grad_context
+
+ @property
+ def history_map(self):
+ """The map that records all the tensors needed for backprop."""
+ return self._history_map
+
+ @property
+ def switch_map(self):
+ """The map that records all the Switch ops in the While loop."""
+ return self._switch_map
+
+ @property
+ def total_iterations(self):
+ """The total number of iterations of the while loop."""
+ return self._total_iterations
+
+ def GetWhileContext(self):
+ return self
+
+ def GetControlPivot(self):
+ if self._pivot_for_body:
+ return self._pivot_for_body
+ return self._pivot_for_pred
+
+ def AddValue(self, val):
+ """Add 'val' to the current context and its outer context recursively."""
+ result = val
+ if val.name not in self._values:
+ self._values.add(val.name)
+ if self._outer_context is not None:
+ result = self._outer_context.AddValue(val)
+ # Create an Enter that makes 'result' known to this context.
+ enter = _Enter(result, self._name, is_constant=True,
+ parallel_iterations=self._parallel_iterations)
+ self._values.add(enter.name)
+ self._external_values[val.name] = enter
+ result = enter
+ else:
+ actual_val = self._external_values.get(val.name)
+ if actual_val is not None:
+ result = actual_val
+ return result
+
+ def AddOp(self, op):
+ """Adds 'op' to the current context."""
+ if not op.inputs:
+ if not op.control_inputs:
+ # Add a control edge from the control pivot to this op.
+ # pylint: disable=protected-access
+ op._add_control_input(self.GetControlPivot().op)
+ # pylint: enable=protected-access
+ else:
+ # Control edges must be in the same context.
+ for x in op.control_inputs:
+ assert x._get_control_flow_context() == self, (
+ "Control inputs must come from Operations in the same while "
+ "loop context (not an outer context).")
+ for x in op.outputs:
+ self._values.add(x.name)
+ else:
+ for index in range(len(op.inputs)):
+ x = op.inputs[index]
+ self.AddValue(x)
+ real_x = self._external_values.get(x.name)
+ if real_x is not None:
+ op._update_input(index, real_x)
+ # Add a control dependency to prevent loop invariants from
+ # enabling ops that should not be executed.
+ if real_x.op.type == "RefEnter" and real_x.op.get_attr("is_constant"):
+ # pylint: disable=protected-access
+ op._add_control_input(self.GetControlPivot().op)
+ # pylint: enable=protected-access
+ for x in op.outputs:
+ self._values.add(x.name)
+
+ def CreateGradWhileContext(self):
+ """Creates the WhileContext for backprop gradient computation."""
+ if self._grad_context is None:
+ cnt = self.AddForwardCounterLoop()
+ self._grad_context = WhileContext(self._parallel_iterations,
+ self._back_prop, self._name)
+ self._grad_context.AddBackPropCounterLoop(cnt)
+ return self._grad_context
+
+ def AddForwardCounterLoop(self):
+ """Adds a loop that counts the number of iterations.
+
+ This is added to the forward loop at the time when we start to
+ create the loop for backprop gradient computation.
+
+ The pseudocode is:
+ `n = 0; while (_pivot) { n++; }`
+
+ Returns:
+ The number of iterations taken by the forward loop.
+ """
+ n = constant_op.constant(0, name="f_count")
+ self.Enter()
+ self.AddName(n.name)
+ enter_n = _Enter(n, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="f_count")
+ merge_n = merge([enter_n, enter_n])[0]
+ switch_n = switch(merge_n, self._pivot)
+ self._index = switch_n[1]
+
+ add_n = math_ops.add(self._index, 1)
+ next_n = next_iteration(add_n)
+ merge_n.op._update_input(1, next_n)
+
+ self._total_iterations = exit(switch_n[0], name="f_count")
+ self.Exit()
+ return self._total_iterations
+
+ def AddForwardAccumulateLoop(self, value):
+ """Add an accumulation loop for each value needed in backprop.
+
+ This is added to the forward loop at the first time when a value
+ in the forward loop is used by backprop gradient computation loop.
+
+ The pseudocode is:
+ ```
+ acc;
+ while (_pivot) {
+ if (index == 0) [value] else Concat(acc, [value]);
+ }
+ ```
+
+ Args:
+ value: The tensor that is accumulated.
+
+ Returns:
+ The accumulated history of value.
+
+ Raises:
+ ValueError: If the shape of "value" is not known statically.
+ """
+ if not value.get_shape().is_fully_defined():
+ raise ValueError("Must have known shape: %s" % value)
+ self._grad_context.Exit()
+ # TODO(irving): Now that acc starts out empty, most of the
+ # conditional logic can go away.
+ acc = constant_op.constant([],
+ value.dtype,
+ shape=[0] + value.get_shape().as_list(),
+ name="f_acc")
+ self.Enter()
+ self.AddName(acc.name)
+ enter_acc = _Enter(acc, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="f_acc")
+ merge_acc = merge([enter_acc, enter_acc])[0]
+ switch_acc = switch(merge_acc, self._pivot)
+
+ # If index = 0 then [value] else Concat(acc, [value]).
+ cond = math_ops.greater(self._index, 0)
+ switch_add_acc = switch(switch_acc[1], cond)
+ expand_value = array_ops.expand_dims(value, 0)
+ true_branch = array_ops.concat(0, [switch_add_acc[1], expand_value])
+ false_branch = array_ops.identity(switch_add_acc[0])
+ false_branch = with_dependencies([false_branch], expand_value)
+ add_acc = merge([false_branch, true_branch])[0]
+
+ next_acc = next_iteration(add_acc)
+ merge_acc.op._update_input(1, next_acc)
+
+ exit_acc = exit(switch_acc[0], name="f_acc")
+ self.Exit()
+ self._grad_context.Enter()
+ return exit_acc
+
+ def AddForwardAccumulateCondLoop(self, value):
+ """Add an accumulation loop for each conditional switch.
+
+ This is added to the forward loop at the first time when a conditional
+ switch in the forward loop is used by backprop gradient computation loop.
+
+ The pseudocode is:
+ ```
+ acc;
+ while (_pivot) {
+ Concat(acc, value);
+ }
+ ```
+
+ Args:
+ value: The boolean tensor that is accumulated.
+
+ Returns:
+ The accumulated history of value.
+ """
+ self._grad_context.Exit()
+ acc = constant_op.constant(False, name="f_acc")
+ self.Enter()
+ self.AddName(acc.name)
+ enter_acc = _Enter(acc, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="f_acc")
+ merge_acc = merge([enter_acc, enter_acc])[0]
+ switch_acc = switch(merge_acc, self._pivot)
+ acc = array_ops.concat(0, [switch_add_acc[1], value])
+ next_acc = next_iteration(acc)
+ merge_acc.op._update_input(1, next_acc)
+
+ exit_acc = exit(switch_acc[0], name="f_acc")
+ self.Exit()
+ self._grad_context.Enter()
+ return exit_acc
+
+ def AddBackPropCounterLoop(self, count):
+ """Add the backprop loop that controls the iterations.
+
+ This is added to the backprop loop. It is used to control the loop
+ termination and the slice index.
+
+ The pseudocode is:
+ `n = count; while (n >= 1) { n--; }`
+
+ Args:
+ count: The number of iterations for backprop.
+
+ Returns:
+ always 0.
+ """
+ one = constant_op.constant(1, name="b_count")
+ self.Enter()
+ self.AddName(count.name)
+ enter_count = _Enter(count, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="b_count")
+ merge_count = merge([enter_count, enter_count])[0]
+ self._pivot_for_pred = merge_count
+
+ cond = math_ops.greater_equal(merge_count, one)
+ self._pivot = loop_cond(cond, name="b_count")
+ switch_count = switch(merge_count, self._pivot)
+
+ # Add next_iteration right after Switch to match the gradient function.
+ next_count = next_iteration(switch_count[1])
+ self._pivot_for_body = next_count
+ self._index = math_ops.sub(next_count, one)
+ merge_count.op._update_input(1, self._index)
+
+ exit_count = exit(switch_count[0], name="b_count")
+ self.Exit()
+ return exit_count
+
+ def AddBackPropAccumulateLoop(self, value):
+ """Add an accumulation loop for every loop invariant.
+
+ This is added to the backprop loop. It is used to accumulate partial
+ gradients for each loop iteration. Called when in the while context
+ for gradient.
+
+ The pseudocode is:
+ ```
+ acc = 0;
+ while (_pivot) {
+ acc += value;
+ }
+ ```
+
+ Args:
+ value: The partial gradient of an iteration for a loop invariant.
+
+ Returns:
+ The gradient for a loop invariant.
+ """
+ self.Exit()
+ acc = constant_op.constant(0, value.dtype, name="b_acc")
+ self.Enter()
+ self.AddName(acc.name)
+ enter_acc = _Enter(acc, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="b_acc")
+ merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
+ switch_acc = switch(merge_acc, self._pivot)
+
+ next_acc = next_iteration(switch_acc[1])
+ add_acc = math_ops.add(next_acc, value)
+ merge_acc.op._update_input(1, add_acc)
+
+ exit_acc = exit(switch_acc[0], name="b_acc")
+ return exit_acc
+
+ def BuildLoop(self, pred, body, loop_vars):
+ """Add the loop termination condition and body to the graph."""
+
+ loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
+ # Let the context know the loop variabes so the _Enter nodes below
+ # would be added into the context correctly.
+ self._values = set([x.name for x in loop_vars])
+ if self._outer_context is not None:
+ real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
+ else:
+ real_vars = loop_vars
+ enter_vars = [_Enter(x, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations)
+ for x in real_vars]
+ self._values = set([x.name for x in enter_vars])
+
+ merge_vars = [merge([x, x])[0] for x in enter_vars]
+ self._pivot_for_pred = merge_vars[0]
+
+ # Build the graph for pred.
+ c = ops.convert_to_tensor(pred(*merge_vars))
+ self._pivot = loop_cond(c, name="LoopCond")
+ switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
+
+ # Build the graph for body.
+ vars_for_body = [_Identity(x[1]) for x in switch_vars]
+ self._pivot_for_body = vars_for_body[0]
+
+ body_result = body(*vars_for_body)
+ if not isinstance(body_result, (list, _basetuple)):
+ body_result = [body_result]
+ result = ops.convert_n_to_tensor_or_indexed_slices(body_result)
+ next_vars = [next_iteration(x) for x in result]
+
+ # Add the back edges to complete the loop.
+ assert len(merge_vars) == len(next_vars)
+ for x in zip(merge_vars, next_vars):
+ x[0].op._update_input(1, x[1])
+
+ # Add the exit ops.
+ exit_vars = [exit(x[0]) for x in switch_vars]
+
+ for m_var, n_var, e_var in zip(merge_vars, next_vars, exit_vars):
+ if m_var.get_shape().is_compatible_with(n_var.get_shape()):
+ e_var.set_shape(m_var.get_shape().merge_with(n_var.get_shape()))
+
+ # Exit the loop.
+ self.ExitResult(exit_vars)
+ self.Exit()
+ return exit_vars[0] if len(exit_vars) == 1 else exit_vars
+
+
+def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
+ name=None):
+ """Repeat `body` while the condition `cond` is true.
+
+ `cond` is a function taking a list of tensors and returning a boolean scalar
+ tensor. `body` is a function taking a list of tensors and returning a list of
+ tensors of the same length and with the same types as the input. `loop_vars`
+ is a list of tensors that is passed to both `cond` and `body`.
+
+ While `cond` evaluates to true, `body` is executed.
+
+ Args:
+ cond: The termination condition of the loop.
+ body: A function that represents the loop body.
+ loop_vars: The list of variable input tensors.
+ parallel_iterations: The number of iterations allowed to run in parallel.
+ back_prop: Whether backprop is enabled for this while loop.
+ name: Optional name prefix for the returned tensors.
+
+ Returns:
+ The output tensors for the loop variables after the loop.
+
+ Raises:
+ TypeError: if `cond` or `body` is not callable.
+ ValueError: if `loop_var` is empty.
+
+ Example:
+ ```python
+ i = Constant(0)
+ c = lambda i: math_ops.less(i, 10)
+ b = lambda i: math_ops.add(i, 1)
+ r = While(c, b, [i])
+ ```
+ """
+ with ops.op_scope(loop_vars, name, "While") as name:
+ if not loop_vars:
+ raise ValueError("No loop variables provided")
+ if not callable(cond):
+ raise TypeError("cond must be callable.")
+ if not callable(body):
+ raise TypeError("body must be callable.")
+
+ context = WhileContext(parallel_iterations, back_prop, name)
+ context.Enter()
+ return context.BuildLoop(cond, body, loop_vars)
+
+
+def _AsTensorList(x, p):
+ """Return x as a list of Tensors or IndexedSlices.
+
+ For entries of `x` that are Operations, this returns an Identity of `p`
+ with a dependency on the operation.
+
+ Args:
+ x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
+ p: A Tensor to return for entries in `x` that are Operations.
+
+ Returns:
+ A list of Tensors or IndexedSlices.
+ """
+ if not isinstance(x, list) and not isinstance(x, _basetuple):
+ x = [x]
+
+ l = []
+ for v in x:
+ if isinstance(v, ops.Operation):
+ v = with_dependencies([v], p)
+ v = ops.convert_to_tensor_or_indexed_slices(v)
+ if isinstance(v, ops.Tensor):
+ l.append(array_ops.identity(v))
+ else:
+ l.append(ops.IndexedSlices(array_ops.identity(v.values),
+ array_ops.identity(v.indices)))
+ return l
+
+
+def _CheckResults(a, b):
+ assert len(a) == len(b), (
+ "Values returned by a() and b() must have the same length.")
+ for x, y in zip(a, b):
+ assert x.dtype == y.dtype, (
+ "Values returned by a() [%s] and b() [%s] must have "
+ "the same type: %s, %s." %
+ (x.name, y.name, x.dtype.name, y.dtype.name))
+
+
+def with_dependencies(dependencies, output_tensor, name=None):
+ """Produces the content of `output_tensor` only after `dependencies`.
+
+ In some cases, a user may want the output of an operation to be
+ consumed externally only after some other dependencies have run
+ first. This function ensures returns `output_tensor`, but only after all
+ operations in `dependencies` have run. Note that this means that there is
+ no guarantee that `output_tensor` will be evaluated after any `dependencies`
+ have run.
+
+ See also `tuple` and `group`.
+
+ Args:
+ dependencies: A list of operations to run before this op finishes.
+ output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
+ name: (Optional) A name for this operation.
+
+ Returns:
+ Same as `output_tensor`.
+
+ Raises:
+ TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
+ """
+ with ops.op_scope(dependencies + [output_tensor], name,
+ "control_dependency") as name:
+ with ops.device(output_tensor.device
+ or ops.get_default_graph().get_default_device()):
+ with ops.control_dependencies(dependencies):
+ output_tensor = ops.convert_to_tensor_or_indexed_slices(output_tensor)
+ if isinstance(output_tensor, ops.Tensor):
+ return _Identity(output_tensor, name=name)
+ else:
+ return ops.IndexedSlices(_Identity(output_tensor.values, name=name),
+ output_tensor.indices,
+ output_tensor.dense_shape)
+
+
+def _GroupControlDeps(dev, deps, name=None):
+ with ops.control_dependencies(deps):
+ if dev is None:
+ return no_op(name=name)
+ else:
+ with ops.device(dev):
+ return no_op(name=name)
+
+
+# TODO(mdevin): Accept "inputs" as a list.
+def group(*inputs, **kwargs):
+ """Create an op that groups multiple operations.
+
+ When this op finishes, all ops in `input` have finished. This op has no
+ output.
+
+ See also `tuple` and `with_dependencies`.
+
+ Args:
+ *inputs: One or more tensors to group.
+ **kwargs: Optional parameters to pass when constructing the NodeDef.
+ name: A name for this operation (optional).
+
+ Returns:
+ An Operation that executes all its inputs.
+
+ Raises:
+ ValueError: If an unknown keyword argument is provided, or if there are
+ no inputs.
+ """
+ name = kwargs.pop("name", None)
+ if kwargs:
+ raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
+ if not inputs:
+ # TODO(mdevin): Would make sense to return a NoOp.
+ raise ValueError("No inputs provided")
+ with ops.op_scope(inputs, name, "group_deps") as name:
+ # Sorts *inputs according to their devices.
+ ops_on_device = {} # device -> operations specified on the device.
+ for inp in inputs:
+ dev = inp.device
+ if dev in ops_on_device:
+ ops_on_device[dev].append(inp)
+ else:
+ ops_on_device[dev] = [inp]
+ if len(ops_on_device) == 1:
+ # 1-level tree. The root node is the returned NoOp node.
+ dev, deps = ops_on_device.items()[0]
+ return _GroupControlDeps(dev, deps, name=name)
+ # 2-level tree. The root node is the returned NoOp node.
+ # deps contains 1 NoOp node for each device.
+ deps = []
+ for dev in sorted(ops_on_device.iterkeys()):
+ deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
+ return _GroupControlDeps(None, deps, name=name)
+
+def tuple(tensors, name=None, control_inputs=None):
+ """Group tensors together.
+
+ This creates a tuple of tensors with the same values as the `tensors`
+ argument, except that the value of each tensor is only returned after the
+ values of all tensors have been computed.
+
+ `control_inputs` contains additional ops that have to finish before this op
+ finishes, but whose outputs are not returned.
+
+ This can be used as a "join" mechanism for parallel computations: all the
+ argument tensors can be computed in parallel, but the values of any tensor
+ returned by `tuple` are only available after all the parallel computations
+ are done.
+
+ See also `group` and `with_dependencies`.
+
+ Args:
+ tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
+ name: (optional) A name to use as a `name_scope` for the operation.
+ control_inputs: List of additional ops to finish before returning.
+
+ Returns:
+ Same as `tensors`.
+
+ Raises:
+ ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
+
+ """
+ with ops.op_scope(tensors, name, "tuple") as name:
+ gating_ops = [t.op for t in tensors if t]
+ if control_inputs:
+ gating_ops += control_inputs
+ # Note that in order to ensure ordering in the pbtxt, we must take care to
+ # ensure the order here.
+ gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops.
+ if not gating_ops:
+ raise ValueError("Must have at least one Tensor: %s" % tensors)
+ gate = group(*gating_ops)
+ tpl = []
+ for t in tensors:
+ if t:
+ tpl.append(with_dependencies([gate], t))
+ else:
+ tpl.append(None)
+ return tpl
+
+
+# TODO(yuanbyu): It would be nicer if we could have the distributed list
+# support that Derek has been proposing.
+# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
+def fold(fn, elems, elem_shape, name=None):
+ """The fold operator on slices of a tensor.
+
+ This fold operator applies the function `fn` to slices of `elems` on
+ dimension 0. The shape of the slices is specified by `elem_shape`. `elems`
+ must contain at least one slice (`shape(elems)[0] / elem_shape[0] > 0`).
+
+ Args:
+ fn: The function to be performed on each slice of the tensor.
+ elems: The tensor to whose slices we want to apply `fn`.
+ elem_shape: The shape definition for the slices.
+ name: Optional name prefix for the returned tensors.
+
+ Returns:
+ A tensor resulting from applying `fn` consecutively on each slice of
+ `elems`.
+
+ Raises:
+ TypeError: if `fn` is not callable.
+ """
+ with ops.op_scope([elems], name, "Fold") as name:
+ if not callable(fn):
+ raise TypeError("fn must be callable.")
+
+ s0 = array_ops.shape(elems)[0]
+ d0 = elem_shape[0]
+ n = math_ops.div(s0, d0)
+ b1 = array_ops.zeros(array_ops.expand_dims(array_ops.rank(elems) - 1, 0),
+ dtype=types.int32)
+ # Initialize the output with slice 0
+ b = array_ops.concat(0, [[0], b1])
+ o = array_ops.slice(elems, b, elem_shape)
+ i = ops.convert_to_tensor(d0)
+
+ def Compute(i, o):
+ b = array_ops.concat(0, [array_ops.expand_dims(i, 0), b1])
+ x = array_ops.slice(elems, b, elem_shape)
+ o = fn(o, x)
+ i = math_ops.add(i, d0)
+ return [i, o]
+ r = While(lambda i, o: math_ops.less(i, n), Compute, [i, o])
+ return r[1]
+
+
+def case(pred_fn_pairs, default, exclusive=False, name="Case"):
+ """Create a Case operation.
+
+ The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
+ Each pair contains a boolean scalar tensor and a python callable that
+ creates the tensors to be returned if the boolean evaluates to True. `default`
+ is a callable generating a list of tensors. All the callables in
+ `pred_fn_pairs` as well as `default` should return the same number and types
+ of tensors.
+
+ If `exclusive==True`, all predicates are evaluated, and a logging operation
+ with an error is returned if more than one of the predicates evaluates to
+ True. If `exclusive==False`, execution stops are the first predicate which
+ evaluates to True, and the tensors generated by the corresponding function
+ are returned immediately. If none of the predicates evaluate to True, this
+ operation returns the tensors generated by `default`.
+
+ Example 1:
+ Pseudocode:
+ ```
+ if (x < y) return 17;
+ else return 23;
+ ```
+
+ Expressions:
+ ```
+ f1 = lambda: Constant(17)
+ f2 = lambda: Constant(23)
+ r = Case([(math_ops.less(x, y), f1)], default=f2)
+ ```
+
+ Example 2:
+ Pseudocode:
+ ```
+ if (x < y && x > z) raise OpError("Only one predicate may evaluate true");
+ if (x < y) return 17;
+ else if (x > z) return 23;
+ else return -1;
+ ```
+
+ Expressions:
+ ```
+ def f1(): return Constant(17)
+ def f2(): return Constant(23)
+ def f3(): return Constant(-1)
+ r = Case({math_ops.less(x, y): f1, math_ops.greater(x, z): f2},
+ default=f3, exclusive=True)
+ ```
+
+ Args:
+ pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
+ callable which returns a list of tensors.
+ default: A callable that returns a list of tensors.
+ exclusive: True iff more than one predicate is allowed to evaluate to True.
+ name: A name for this operation (optional).
+
+ Returns:
+ The tensors returned by the first pair whose predicate evaluated to True, or
+ those returned by `default` if none does.
+
+ Raises:
+ TypeError: If `pred_fn_pairs` is not a list/dictionary.
+ TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
+ TypeError: If `fns[i]` is not callable for any i, or `default` is not
+ callable.
+ """
+ pfp = pred_fn_pairs # For readability
+ if not (isinstance(pfp, list) or isinstance(pfp, _basetuple)
+ or isinstance(pfp, dict)):
+ raise TypeError("fns must be a list, tuple, or dict")
+ if isinstance(pfp, dict):
+ pfp = pfp.items()
+ if not exclusive:
+ logging.warn("%s: Provided dictionary of predicate/fn pairs, but "
+ "exclusive=False. Order of conditional tests is "
+ "not guaranteed." % name)
+ for tup in pfp:
+ if not isinstance(tup, _basetuple) or len(tup) != 2:
+ raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
+ pred, fn = tup
+ if pred.dtype != types.bool:
+ raise TypeError("pred must be of type bool: %s", pred.name)
+ if not callable(fn):
+ raise TypeError("fn for pred %s must be callable." % pred.name)
+ if not callable(default):
+ raise TypeError("default must be callable.")
+
+ preds, fns = map(list, zip(*pfp))
+ with ops.op_scope([[f() for f in fns] + preds + [default()]], name, "Case"):
+ if not preds:
+ return default()
+ not_preds = []
+ for i, p in enumerate(preds):
+ with ops.name_scope("not_%d" % i):
+ not_preds.append(math_ops.logical_not(p))
+ and_not_preds = [constant_op.constant(True, name="and_not_true")]
+ for i, notp in enumerate(not_preds[:-1]):
+ with ops.name_scope("and_not_%d" % i):
+ and_not_preds.append(math_ops.logical_and(and_not_preds[-1], notp))
+
+ # preds = [p1, p2, p3]
+ # fns = [f1, f2, f3]
+ # not_preds = [~p1, ~p2, ~p3]
+ # case_preds = [p1 & True,
+ # p2 & ~p1,
+ # p3 & ~p1 & ~ p2]
+ case_preds = []
+ for i, (p, and_not_p_prev) in enumerate(zip(preds, and_not_preds)):
+ with ops.name_scope("case_%d" % i):
+ case_preds.append(math_ops.logical_and(p, and_not_p_prev))
+
+ # case_sequence = [Cond(p3 & ..., f3, default),
+ # Cond(p2 & ..., f2, lambda: case_sequence[0]),
+ # ...
+ # Cond(p1 & True, f1, lambda: case_sequence[i-1])]
+ # and prev_case_seq will loop from case_sequence[0] to case_sequence[-1]
+ if exclusive:
+ # TODO(ebrevdo): Add Where() for DT_BOOL, replace with Size(Where(preds))
+ preds_c = array_ops.concat(0, preds, name="preds_c")
+ num_true_conditions = math_ops.reduce_sum(
+ math_ops.cast(preds_c, types.int32), name="num_true_conds")
+ at_most_one_true_condition = math_ops.less(
+ num_true_conditions, constant_op.constant(2, name="two_true_conds"))
+
+ error_msg = [
+ ("More than one condition evaluated as True but "
+ "exclusive=True. Conditions: (%s), Values:"
+ % ", ".join([p.name for p in preds])),
+ preds_c]
+ with ops.control_dependencies([
+ logging_ops.Assert(condition=at_most_one_true_condition,
+ data=error_msg, summarize=len(preds))]):
+ prev_case_seq = default()
+ for i, (cp, fn) in enumerate(zip(case_preds, fns)[::-1]):
+ prev_case_seq = cond(cp, fn, lambda: prev_case_seq, name="If_%d" % i)
+ else:
+ prev_case_seq = default()
+ for i, (cp, fn) in enumerate(zip(case_preds, fns)[::-1]):
+ prev_case_seq = cond(cp, fn, lambda: prev_case_seq, name="If_%d" % i)
+
+ return prev_case_seq
+
+
+ops.RegisterShape("Enter")(common_shapes.unchanged_shape)
+ops.RegisterShape("Exit")(common_shapes.unknown_shape)
+ops.RegisterShape("NextIteration")(common_shapes.unchanged_shape)
+ops.RegisterShape("RefEnter")(common_shapes.unchanged_shape)
+ops.RegisterShape("ControlTrigger")(common_shapes.no_outputs)
+ops.RegisterShape("NoOp")(common_shapes.no_outputs)
+
+
+@ops.RegisterShape("LoopCond")
+def _LoopCondShape(op):
+ """Shape function for the LoopCond op."""
+ return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())]
+
+
+@ops.RegisterShape("Merge")
+def _MergeShape(op):
+ """Shape function for the Merge op.
+
+ The Merge op takes many inputs of arbitrary shapes, and produces a
+ first output that is one of those inputs, and a second scalar
+ output.
+
+ This function conservatively assumes that if any of its inputs is
+ not fully defined, the output shape is unknown. If all of the inputs
+ have the exact same known shape, the output must have that shape.
+
+ Args:
+ op: A Merge Operation.
+
+ Returns:
+ A single-element list containing the Shape of the Merge op.
+
+ """
+ first_input_shape = op.inputs[0].get_shape()
+ if first_input_shape.is_fully_defined():
+ for input_ in op.inputs[1:]:
+ input_shape = input_.get_shape()
+ if (not input_shape.is_fully_defined()
+ or not input_shape.is_compatible_with(first_input_shape)):
+ return [tensor_shape.unknown_shape(), tensor_shape.scalar()]
+ return [first_input_shape, tensor_shape.scalar()]
+ else:
+ return [tensor_shape.unknown_shape(), tensor_shape.scalar()]
+
+
+@ops.RegisterShape("RefSelect")
+def _RefSelectShape(op):
+ """Shape function for the RefSelect op.
+
+ The RefSelect takes one scalar input and N inputs of arbitrary
+ shapes, and produces one output, which is one of those N inputs.
+
+ This function conservatively assumes that if any of the N inputs is
+ not fully defined, the output shape is unknown. If all of the N
+ inputs have the exact same known shape, the output must have that
+ shape.
+
+ Args:
+ op: A RefSelect Operation.
+
+ Returns:
+ A single-element list containing the Shape of the RefSelect op.
+ """
+ unused_shape = op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
+ first_input_shape = op.inputs[1].get_shape()
+ if first_input_shape.is_fully_defined():
+ for input_ in op.inputs[2:]:
+ input_shape = input_.get_shape()
+ if (not input_shape.is_fully_defined()
+ or not input_shape.is_compatible_with(first_input_shape)):
+ return [tensor_shape.unknown_shape()]
+ return [first_input_shape]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("RefSwitch")
+@ops.RegisterShape("Switch")
+def _SwitchShape(op):
+ input_shape = op.inputs[0].get_shape()
+ unused_pred_shape = op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
+ return [input_shape] * 2