diff options
author | 2018-10-02 13:18:27 -0700 | |
---|---|---|
committer | 2018-10-02 13:22:56 -0700 | |
commit | 8d12c635cc48e896da0bcac1cd568bd6381ca64e (patch) | |
tree | d651bbcfdd325e649c230c19424acc62c28de725 /tensorflow/python/ops | |
parent | 78e4ce52aeda5a10ddaf5e64ea8958f439a2f9f2 (diff) |
Support shape_invariants in while_v2. Note that this arg is temporary and may be replaced by automatic shape inference in TF 2.0 (or before).
Add a output_shapes attr to While op to allow output shapes to be different from the incoming loop_vars.
PiperOrigin-RevId: 215446737
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 3 | ||||
-rw-r--r-- | tensorflow/python/ops/while_v2.py | 59 |
2 files changed, 53 insertions, 9 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 8ad71fe00c..f779c3d273 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -3225,7 +3225,8 @@ def while_loop(cond, raise ValueError("The while_v2 module is not set. Did you forget to " "import tensorflow.python.ops." "while_v2?") - return _while_v2.while_loop(cond, body, loop_vars, name) + return _while_v2.while_loop( + cond, body, loop_vars, shape_invariants=shape_invariants, name=name) with ops.name_scope(name, "while", loop_vars): if not loop_vars: diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 6791e1cd61..8e88a84d60 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import function_def_to_graph from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2_impl as cond_v2 from tensorflow.python.ops import control_flow_ops @@ -52,8 +53,17 @@ control_flow_ops._while_v2 = sys.modules[__name__] # handled in the CapturingGraph itself. -def while_loop(cond, body, loop_vars, name=None): +def while_loop(cond, body, loop_vars, shape_invariants=None, name=None): """Like tf.while_loop, except emits a single While op.""" + flattened_loop_vars = nest.flatten(loop_vars) + if shape_invariants is not None: + nest.assert_same_structure(loop_vars, shape_invariants) + flattened_shapes = nest.flatten(shape_invariants) + else: + flattened_shapes = [t.shape for t in flattened_loop_vars] + + del shape_invariants + if not name: name = "while" @@ -62,25 +72,33 @@ def while_loop(cond, body, loop_vars, name=None): cond_name = _get_unique_name(("%scond" % scope).replace("/", "_")) body_name = _get_unique_name(("%sbody" % scope).replace("/", "_")) - flattened_loop_vars = nest.flatten(loop_vars) num_outputs = len(flattened_loop_vars) # Add loop counter needed for computing gradients. flattened_loop_vars = [constant_op.constant(0., name="loop_counter") ] + flattened_loop_vars + flattened_shapes = [tensor_shape.scalar()] + flattened_shapes + # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(unused_loop_counter, *loop_vars): return cond(*loop_vars) - cond_graph = function.func_graph_from_py_func(cond_name, wrapped_cond, - flattened_loop_vars, {}) + signature = [ + tensor_spec.TensorSpec(shape, t.dtype) + for shape, t in zip(flattened_shapes, flattened_loop_vars) + ] + cond_graph = function.func_graph_from_py_func( + cond_name, wrapped_cond, flattened_loop_vars, {}, signature=signature) # Add external_captures of cond to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures + flattened_shapes = flattened_shapes + [ + t.shape for t in cond_graph.external_captures + ] def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. @@ -105,8 +123,12 @@ def while_loop(cond, body, loop_vars, name=None): # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list(args[num_outputs:]) - body_graph = function.func_graph_from_py_func(body_name, wrapped_body, - flattened_loop_vars, {}) + signature = [ + tensor_spec.TensorSpec(shape, t.dtype) + for shape, t in zip(flattened_shapes, flattened_loop_vars) + ] + body_graph = function.func_graph_from_py_func( + body_name, wrapped_body, flattened_loop_vars, {}, signature=signature) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the @@ -149,10 +171,17 @@ def while_loop(cond, body, loop_vars, name=None): # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) + # Make sure that the shapes of the loop outputs are compatible with the + # shape invariants, or the shapes of the loop vars if the invariants are not + # specified. + _check_shapes_compat(body_graph.outputs[1:1 + num_outputs], + flattened_shapes[1:1 + num_outputs], + flattened_loop_vars[1:1 + num_outputs]) outputs = gen_functional_ops._while( flattened_loop_vars, cond_v2._create_new_tf_function(cond_graph), cond_v2._create_new_tf_function(body_graph), + output_shapes=[t.shape for t in body_graph.outputs], name=scope) _copy_handle_data(body_graph.outputs, outputs) @@ -216,6 +245,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name loop_vars, cond_v2._create_new_tf_function(cond_grad_graph), cond_v2._create_new_tf_function(body_grad_graph), + output_shapes=[t.shape for t in body_grad_graph.outputs], name=_get_unique_name("%s_grad" % op.name)) _copy_handle_data(body_grad_graph.outputs, outputs) @@ -236,8 +266,10 @@ def _get_body_graph(while_op): Returns: `FuncGraph` for the while body. """ - extra_inputs = list(while_op.inputs) - input_shapes = [t.shape for t in extra_inputs] + # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes. + input_shapes = [ + tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes") + ] func_name = while_op.get_attr("body").name fdef = while_op.graph._get_function(func_name).definition func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) @@ -535,6 +567,17 @@ class _WhileBodyGradFuncGraph(function.FuncGraph): return captured_tensor +def _check_shapes_compat(output_tensors, shape_invariants, input_tensors): + for (t, shape, input_t) in zip(output_tensors, shape_invariants, + input_tensors): + if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape): + raise ValueError( + "Input tensor '%s' enters the loop with shape %s, but has " + "shape %s after one iteration. To allow the shape to vary across " + "iterations, use the `shape_invariants` argument of tf.while_loop to " + "specify a less-specific shape." % (input_t.name, shape, t.shape)) + + def _copy_handle_data(src_tensors, tgt_tensors): for src_t, tgt_t in zip(src_tensors, tgt_tensors): function._copy_handle_data(src_t, tgt_t) |