diff options
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/eager/function.py | 87 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 18 | ||||
-rw-r--r-- | tensorflow/python/framework/op_def_library.py | 3 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/cond_v2_test.py | 1 | ||||
-rw-r--r-- | tensorflow/python/ops/custom_gradient.py | 44 | ||||
-rw-r--r-- | tensorflow/python/ops/gradients_impl.py | 30 | ||||
-rw-r--r-- | tensorflow/python/ops/while_v2.py | 3 |
7 files changed, 125 insertions, 61 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 93168826b1..99bf375ea7 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -46,6 +46,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops @@ -81,49 +82,10 @@ def _create_substitute_placeholder(value, name=None, dtype=None): with ops.control_dependencies(None): placeholder = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) - _copy_handle_data(value, placeholder) + custom_gradient.copy_handle_data(value, placeholder) return placeholder -def _copy_handle_data(source_t, target_t): - """Copies HandleData for variant and resource type tensors if available. - - The CppShapeInferenceResult::HandleData proto contains information about the - shapes and types of the element tensors of resource/variant type tensors. - We need to copy this across function boundaries, i.e., when capturing a - placeholder or when returning a function tensor as output. If we don't do this - the element tensors will have unknown shapes, e.g., if a TensorList variant - tensor is captured as a placeholder, elements popped from that list would have - unknown shape. - - Args: - source_t: The tensor to copy HandleData from. - target_t: The tensor to copy HandleData to. - """ - if (target_t.dtype == dtypes_module.resource or - target_t.dtype == dtypes_module.variant): - if isinstance(source_t, ops.EagerTensor): - handle_data = source_t._handle_data # pylint: disable=protected-access - else: - handle_data = resource_variable_ops.get_resource_handle_data(source_t) - if handle_data is not None and handle_data.is_set: - # pylint: disable=protected-access - pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph, - target_t._as_tf_output(), - handle_data.SerializeToString()) - # pylint: enable=protected-access - # Ensure that shapes and dtypes are propagated. - shapes, types = zip(*[(pair.shape, pair.dtype) - for pair in handle_data.shape_and_type]) - ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] - shapes = [[d.size for d in s.dim] - if not s.unknown_rank else None for s in shapes] - pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( - target_t._op._graph._c_graph, # pylint: disable=protected-access - target_t._as_tf_output(), # pylint: disable=protected-access - shapes, ranks, types) - - def _get_device_functions(ctx, graph): """Returns a tuple of device functions representing the device stack.""" if ctx.executing_eagerly(): @@ -547,7 +509,7 @@ class _EagerDefinedFunction(object): for i, shape in enumerate(self._output_shapes): outputs[i].set_shape(shape) for i, func_graph_output in enumerate(self._func_graph_outputs): - _copy_handle_data(func_graph_output, outputs[i]) + custom_gradient.copy_handle_data(func_graph_output, outputs[i]) return outputs @@ -658,7 +620,48 @@ class Function(object): if tape.should_record(tensor_inputs) or tape.should_record(captures): return self._backprop_call(args) - outputs = self._inference_function.call(ctx, args) + # Only need to override the gradient in graph mode and when we have outputs. + if context.executing_eagerly() or not self.outputs: + outputs = self._inference_function.call(ctx, args) + else: + name = "PartitionedCall-%s" % ops.uid() + + @ops.RegisterGradient(name) + def grad_fn(op, *doutputs): # pylint: disable=unused-variable + """Gradients of this function.""" + if op.graph is not ops.get_default_graph(): + # TODO(apassos) this will still emit SymbolicGradient ops when + # nested defuns are being differentiated. We need to somehow figure + # out a way to update the FunctionDef corresponding to the calling + # function when mutating a call to the forward pass. + return gradients_impl._SymGrad(op, list(doutputs)) # pylint: disable=protected-access + if self._backward_graph_function is None: + self._construct_backprop_function() + self._forward_function.add_to_graph(op.graph) + func = attr_value_pb2.AttrValue( + func=attr_value_pb2.NameAttrList( + name=self._forward_function.name)) + # pylint: disable=protected-access + op._set_attr("f", func) + types = attr_value_pb2.AttrValue.ListValue( + type=self._forward_function._output_types) + op._set_attr("Tout", attr_value_pb2.AttrValue(list=types)) + for i in range( + len(outputs), len(self._forward_function._output_types)): + t = ops.Tensor(op, i, self._forward_function._output_types[i]) + t.set_shape(self._forward_function._output_shapes[i]) + func_graph_output = self._forward_function._func_graph_outputs[i] + custom_gradient.copy_handle_data(func_graph_output, t) + op._outputs.append(t) + # pylint: enable=protected-access + side_outputs = op.outputs[len(outputs):] + return self._backward_graph_function( + *(list(doutputs) + list(side_outputs))) + + with ops.get_default_graph().gradient_override_map( + {"PartitionedCall": name}): + outputs = self._inference_function.call(ctx, args) + return self._build_call_outputs(outputs) @property diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 57e545be69..e46bde098b 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -286,7 +286,23 @@ class FunctionTest(test.TestCase): c = constant_op.constant([[2.]]) f_c = f(c) g, = gradients_impl.gradients(f_c, c) - self.assertAllEqual(sess.run(g), [[1.0]]) + self.assertAllEqual(sess.run(g).values, [[1.0]]) + + def testNoSymGradNestedDefun(self): + + @function.defun + def outer(): + + @function.defun + def f(x): + return array_ops.gather_nd(x, [[0]]) + + c = constant_op.constant([[2.]]) + f_c = f(c) + g, = gradients_impl.gradients(f_c, c) + self.assertTrue(isinstance(g, ops.IndexedSlices)) + + outer() def testNestedInputsGraphFunction(self): matmul = function.defun(math_ops.matmul) diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index e85bba11cd..9955a9a2cd 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -482,7 +482,8 @@ class OpDefLibrary(object): else: raise TypeError("%s that don't all match." % prefix) else: - raise TypeError("%s that are invalid." % prefix) + raise TypeError( + "%s that are invalid. Tensors: %s" % (prefix, values)) types = [x.dtype for x in values] inputs.extend(values) diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index ec875aae59..a424a0f219 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -153,6 +153,7 @@ class CondV2Test(test.TestCase): self.assertIn("foo_cond_1_false", ops.get_default_graph()._functions) def testDefunInCond(self): + self.skipTest("b/117293122") x = constant_op.constant(1.0, name="x") y = constant_op.constant(2.0, name="y") diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index d7834ba350..bfe23834b7 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -18,9 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import tape as tape_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -33,6 +35,45 @@ from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export +def copy_handle_data(source_t, target_t): + """Copies HandleData for variant and resource type tensors if available. + + The CppShapeInferenceResult::HandleData proto contains information about the + shapes and types of the element tensors of resource/variant type tensors. + We need to copy this across function boundaries, i.e., when capturing a + placeholder or when returning a function tensor as output. If we don't do this + the element tensors will have unknown shapes, e.g., if a TensorList variant + tensor is captured as a placeholder, elements popped from that list would have + unknown shape. + + Args: + source_t: The tensor to copy HandleData from. + target_t: The tensor to copy HandleData to. + """ + if (target_t.dtype == dtypes.resource or + target_t.dtype == dtypes.variant): + if isinstance(source_t, ops.EagerTensor): + handle_data = source_t._handle_data # pylint: disable=protected-access + else: + handle_data = resource_variable_ops.get_resource_handle_data(source_t) + if handle_data is not None and handle_data.is_set: + # pylint: disable=protected-access + pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph, + target_t._as_tf_output(), + handle_data.SerializeToString()) + # pylint: enable=protected-access + # Ensure that shapes and dtypes are propagated. + shapes, types = zip(*[(pair.shape, pair.dtype) + for pair in handle_data.shape_and_type]) + ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] + shapes = [[d.size for d in s.dim] + if not s.unknown_rank else None for s in shapes] + pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( + target_t._op._graph._c_graph, # pylint: disable=protected-access + target_t._as_tf_output(), # pylint: disable=protected-access + shapes, ranks, types) + + @tf_export("custom_gradient") def custom_gradient(f): """Decorator to define a function with a custom gradient. @@ -180,8 +221,11 @@ def _graph_mode_decorator(f, *args, **kwargs): input_grads = nest.flatten(input_grads) return ([None] * len(flat_result)) + input_grads + variable_grads + original_tensors = all_tensors with ops.get_default_graph().gradient_override_map({"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) + for ot, t in zip(original_tensors, all_tensors): + copy_handle_data(ot, t) return nest.pack_sequence_as( structure=result, flat_sequence=all_tensors[:len(flat_result)]) diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index aac95037dc..6909fcaed5 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -800,23 +800,21 @@ def _GradientsHelper(ys, # pylint: enable=protected-access has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) if has_out_grads and (op not in stop_ops): - if is_func_call: - if is_partitioned_call: - func_call = src_graph._get_function( # pylint: disable=protected-access - compat.as_bytes(op.get_attr("f").name)) + try: + grad_fn = ops.get_gradient_function(op) + except LookupError: + if is_func_call: + if is_partitioned_call: + func_call = src_graph._get_function( # pylint: disable=protected-access + compat.as_bytes(op.get_attr("f").name)) + else: + func_call = src_graph._get_function(op.type) # pylint: disable=protected-access + # Note that __defun is not set if the graph is + # imported. If it's set, we prefer to access the original + # defun. + func_call = getattr(op, "__defun", func_call) + grad_fn = func_call.python_grad_func else: - func_call = src_graph._get_function(op.type) # pylint: disable=protected-access - # Note that __defun is not set if the graph is - # imported. If it's set, we prefer to access the original - # defun. - func_call = getattr(op, "__defun", func_call) - grad_fn = func_call.python_grad_func - else: - # A grad_fn must be defined, either as a function or as None - # for ops that do not have gradients. - try: - grad_fn = ops.get_gradient_function(op) - except LookupError: raise LookupError( "No gradient defined for operation '%s' (op type: %s)" % (op.name, op.type)) diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 8e88a84d60..0419656143 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2_impl as cond_v2 from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gen_functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import list_ops @@ -580,7 +581,7 @@ def _check_shapes_compat(output_tensors, shape_invariants, input_tensors): def _copy_handle_data(src_tensors, tgt_tensors): for src_t, tgt_t in zip(src_tensors, tgt_tensors): - function._copy_handle_data(src_t, tgt_t) + custom_gradient.copy_handle_data(src_t, tgt_t) # TODO(srbs): Move to common utils for cond_v2 and while_v2. |