aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-10-08 13:50:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 13:58:40 -0700
commiteec9ca8f0baccd249a49046fe31b460903e44850 (patch)
treeb6397af544af7c05abca4bea08bd6354f90bedf1 /tensorflow/python/eager
parent494bbdfced3fd8596721d12e73676c4967f452e4 (diff)
Partial support tfe.defun in tf.gradients.
Doesn't attempt to deal with cases where we might have already generated the functiondef for the parent function as in that case we cannot easily modify the forward pass. PiperOrigin-RevId: 216243224
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py87
-rw-r--r--tensorflow/python/eager/function_test.py18
2 files changed, 62 insertions, 43 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 93168826b1..99bf375ea7 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -46,6 +46,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
@@ -81,49 +82,10 @@ def _create_substitute_placeholder(value, name=None, dtype=None):
with ops.control_dependencies(None):
placeholder = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
- _copy_handle_data(value, placeholder)
+ custom_gradient.copy_handle_data(value, placeholder)
return placeholder
-def _copy_handle_data(source_t, target_t):
- """Copies HandleData for variant and resource type tensors if available.
-
- The CppShapeInferenceResult::HandleData proto contains information about the
- shapes and types of the element tensors of resource/variant type tensors.
- We need to copy this across function boundaries, i.e., when capturing a
- placeholder or when returning a function tensor as output. If we don't do this
- the element tensors will have unknown shapes, e.g., if a TensorList variant
- tensor is captured as a placeholder, elements popped from that list would have
- unknown shape.
-
- Args:
- source_t: The tensor to copy HandleData from.
- target_t: The tensor to copy HandleData to.
- """
- if (target_t.dtype == dtypes_module.resource or
- target_t.dtype == dtypes_module.variant):
- if isinstance(source_t, ops.EagerTensor):
- handle_data = source_t._handle_data # pylint: disable=protected-access
- else:
- handle_data = resource_variable_ops.get_resource_handle_data(source_t)
- if handle_data is not None and handle_data.is_set:
- # pylint: disable=protected-access
- pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
- target_t._as_tf_output(),
- handle_data.SerializeToString())
- # pylint: enable=protected-access
- # Ensure that shapes and dtypes are propagated.
- shapes, types = zip(*[(pair.shape, pair.dtype)
- for pair in handle_data.shape_and_type])
- ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
- shapes = [[d.size for d in s.dim]
- if not s.unknown_rank else None for s in shapes]
- pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
- target_t._op._graph._c_graph, # pylint: disable=protected-access
- target_t._as_tf_output(), # pylint: disable=protected-access
- shapes, ranks, types)
-
-
def _get_device_functions(ctx, graph):
"""Returns a tuple of device functions representing the device stack."""
if ctx.executing_eagerly():
@@ -547,7 +509,7 @@ class _EagerDefinedFunction(object):
for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
for i, func_graph_output in enumerate(self._func_graph_outputs):
- _copy_handle_data(func_graph_output, outputs[i])
+ custom_gradient.copy_handle_data(func_graph_output, outputs[i])
return outputs
@@ -658,7 +620,48 @@ class Function(object):
if tape.should_record(tensor_inputs) or tape.should_record(captures):
return self._backprop_call(args)
- outputs = self._inference_function.call(ctx, args)
+ # Only need to override the gradient in graph mode and when we have outputs.
+ if context.executing_eagerly() or not self.outputs:
+ outputs = self._inference_function.call(ctx, args)
+ else:
+ name = "PartitionedCall-%s" % ops.uid()
+
+ @ops.RegisterGradient(name)
+ def grad_fn(op, *doutputs): # pylint: disable=unused-variable
+ """Gradients of this function."""
+ if op.graph is not ops.get_default_graph():
+ # TODO(apassos) this will still emit SymbolicGradient ops when
+ # nested defuns are being differentiated. We need to somehow figure
+ # out a way to update the FunctionDef corresponding to the calling
+ # function when mutating a call to the forward pass.
+ return gradients_impl._SymGrad(op, list(doutputs)) # pylint: disable=protected-access
+ if self._backward_graph_function is None:
+ self._construct_backprop_function()
+ self._forward_function.add_to_graph(op.graph)
+ func = attr_value_pb2.AttrValue(
+ func=attr_value_pb2.NameAttrList(
+ name=self._forward_function.name))
+ # pylint: disable=protected-access
+ op._set_attr("f", func)
+ types = attr_value_pb2.AttrValue.ListValue(
+ type=self._forward_function._output_types)
+ op._set_attr("Tout", attr_value_pb2.AttrValue(list=types))
+ for i in range(
+ len(outputs), len(self._forward_function._output_types)):
+ t = ops.Tensor(op, i, self._forward_function._output_types[i])
+ t.set_shape(self._forward_function._output_shapes[i])
+ func_graph_output = self._forward_function._func_graph_outputs[i]
+ custom_gradient.copy_handle_data(func_graph_output, t)
+ op._outputs.append(t)
+ # pylint: enable=protected-access
+ side_outputs = op.outputs[len(outputs):]
+ return self._backward_graph_function(
+ *(list(doutputs) + list(side_outputs)))
+
+ with ops.get_default_graph().gradient_override_map(
+ {"PartitionedCall": name}):
+ outputs = self._inference_function.call(ctx, args)
+
return self._build_call_outputs(outputs)
@property
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 57e545be69..e46bde098b 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -286,7 +286,23 @@ class FunctionTest(test.TestCase):
c = constant_op.constant([[2.]])
f_c = f(c)
g, = gradients_impl.gradients(f_c, c)
- self.assertAllEqual(sess.run(g), [[1.0]])
+ self.assertAllEqual(sess.run(g).values, [[1.0]])
+
+ def testNoSymGradNestedDefun(self):
+
+ @function.defun
+ def outer():
+
+ @function.defun
+ def f(x):
+ return array_ops.gather_nd(x, [[0]])
+
+ c = constant_op.constant([[2.]])
+ f_c = f(c)
+ g, = gradients_impl.gradients(f_c, c)
+ self.assertTrue(isinstance(g, ops.IndexedSlices))
+
+ outer()
def testNestedInputsGraphFunction(self):
matmul = function.defun(math_ops.matmul)