diff options
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r-- | tensorflow/python/eager/function.py | 87 |
1 files changed, 45 insertions, 42 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 93168826b1..99bf375ea7 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -46,6 +46,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops @@ -81,49 +82,10 @@ def _create_substitute_placeholder(value, name=None, dtype=None): with ops.control_dependencies(None): placeholder = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) - _copy_handle_data(value, placeholder) + custom_gradient.copy_handle_data(value, placeholder) return placeholder -def _copy_handle_data(source_t, target_t): - """Copies HandleData for variant and resource type tensors if available. - - The CppShapeInferenceResult::HandleData proto contains information about the - shapes and types of the element tensors of resource/variant type tensors. - We need to copy this across function boundaries, i.e., when capturing a - placeholder or when returning a function tensor as output. If we don't do this - the element tensors will have unknown shapes, e.g., if a TensorList variant - tensor is captured as a placeholder, elements popped from that list would have - unknown shape. - - Args: - source_t: The tensor to copy HandleData from. - target_t: The tensor to copy HandleData to. - """ - if (target_t.dtype == dtypes_module.resource or - target_t.dtype == dtypes_module.variant): - if isinstance(source_t, ops.EagerTensor): - handle_data = source_t._handle_data # pylint: disable=protected-access - else: - handle_data = resource_variable_ops.get_resource_handle_data(source_t) - if handle_data is not None and handle_data.is_set: - # pylint: disable=protected-access - pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph, - target_t._as_tf_output(), - handle_data.SerializeToString()) - # pylint: enable=protected-access - # Ensure that shapes and dtypes are propagated. - shapes, types = zip(*[(pair.shape, pair.dtype) - for pair in handle_data.shape_and_type]) - ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] - shapes = [[d.size for d in s.dim] - if not s.unknown_rank else None for s in shapes] - pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( - target_t._op._graph._c_graph, # pylint: disable=protected-access - target_t._as_tf_output(), # pylint: disable=protected-access - shapes, ranks, types) - - def _get_device_functions(ctx, graph): """Returns a tuple of device functions representing the device stack.""" if ctx.executing_eagerly(): @@ -547,7 +509,7 @@ class _EagerDefinedFunction(object): for i, shape in enumerate(self._output_shapes): outputs[i].set_shape(shape) for i, func_graph_output in enumerate(self._func_graph_outputs): - _copy_handle_data(func_graph_output, outputs[i]) + custom_gradient.copy_handle_data(func_graph_output, outputs[i]) return outputs @@ -658,7 +620,48 @@ class Function(object): if tape.should_record(tensor_inputs) or tape.should_record(captures): return self._backprop_call(args) - outputs = self._inference_function.call(ctx, args) + # Only need to override the gradient in graph mode and when we have outputs. + if context.executing_eagerly() or not self.outputs: + outputs = self._inference_function.call(ctx, args) + else: + name = "PartitionedCall-%s" % ops.uid() + + @ops.RegisterGradient(name) + def grad_fn(op, *doutputs): # pylint: disable=unused-variable + """Gradients of this function.""" + if op.graph is not ops.get_default_graph(): + # TODO(apassos) this will still emit SymbolicGradient ops when + # nested defuns are being differentiated. We need to somehow figure + # out a way to update the FunctionDef corresponding to the calling + # function when mutating a call to the forward pass. + return gradients_impl._SymGrad(op, list(doutputs)) # pylint: disable=protected-access + if self._backward_graph_function is None: + self._construct_backprop_function() + self._forward_function.add_to_graph(op.graph) + func = attr_value_pb2.AttrValue( + func=attr_value_pb2.NameAttrList( + name=self._forward_function.name)) + # pylint: disable=protected-access + op._set_attr("f", func) + types = attr_value_pb2.AttrValue.ListValue( + type=self._forward_function._output_types) + op._set_attr("Tout", attr_value_pb2.AttrValue(list=types)) + for i in range( + len(outputs), len(self._forward_function._output_types)): + t = ops.Tensor(op, i, self._forward_function._output_types[i]) + t.set_shape(self._forward_function._output_shapes[i]) + func_graph_output = self._forward_function._func_graph_outputs[i] + custom_gradient.copy_handle_data(func_graph_output, t) + op._outputs.append(t) + # pylint: enable=protected-access + side_outputs = op.outputs[len(outputs):] + return self._backward_graph_function( + *(list(doutputs) + list(side_outputs))) + + with ops.get_default_graph().gradient_override_map( + {"PartitionedCall": name}): + outputs = self._inference_function.call(ctx, args) + return self._build_call_outputs(outputs) @property |