aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-10-08 13:50:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 13:58:40 -0700
commiteec9ca8f0baccd249a49046fe31b460903e44850 (patch)
treeb6397af544af7c05abca4bea08bd6354f90bedf1 /tensorflow/python
parent494bbdfced3fd8596721d12e73676c4967f452e4 (diff)
Partial support tfe.defun in tf.gradients.
Doesn't attempt to deal with cases where we might have already generated the functiondef for the parent function as in that case we cannot easily modify the forward pass. PiperOrigin-RevId: 216243224
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/function.py87
-rw-r--r--tensorflow/python/eager/function_test.py18
-rw-r--r--tensorflow/python/framework/op_def_library.py3
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py1
-rw-r--r--tensorflow/python/ops/custom_gradient.py44
-rw-r--r--tensorflow/python/ops/gradients_impl.py30
-rw-r--r--tensorflow/python/ops/while_v2.py3
7 files changed, 125 insertions, 61 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 93168826b1..99bf375ea7 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -46,6 +46,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
@@ -81,49 +82,10 @@ def _create_substitute_placeholder(value, name=None, dtype=None):
with ops.control_dependencies(None):
placeholder = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
- _copy_handle_data(value, placeholder)
+ custom_gradient.copy_handle_data(value, placeholder)
return placeholder
-def _copy_handle_data(source_t, target_t):
- """Copies HandleData for variant and resource type tensors if available.
-
- The CppShapeInferenceResult::HandleData proto contains information about the
- shapes and types of the element tensors of resource/variant type tensors.
- We need to copy this across function boundaries, i.e., when capturing a
- placeholder or when returning a function tensor as output. If we don't do this
- the element tensors will have unknown shapes, e.g., if a TensorList variant
- tensor is captured as a placeholder, elements popped from that list would have
- unknown shape.
-
- Args:
- source_t: The tensor to copy HandleData from.
- target_t: The tensor to copy HandleData to.
- """
- if (target_t.dtype == dtypes_module.resource or
- target_t.dtype == dtypes_module.variant):
- if isinstance(source_t, ops.EagerTensor):
- handle_data = source_t._handle_data # pylint: disable=protected-access
- else:
- handle_data = resource_variable_ops.get_resource_handle_data(source_t)
- if handle_data is not None and handle_data.is_set:
- # pylint: disable=protected-access
- pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
- target_t._as_tf_output(),
- handle_data.SerializeToString())
- # pylint: enable=protected-access
- # Ensure that shapes and dtypes are propagated.
- shapes, types = zip(*[(pair.shape, pair.dtype)
- for pair in handle_data.shape_and_type])
- ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
- shapes = [[d.size for d in s.dim]
- if not s.unknown_rank else None for s in shapes]
- pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
- target_t._op._graph._c_graph, # pylint: disable=protected-access
- target_t._as_tf_output(), # pylint: disable=protected-access
- shapes, ranks, types)
-
-
def _get_device_functions(ctx, graph):
"""Returns a tuple of device functions representing the device stack."""
if ctx.executing_eagerly():
@@ -547,7 +509,7 @@ class _EagerDefinedFunction(object):
for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
for i, func_graph_output in enumerate(self._func_graph_outputs):
- _copy_handle_data(func_graph_output, outputs[i])
+ custom_gradient.copy_handle_data(func_graph_output, outputs[i])
return outputs
@@ -658,7 +620,48 @@ class Function(object):
if tape.should_record(tensor_inputs) or tape.should_record(captures):
return self._backprop_call(args)
- outputs = self._inference_function.call(ctx, args)
+ # Only need to override the gradient in graph mode and when we have outputs.
+ if context.executing_eagerly() or not self.outputs:
+ outputs = self._inference_function.call(ctx, args)
+ else:
+ name = "PartitionedCall-%s" % ops.uid()
+
+ @ops.RegisterGradient(name)
+ def grad_fn(op, *doutputs): # pylint: disable=unused-variable
+ """Gradients of this function."""
+ if op.graph is not ops.get_default_graph():
+ # TODO(apassos) this will still emit SymbolicGradient ops when
+ # nested defuns are being differentiated. We need to somehow figure
+ # out a way to update the FunctionDef corresponding to the calling
+ # function when mutating a call to the forward pass.
+ return gradients_impl._SymGrad(op, list(doutputs)) # pylint: disable=protected-access
+ if self._backward_graph_function is None:
+ self._construct_backprop_function()
+ self._forward_function.add_to_graph(op.graph)
+ func = attr_value_pb2.AttrValue(
+ func=attr_value_pb2.NameAttrList(
+ name=self._forward_function.name))
+ # pylint: disable=protected-access
+ op._set_attr("f", func)
+ types = attr_value_pb2.AttrValue.ListValue(
+ type=self._forward_function._output_types)
+ op._set_attr("Tout", attr_value_pb2.AttrValue(list=types))
+ for i in range(
+ len(outputs), len(self._forward_function._output_types)):
+ t = ops.Tensor(op, i, self._forward_function._output_types[i])
+ t.set_shape(self._forward_function._output_shapes[i])
+ func_graph_output = self._forward_function._func_graph_outputs[i]
+ custom_gradient.copy_handle_data(func_graph_output, t)
+ op._outputs.append(t)
+ # pylint: enable=protected-access
+ side_outputs = op.outputs[len(outputs):]
+ return self._backward_graph_function(
+ *(list(doutputs) + list(side_outputs)))
+
+ with ops.get_default_graph().gradient_override_map(
+ {"PartitionedCall": name}):
+ outputs = self._inference_function.call(ctx, args)
+
return self._build_call_outputs(outputs)
@property
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 57e545be69..e46bde098b 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -286,7 +286,23 @@ class FunctionTest(test.TestCase):
c = constant_op.constant([[2.]])
f_c = f(c)
g, = gradients_impl.gradients(f_c, c)
- self.assertAllEqual(sess.run(g), [[1.0]])
+ self.assertAllEqual(sess.run(g).values, [[1.0]])
+
+ def testNoSymGradNestedDefun(self):
+
+ @function.defun
+ def outer():
+
+ @function.defun
+ def f(x):
+ return array_ops.gather_nd(x, [[0]])
+
+ c = constant_op.constant([[2.]])
+ f_c = f(c)
+ g, = gradients_impl.gradients(f_c, c)
+ self.assertTrue(isinstance(g, ops.IndexedSlices))
+
+ outer()
def testNestedInputsGraphFunction(self):
matmul = function.defun(math_ops.matmul)
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index e85bba11cd..9955a9a2cd 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -482,7 +482,8 @@ class OpDefLibrary(object):
else:
raise TypeError("%s that don't all match." % prefix)
else:
- raise TypeError("%s that are invalid." % prefix)
+ raise TypeError(
+ "%s that are invalid. Tensors: %s" % (prefix, values))
types = [x.dtype for x in values]
inputs.extend(values)
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index ec875aae59..a424a0f219 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -153,6 +153,7 @@ class CondV2Test(test.TestCase):
self.assertIn("foo_cond_1_false", ops.get_default_graph()._functions)
def testDefunInCond(self):
+ self.skipTest("b/117293122")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index d7834ba350..bfe23834b7 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape as tape_lib
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
@@ -33,6 +35,45 @@ from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
+def copy_handle_data(source_t, target_t):
+ """Copies HandleData for variant and resource type tensors if available.
+
+ The CppShapeInferenceResult::HandleData proto contains information about the
+ shapes and types of the element tensors of resource/variant type tensors.
+ We need to copy this across function boundaries, i.e., when capturing a
+ placeholder or when returning a function tensor as output. If we don't do this
+ the element tensors will have unknown shapes, e.g., if a TensorList variant
+ tensor is captured as a placeholder, elements popped from that list would have
+ unknown shape.
+
+ Args:
+ source_t: The tensor to copy HandleData from.
+ target_t: The tensor to copy HandleData to.
+ """
+ if (target_t.dtype == dtypes.resource or
+ target_t.dtype == dtypes.variant):
+ if isinstance(source_t, ops.EagerTensor):
+ handle_data = source_t._handle_data # pylint: disable=protected-access
+ else:
+ handle_data = resource_variable_ops.get_resource_handle_data(source_t)
+ if handle_data is not None and handle_data.is_set:
+ # pylint: disable=protected-access
+ pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
+ target_t._as_tf_output(),
+ handle_data.SerializeToString())
+ # pylint: enable=protected-access
+ # Ensure that shapes and dtypes are propagated.
+ shapes, types = zip(*[(pair.shape, pair.dtype)
+ for pair in handle_data.shape_and_type])
+ ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
+ shapes = [[d.size for d in s.dim]
+ if not s.unknown_rank else None for s in shapes]
+ pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
+ target_t._op._graph._c_graph, # pylint: disable=protected-access
+ target_t._as_tf_output(), # pylint: disable=protected-access
+ shapes, ranks, types)
+
+
@tf_export("custom_gradient")
def custom_gradient(f):
"""Decorator to define a function with a custom gradient.
@@ -180,8 +221,11 @@ def _graph_mode_decorator(f, *args, **kwargs):
input_grads = nest.flatten(input_grads)
return ([None] * len(flat_result)) + input_grads + variable_grads
+ original_tensors = all_tensors
with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
all_tensors = array_ops.identity_n(all_tensors)
+ for ot, t in zip(original_tensors, all_tensors):
+ copy_handle_data(ot, t)
return nest.pack_sequence_as(
structure=result, flat_sequence=all_tensors[:len(flat_result)])
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index aac95037dc..6909fcaed5 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -800,23 +800,21 @@ def _GradientsHelper(ys,
# pylint: enable=protected-access
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op not in stop_ops):
- if is_func_call:
- if is_partitioned_call:
- func_call = src_graph._get_function( # pylint: disable=protected-access
- compat.as_bytes(op.get_attr("f").name))
+ try:
+ grad_fn = ops.get_gradient_function(op)
+ except LookupError:
+ if is_func_call:
+ if is_partitioned_call:
+ func_call = src_graph._get_function( # pylint: disable=protected-access
+ compat.as_bytes(op.get_attr("f").name))
+ else:
+ func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
+ # Note that __defun is not set if the graph is
+ # imported. If it's set, we prefer to access the original
+ # defun.
+ func_call = getattr(op, "__defun", func_call)
+ grad_fn = func_call.python_grad_func
else:
- func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
- # Note that __defun is not set if the graph is
- # imported. If it's set, we prefer to access the original
- # defun.
- func_call = getattr(op, "__defun", func_call)
- grad_fn = func_call.python_grad_func
- else:
- # A grad_fn must be defined, either as a function or as None
- # for ops that do not have gradients.
- try:
- grad_fn = ops.get_gradient_function(op)
- except LookupError:
raise LookupError(
"No gradient defined for operation '%s' (op type: %s)" %
(op.name, op.type))
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 8e88a84d60..0419656143 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl as cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
@@ -580,7 +581,7 @@ def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
def _copy_handle_data(src_tensors, tgt_tensors):
for src_t, tgt_t in zip(src_tensors, tgt_tensors):
- function._copy_handle_data(src_t, tgt_t)
+ custom_gradient.copy_handle_data(src_t, tgt_t)
# TODO(srbs): Move to common utils for cond_v2 and while_v2.