diff options
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/custom_gradient.py | 44 | ||||
-rw-r--r-- | tensorflow/python/ops/gradients_impl.py | 30 | ||||
-rw-r--r-- | tensorflow/python/ops/while_v2.py | 3 |
3 files changed, 60 insertions, 17 deletions
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index d7834ba350..bfe23834b7 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -18,9 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import tape as tape_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -33,6 +35,45 @@ from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export +def copy_handle_data(source_t, target_t): + """Copies HandleData for variant and resource type tensors if available. + + The CppShapeInferenceResult::HandleData proto contains information about the + shapes and types of the element tensors of resource/variant type tensors. + We need to copy this across function boundaries, i.e., when capturing a + placeholder or when returning a function tensor as output. If we don't do this + the element tensors will have unknown shapes, e.g., if a TensorList variant + tensor is captured as a placeholder, elements popped from that list would have + unknown shape. + + Args: + source_t: The tensor to copy HandleData from. + target_t: The tensor to copy HandleData to. + """ + if (target_t.dtype == dtypes.resource or + target_t.dtype == dtypes.variant): + if isinstance(source_t, ops.EagerTensor): + handle_data = source_t._handle_data # pylint: disable=protected-access + else: + handle_data = resource_variable_ops.get_resource_handle_data(source_t) + if handle_data is not None and handle_data.is_set: + # pylint: disable=protected-access + pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph, + target_t._as_tf_output(), + handle_data.SerializeToString()) + # pylint: enable=protected-access + # Ensure that shapes and dtypes are propagated. + shapes, types = zip(*[(pair.shape, pair.dtype) + for pair in handle_data.shape_and_type]) + ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] + shapes = [[d.size for d in s.dim] + if not s.unknown_rank else None for s in shapes] + pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( + target_t._op._graph._c_graph, # pylint: disable=protected-access + target_t._as_tf_output(), # pylint: disable=protected-access + shapes, ranks, types) + + @tf_export("custom_gradient") def custom_gradient(f): """Decorator to define a function with a custom gradient. @@ -180,8 +221,11 @@ def _graph_mode_decorator(f, *args, **kwargs): input_grads = nest.flatten(input_grads) return ([None] * len(flat_result)) + input_grads + variable_grads + original_tensors = all_tensors with ops.get_default_graph().gradient_override_map({"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) + for ot, t in zip(original_tensors, all_tensors): + copy_handle_data(ot, t) return nest.pack_sequence_as( structure=result, flat_sequence=all_tensors[:len(flat_result)]) diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index aac95037dc..6909fcaed5 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -800,23 +800,21 @@ def _GradientsHelper(ys, # pylint: enable=protected-access has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) if has_out_grads and (op not in stop_ops): - if is_func_call: - if is_partitioned_call: - func_call = src_graph._get_function( # pylint: disable=protected-access - compat.as_bytes(op.get_attr("f").name)) + try: + grad_fn = ops.get_gradient_function(op) + except LookupError: + if is_func_call: + if is_partitioned_call: + func_call = src_graph._get_function( # pylint: disable=protected-access + compat.as_bytes(op.get_attr("f").name)) + else: + func_call = src_graph._get_function(op.type) # pylint: disable=protected-access + # Note that __defun is not set if the graph is + # imported. If it's set, we prefer to access the original + # defun. + func_call = getattr(op, "__defun", func_call) + grad_fn = func_call.python_grad_func else: - func_call = src_graph._get_function(op.type) # pylint: disable=protected-access - # Note that __defun is not set if the graph is - # imported. If it's set, we prefer to access the original - # defun. - func_call = getattr(op, "__defun", func_call) - grad_fn = func_call.python_grad_func - else: - # A grad_fn must be defined, either as a function or as None - # for ops that do not have gradients. - try: - grad_fn = ops.get_gradient_function(op) - except LookupError: raise LookupError( "No gradient defined for operation '%s' (op type: %s)" % (op.name, op.type)) diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 8e88a84d60..0419656143 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2_impl as cond_v2 from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gen_functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import list_ops @@ -580,7 +581,7 @@ def _check_shapes_compat(output_tensors, shape_invariants, input_tensors): def _copy_handle_data(src_tensors, tgt_tensors): for src_t, tgt_t in zip(src_tensors, tgt_tensors): - function._copy_handle_data(src_t, tgt_t) + custom_gradient.copy_handle_data(src_t, tgt_t) # TODO(srbs): Move to common utils for cond_v2 and while_v2. |