aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-10-02 13:18:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 13:22:56 -0700
commit8d12c635cc48e896da0bcac1cd568bd6381ca64e (patch)
treed651bbcfdd325e649c230c19424acc62c28de725 /tensorflow/python/ops
parent78e4ce52aeda5a10ddaf5e64ea8958f439a2f9f2 (diff)
Support shape_invariants in while_v2. Note that this arg is temporary and may be replaced by automatic shape inference in TF 2.0 (or before).
Add a output_shapes attr to While op to allow output shapes to be different from the incoming loop_vars. PiperOrigin-RevId: 215446737
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/control_flow_ops.py3
-rw-r--r--tensorflow/python/ops/while_v2.py59
2 files changed, 53 insertions, 9 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 8ad71fe00c..f779c3d273 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -3225,7 +3225,8 @@ def while_loop(cond,
raise ValueError("The while_v2 module is not set. Did you forget to "
"import tensorflow.python.ops."
"while_v2?")
- return _while_v2.while_loop(cond, body, loop_vars, name)
+ return _while_v2.while_loop(
+ cond, body, loop_vars, shape_invariants=shape_invariants, name=name)
with ops.name_scope(name, "while", loop_vars):
if not loop_vars:
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 6791e1cd61..8e88a84d60 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl as cond_v2
from tensorflow.python.ops import control_flow_ops
@@ -52,8 +53,17 @@ control_flow_ops._while_v2 = sys.modules[__name__]
# handled in the CapturingGraph itself.
-def while_loop(cond, body, loop_vars, name=None):
+def while_loop(cond, body, loop_vars, shape_invariants=None, name=None):
"""Like tf.while_loop, except emits a single While op."""
+ flattened_loop_vars = nest.flatten(loop_vars)
+ if shape_invariants is not None:
+ nest.assert_same_structure(loop_vars, shape_invariants)
+ flattened_shapes = nest.flatten(shape_invariants)
+ else:
+ flattened_shapes = [t.shape for t in flattened_loop_vars]
+
+ del shape_invariants
+
if not name:
name = "while"
@@ -62,25 +72,33 @@ def while_loop(cond, body, loop_vars, name=None):
cond_name = _get_unique_name(("%scond" % scope).replace("/", "_"))
body_name = _get_unique_name(("%sbody" % scope).replace("/", "_"))
- flattened_loop_vars = nest.flatten(loop_vars)
num_outputs = len(flattened_loop_vars)
# Add loop counter needed for computing gradients.
flattened_loop_vars = [constant_op.constant(0., name="loop_counter")
] + flattened_loop_vars
+ flattened_shapes = [tensor_shape.scalar()] + flattened_shapes
+
# Build a `cond` wrapper that can handle the extra counter loop_var.
def wrapped_cond(unused_loop_counter, *loop_vars):
return cond(*loop_vars)
- cond_graph = function.func_graph_from_py_func(cond_name, wrapped_cond,
- flattened_loop_vars, {})
+ signature = [
+ tensor_spec.TensorSpec(shape, t.dtype)
+ for shape, t in zip(flattened_shapes, flattened_loop_vars)
+ ]
+ cond_graph = function.func_graph_from_py_func(
+ cond_name, wrapped_cond, flattened_loop_vars, {}, signature=signature)
# Add external_captures of cond to the list of loop vars.
# Note that external tensors will be treated as loop invariants, i.e.,
# the value of that tensor in each iteration is the same as it was at the
# beginning of the loop execution.
flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures
+ flattened_shapes = flattened_shapes + [
+ t.shape for t in cond_graph.external_captures
+ ]
def wrapped_body(loop_counter, *args):
"""Loop body augmented with counter update.
@@ -105,8 +123,12 @@ def while_loop(cond, body, loop_vars, name=None):
# is_constant=True for inputs that are directly passed to outputs.
return [loop_counter + 1] + list(outputs) + list(args[num_outputs:])
- body_graph = function.func_graph_from_py_func(body_name, wrapped_body,
- flattened_loop_vars, {})
+ signature = [
+ tensor_spec.TensorSpec(shape, t.dtype)
+ for shape, t in zip(flattened_shapes, flattened_loop_vars)
+ ]
+ body_graph = function.func_graph_from_py_func(
+ body_name, wrapped_body, flattened_loop_vars, {}, signature=signature)
# Add external captures of body to the list of loop vars.
# Note that external tensors will be treated as loop invariants, i.e.,
# the value of that tensor in each iteration is the same as it was at the
@@ -149,10 +171,17 @@ def while_loop(cond, body, loop_vars, name=None):
# Add this modified tensor list to the list of outputs.
body_graph.outputs.append(appended_tensor_list)
+ # Make sure that the shapes of the loop outputs are compatible with the
+ # shape invariants, or the shapes of the loop vars if the invariants are not
+ # specified.
+ _check_shapes_compat(body_graph.outputs[1:1 + num_outputs],
+ flattened_shapes[1:1 + num_outputs],
+ flattened_loop_vars[1:1 + num_outputs])
outputs = gen_functional_ops._while(
flattened_loop_vars,
cond_v2._create_new_tf_function(cond_graph),
cond_v2._create_new_tf_function(body_graph),
+ output_shapes=[t.shape for t in body_graph.outputs],
name=scope)
_copy_handle_data(body_graph.outputs, outputs)
@@ -216,6 +245,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
loop_vars,
cond_v2._create_new_tf_function(cond_grad_graph),
cond_v2._create_new_tf_function(body_grad_graph),
+ output_shapes=[t.shape for t in body_grad_graph.outputs],
name=_get_unique_name("%s_grad" % op.name))
_copy_handle_data(body_grad_graph.outputs, outputs)
@@ -236,8 +266,10 @@ def _get_body_graph(while_op):
Returns:
`FuncGraph` for the while body.
"""
- extra_inputs = list(while_op.inputs)
- input_shapes = [t.shape for t in extra_inputs]
+ # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
+ input_shapes = [
+ tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
+ ]
func_name = while_op.get_attr("body").name
fdef = while_op.graph._get_function(func_name).definition
func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
@@ -535,6 +567,17 @@ class _WhileBodyGradFuncGraph(function.FuncGraph):
return captured_tensor
+def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
+ for (t, shape, input_t) in zip(output_tensors, shape_invariants,
+ input_tensors):
+ if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape):
+ raise ValueError(
+ "Input tensor '%s' enters the loop with shape %s, but has "
+ "shape %s after one iteration. To allow the shape to vary across "
+ "iterations, use the `shape_invariants` argument of tf.while_loop to "
+ "specify a less-specific shape." % (input_t.name, shape, t.shape))
+
+
def _copy_handle_data(src_tensors, tgt_tensors):
for src_t, tgt_t in zip(src_tensors, tgt_tensors):
function._copy_handle_data(src_t, tgt_t)