aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/custom_gradient.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/custom_gradient.py')
-rw-r--r--tensorflow/python/ops/custom_gradient.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index d7834ba350..bfe23834b7 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape as tape_lib
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
@@ -33,6 +35,45 @@ from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
+def copy_handle_data(source_t, target_t):
+ """Copies HandleData for variant and resource type tensors if available.
+
+ The CppShapeInferenceResult::HandleData proto contains information about the
+ shapes and types of the element tensors of resource/variant type tensors.
+ We need to copy this across function boundaries, i.e., when capturing a
+ placeholder or when returning a function tensor as output. If we don't do this
+ the element tensors will have unknown shapes, e.g., if a TensorList variant
+ tensor is captured as a placeholder, elements popped from that list would have
+ unknown shape.
+
+ Args:
+ source_t: The tensor to copy HandleData from.
+ target_t: The tensor to copy HandleData to.
+ """
+ if (target_t.dtype == dtypes.resource or
+ target_t.dtype == dtypes.variant):
+ if isinstance(source_t, ops.EagerTensor):
+ handle_data = source_t._handle_data # pylint: disable=protected-access
+ else:
+ handle_data = resource_variable_ops.get_resource_handle_data(source_t)
+ if handle_data is not None and handle_data.is_set:
+ # pylint: disable=protected-access
+ pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
+ target_t._as_tf_output(),
+ handle_data.SerializeToString())
+ # pylint: enable=protected-access
+ # Ensure that shapes and dtypes are propagated.
+ shapes, types = zip(*[(pair.shape, pair.dtype)
+ for pair in handle_data.shape_and_type])
+ ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
+ shapes = [[d.size for d in s.dim]
+ if not s.unknown_rank else None for s in shapes]
+ pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
+ target_t._op._graph._c_graph, # pylint: disable=protected-access
+ target_t._as_tf_output(), # pylint: disable=protected-access
+ shapes, ranks, types)
+
+
@tf_export("custom_gradient")
def custom_gradient(f):
"""Decorator to define a function with a custom gradient.
@@ -180,8 +221,11 @@ def _graph_mode_decorator(f, *args, **kwargs):
input_grads = nest.flatten(input_grads)
return ([None] * len(flat_result)) + input_grads + variable_grads
+ original_tensors = all_tensors
with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
all_tensors = array_ops.identity_n(all_tensors)
+ for ot, t in zip(original_tensors, all_tensors):
+ copy_handle_data(ot, t)
return nest.pack_sequence_as(
structure=result, flat_sequence=all_tensors[:len(flat_result)])