aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-10-08 13:50:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 13:58:40 -0700
commiteec9ca8f0baccd249a49046fe31b460903e44850 (patch)
treeb6397af544af7c05abca4bea08bd6354f90bedf1 /tensorflow/python/ops
parent494bbdfced3fd8596721d12e73676c4967f452e4 (diff)
Partial support tfe.defun in tf.gradients.
Doesn't attempt to deal with cases where we might have already generated the functiondef for the parent function as in that case we cannot easily modify the forward pass. PiperOrigin-RevId: 216243224
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/custom_gradient.py44
-rw-r--r--tensorflow/python/ops/gradients_impl.py30
-rw-r--r--tensorflow/python/ops/while_v2.py3
3 files changed, 60 insertions, 17 deletions
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index d7834ba350..bfe23834b7 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape as tape_lib
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
@@ -33,6 +35,45 @@ from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
+def copy_handle_data(source_t, target_t):
+ """Copies HandleData for variant and resource type tensors if available.
+
+ The CppShapeInferenceResult::HandleData proto contains information about the
+ shapes and types of the element tensors of resource/variant type tensors.
+ We need to copy this across function boundaries, i.e., when capturing a
+ placeholder or when returning a function tensor as output. If we don't do this
+ the element tensors will have unknown shapes, e.g., if a TensorList variant
+ tensor is captured as a placeholder, elements popped from that list would have
+ unknown shape.
+
+ Args:
+ source_t: The tensor to copy HandleData from.
+ target_t: The tensor to copy HandleData to.
+ """
+ if (target_t.dtype == dtypes.resource or
+ target_t.dtype == dtypes.variant):
+ if isinstance(source_t, ops.EagerTensor):
+ handle_data = source_t._handle_data # pylint: disable=protected-access
+ else:
+ handle_data = resource_variable_ops.get_resource_handle_data(source_t)
+ if handle_data is not None and handle_data.is_set:
+ # pylint: disable=protected-access
+ pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
+ target_t._as_tf_output(),
+ handle_data.SerializeToString())
+ # pylint: enable=protected-access
+ # Ensure that shapes and dtypes are propagated.
+ shapes, types = zip(*[(pair.shape, pair.dtype)
+ for pair in handle_data.shape_and_type])
+ ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
+ shapes = [[d.size for d in s.dim]
+ if not s.unknown_rank else None for s in shapes]
+ pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
+ target_t._op._graph._c_graph, # pylint: disable=protected-access
+ target_t._as_tf_output(), # pylint: disable=protected-access
+ shapes, ranks, types)
+
+
@tf_export("custom_gradient")
def custom_gradient(f):
"""Decorator to define a function with a custom gradient.
@@ -180,8 +221,11 @@ def _graph_mode_decorator(f, *args, **kwargs):
input_grads = nest.flatten(input_grads)
return ([None] * len(flat_result)) + input_grads + variable_grads
+ original_tensors = all_tensors
with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
all_tensors = array_ops.identity_n(all_tensors)
+ for ot, t in zip(original_tensors, all_tensors):
+ copy_handle_data(ot, t)
return nest.pack_sequence_as(
structure=result, flat_sequence=all_tensors[:len(flat_result)])
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index aac95037dc..6909fcaed5 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -800,23 +800,21 @@ def _GradientsHelper(ys,
# pylint: enable=protected-access
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op not in stop_ops):
- if is_func_call:
- if is_partitioned_call:
- func_call = src_graph._get_function( # pylint: disable=protected-access
- compat.as_bytes(op.get_attr("f").name))
+ try:
+ grad_fn = ops.get_gradient_function(op)
+ except LookupError:
+ if is_func_call:
+ if is_partitioned_call:
+ func_call = src_graph._get_function( # pylint: disable=protected-access
+ compat.as_bytes(op.get_attr("f").name))
+ else:
+ func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
+ # Note that __defun is not set if the graph is
+ # imported. If it's set, we prefer to access the original
+ # defun.
+ func_call = getattr(op, "__defun", func_call)
+ grad_fn = func_call.python_grad_func
else:
- func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
- # Note that __defun is not set if the graph is
- # imported. If it's set, we prefer to access the original
- # defun.
- func_call = getattr(op, "__defun", func_call)
- grad_fn = func_call.python_grad_func
- else:
- # A grad_fn must be defined, either as a function or as None
- # for ops that do not have gradients.
- try:
- grad_fn = ops.get_gradient_function(op)
- except LookupError:
raise LookupError(
"No gradient defined for operation '%s' (op type: %s)" %
(op.name, op.type))
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 8e88a84d60..0419656143 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl as cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
@@ -580,7 +581,7 @@ def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
def _copy_handle_data(src_tensors, tgt_tensors):
for src_t, tgt_t in zip(src_tensors, tgt_tensors):
- function._copy_handle_data(src_t, tgt_t)
+ custom_gradient.copy_handle_data(src_t, tgt_t)
# TODO(srbs): Move to common utils for cond_v2 and while_v2.