diff options
-rw-r--r-- | tensorflow/core/common_runtime/shape_refiner.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference.h | 9 | ||||
-rw-r--r-- | tensorflow/core/graph/graph.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/graph/graph.h | 5 | ||||
-rw-r--r-- | tensorflow/core/graph/node_builder.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/ops/resource_variable_ops.cc | 3 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 87 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 18 | ||||
-rw-r--r-- | tensorflow/python/framework/op_def_library.py | 3 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/cond_v2_test.py | 1 | ||||
-rw-r--r-- | tensorflow/python/ops/custom_gradient.py | 44 | ||||
-rw-r--r-- | tensorflow/python/ops/gradients_impl.py | 30 | ||||
-rw-r--r-- | tensorflow/python/ops/while_v2.py | 3 |
14 files changed, 169 insertions, 69 deletions
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index fa4d1eda62..9488a44778 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -288,6 +288,11 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port, "output_port '", output_port, "' is out of range, ", "node '", node->name(), "' has ", node->num_outputs(), " outputs"); } + // Note: it's possible, if the node's been updated, that the shape inference + // context doesn't have the right number of outputs. + if (node->num_outputs() > c->num_outputs()) { + TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs())); + } // Check compatibility, and merge the shapes. ShapeHandle existing_shape = c->output(output_port); diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 3e77028a5f..4dcc80680f 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -239,6 +239,15 @@ void InferenceContext::PreInputInit( output_handle_shapes_and_types_.resize(num_outputs); } +Status InferenceContext::ExpandOutputs(int new_output_size) { + if (new_output_size < outputs_.size()) { + return errors::InvalidArgument("Trying to reduce number of outputs of op."); + } + outputs_.resize(new_output_size, nullptr); + output_handle_shapes_and_types_.resize(new_output_size); + return Status::OK(); +} + void InferenceContext::PostInputInit( std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) { int num_inputs_from_node_def = 0; diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 81258b55b3..e3885b7d9e 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -323,13 +323,13 @@ class InferenceContext { return input_tensors_as_shapes_; } - ShapeHandle output(int64 idx) const { return outputs_[idx]; } - void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; } + ShapeHandle output(int64 idx) const { return outputs_.at(idx); } + void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; } Status set_output(StringPiece output_name, const std::vector<ShapeHandle>& shapes); int num_outputs() const { return outputs_.size(); } - ShapeHandle output(int idx) const { return outputs_[idx]; } + ShapeHandle output(int idx) const { return outputs_.at(idx); } Status output(StringPiece output_name, std::vector<ShapeHandle>* output) const; @@ -645,6 +645,9 @@ class InferenceContext { return merged_dims_; } + // Adds new outputs; useful when mutating the graph. + Status ExpandOutputs(int new_output_size); + private: // Creates and stores shapes for use in InferenceContext. class ShapeManager { diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 7a4a0096fa..6f068546d2 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -142,6 +142,19 @@ void Node::Clear() { assigned_device_name_index_ = 0; } +void Node::UpdateProperties() { + DataTypeVector inputs; + DataTypeVector outputs; + Status status = + InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs); + if (!status.ok()) { + LOG(ERROR) << "Failed at updating node: " << status; + return; + } + props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def, + inputs, outputs); +} + const string& Node::name() const { return props_->node_def.name(); } const string& Node::type_string() const { return props_->node_def.op(); } const NodeDef& Node::def() const { return props_->node_def; } diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 2944951f82..228b1331d9 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -171,6 +171,7 @@ class Node { template <typename T> void AddAttr(const string& name, const T& val) { SetAttrValue(val, AddAttrHelper(name)); + UpdateProperties(); } void ClearAttr(const string& name); @@ -211,6 +212,10 @@ class Node { // e.g. in AddAttr. void MaybeCopyOnWrite(); + // Called after an attr has changed. Decides whether we need to update some + // property of the node (stored in props_). + void UpdateProperties(); + AttrValue* AddAttrHelper(const string& name); // A set of mutually exclusive classes for different kinds of nodes, diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index d92874909f..68a20fcc5f 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -140,10 +140,10 @@ void NodeBuilder::AddIndexError(const Node* node, int i) { strings::StrCat("Attempt to add nullptr Node to node with type ", def_builder_.op_def().name())); } else { - errors_.emplace_back( - strings::StrCat("Attempt to add output ", i, " of ", node->name(), - " not in range [0, ", node->num_outputs(), - ") to node with type ", def_builder_.op_def().name())); + errors_.emplace_back(strings::StrCat( + "Attempt to add output ", i, " of ", node->name(), " not in range [0, ", + node->num_outputs(), ") to node with type ", + def_builder_.op_def().name(), ". Node: ", node->DebugString())); } } diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index adc9cd1486..65bdde375b 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -216,7 +216,8 @@ REGISTER_OP("VarIsInitializedOp") Status VariableShapeShapeFn(InferenceContext* c) { auto* handle_data = c->input_handle_shapes_and_types(0); if (handle_data == nullptr || handle_data->empty()) { - return errors::InvalidArgument("Handle doesn't have shape information."); + c->set_output(0, c->Vector(c->UnknownDim())); + return Status::OK(); } ShapeHandle var_shape = (*handle_data)[0].shape; int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 93168826b1..99bf375ea7 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -46,6 +46,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops @@ -81,49 +82,10 @@ def _create_substitute_placeholder(value, name=None, dtype=None): with ops.control_dependencies(None): placeholder = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) - _copy_handle_data(value, placeholder) + custom_gradient.copy_handle_data(value, placeholder) return placeholder -def _copy_handle_data(source_t, target_t): - """Copies HandleData for variant and resource type tensors if available. - - The CppShapeInferenceResult::HandleData proto contains information about the - shapes and types of the element tensors of resource/variant type tensors. - We need to copy this across function boundaries, i.e., when capturing a - placeholder or when returning a function tensor as output. If we don't do this - the element tensors will have unknown shapes, e.g., if a TensorList variant - tensor is captured as a placeholder, elements popped from that list would have - unknown shape. - - Args: - source_t: The tensor to copy HandleData from. - target_t: The tensor to copy HandleData to. - """ - if (target_t.dtype == dtypes_module.resource or - target_t.dtype == dtypes_module.variant): - if isinstance(source_t, ops.EagerTensor): - handle_data = source_t._handle_data # pylint: disable=protected-access - else: - handle_data = resource_variable_ops.get_resource_handle_data(source_t) - if handle_data is not None and handle_data.is_set: - # pylint: disable=protected-access - pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph, - target_t._as_tf_output(), - handle_data.SerializeToString()) - # pylint: enable=protected-access - # Ensure that shapes and dtypes are propagated. - shapes, types = zip(*[(pair.shape, pair.dtype) - for pair in handle_data.shape_and_type]) - ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] - shapes = [[d.size for d in s.dim] - if not s.unknown_rank else None for s in shapes] - pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( - target_t._op._graph._c_graph, # pylint: disable=protected-access - target_t._as_tf_output(), # pylint: disable=protected-access - shapes, ranks, types) - - def _get_device_functions(ctx, graph): """Returns a tuple of device functions representing the device stack.""" if ctx.executing_eagerly(): @@ -547,7 +509,7 @@ class _EagerDefinedFunction(object): for i, shape in enumerate(self._output_shapes): outputs[i].set_shape(shape) for i, func_graph_output in enumerate(self._func_graph_outputs): - _copy_handle_data(func_graph_output, outputs[i]) + custom_gradient.copy_handle_data(func_graph_output, outputs[i]) return outputs @@ -658,7 +620,48 @@ class Function(object): if tape.should_record(tensor_inputs) or tape.should_record(captures): return self._backprop_call(args) - outputs = self._inference_function.call(ctx, args) + # Only need to override the gradient in graph mode and when we have outputs. + if context.executing_eagerly() or not self.outputs: + outputs = self._inference_function.call(ctx, args) + else: + name = "PartitionedCall-%s" % ops.uid() + + @ops.RegisterGradient(name) + def grad_fn(op, *doutputs): # pylint: disable=unused-variable + """Gradients of this function.""" + if op.graph is not ops.get_default_graph(): + # TODO(apassos) this will still emit SymbolicGradient ops when + # nested defuns are being differentiated. We need to somehow figure + # out a way to update the FunctionDef corresponding to the calling + # function when mutating a call to the forward pass. + return gradients_impl._SymGrad(op, list(doutputs)) # pylint: disable=protected-access + if self._backward_graph_function is None: + self._construct_backprop_function() + self._forward_function.add_to_graph(op.graph) + func = attr_value_pb2.AttrValue( + func=attr_value_pb2.NameAttrList( + name=self._forward_function.name)) + # pylint: disable=protected-access + op._set_attr("f", func) + types = attr_value_pb2.AttrValue.ListValue( + type=self._forward_function._output_types) + op._set_attr("Tout", attr_value_pb2.AttrValue(list=types)) + for i in range( + len(outputs), len(self._forward_function._output_types)): + t = ops.Tensor(op, i, self._forward_function._output_types[i]) + t.set_shape(self._forward_function._output_shapes[i]) + func_graph_output = self._forward_function._func_graph_outputs[i] + custom_gradient.copy_handle_data(func_graph_output, t) + op._outputs.append(t) + # pylint: enable=protected-access + side_outputs = op.outputs[len(outputs):] + return self._backward_graph_function( + *(list(doutputs) + list(side_outputs))) + + with ops.get_default_graph().gradient_override_map( + {"PartitionedCall": name}): + outputs = self._inference_function.call(ctx, args) + return self._build_call_outputs(outputs) @property diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 57e545be69..e46bde098b 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -286,7 +286,23 @@ class FunctionTest(test.TestCase): c = constant_op.constant([[2.]]) f_c = f(c) g, = gradients_impl.gradients(f_c, c) - self.assertAllEqual(sess.run(g), [[1.0]]) + self.assertAllEqual(sess.run(g).values, [[1.0]]) + + def testNoSymGradNestedDefun(self): + + @function.defun + def outer(): + + @function.defun + def f(x): + return array_ops.gather_nd(x, [[0]]) + + c = constant_op.constant([[2.]]) + f_c = f(c) + g, = gradients_impl.gradients(f_c, c) + self.assertTrue(isinstance(g, ops.IndexedSlices)) + + outer() def testNestedInputsGraphFunction(self): matmul = function.defun(math_ops.matmul) diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index e85bba11cd..9955a9a2cd 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -482,7 +482,8 @@ class OpDefLibrary(object): else: raise TypeError("%s that don't all match." % prefix) else: - raise TypeError("%s that are invalid." % prefix) + raise TypeError( + "%s that are invalid. Tensors: %s" % (prefix, values)) types = [x.dtype for x in values] inputs.extend(values) diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index ec875aae59..a424a0f219 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -153,6 +153,7 @@ class CondV2Test(test.TestCase): self.assertIn("foo_cond_1_false", ops.get_default_graph()._functions) def testDefunInCond(self): + self.skipTest("b/117293122") x = constant_op.constant(1.0, name="x") y = constant_op.constant(2.0, name="y") diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index d7834ba350..bfe23834b7 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -18,9 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import tape as tape_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -33,6 +35,45 @@ from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export +def copy_handle_data(source_t, target_t): + """Copies HandleData for variant and resource type tensors if available. + + The CppShapeInferenceResult::HandleData proto contains information about the + shapes and types of the element tensors of resource/variant type tensors. + We need to copy this across function boundaries, i.e., when capturing a + placeholder or when returning a function tensor as output. If we don't do this + the element tensors will have unknown shapes, e.g., if a TensorList variant + tensor is captured as a placeholder, elements popped from that list would have + unknown shape. + + Args: + source_t: The tensor to copy HandleData from. + target_t: The tensor to copy HandleData to. + """ + if (target_t.dtype == dtypes.resource or + target_t.dtype == dtypes.variant): + if isinstance(source_t, ops.EagerTensor): + handle_data = source_t._handle_data # pylint: disable=protected-access + else: + handle_data = resource_variable_ops.get_resource_handle_data(source_t) + if handle_data is not None and handle_data.is_set: + # pylint: disable=protected-access + pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph, + target_t._as_tf_output(), + handle_data.SerializeToString()) + # pylint: enable=protected-access + # Ensure that shapes and dtypes are propagated. + shapes, types = zip(*[(pair.shape, pair.dtype) + for pair in handle_data.shape_and_type]) + ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] + shapes = [[d.size for d in s.dim] + if not s.unknown_rank else None for s in shapes] + pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( + target_t._op._graph._c_graph, # pylint: disable=protected-access + target_t._as_tf_output(), # pylint: disable=protected-access + shapes, ranks, types) + + @tf_export("custom_gradient") def custom_gradient(f): """Decorator to define a function with a custom gradient. @@ -180,8 +221,11 @@ def _graph_mode_decorator(f, *args, **kwargs): input_grads = nest.flatten(input_grads) return ([None] * len(flat_result)) + input_grads + variable_grads + original_tensors = all_tensors with ops.get_default_graph().gradient_override_map({"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) + for ot, t in zip(original_tensors, all_tensors): + copy_handle_data(ot, t) return nest.pack_sequence_as( structure=result, flat_sequence=all_tensors[:len(flat_result)]) diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index aac95037dc..6909fcaed5 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -800,23 +800,21 @@ def _GradientsHelper(ys, # pylint: enable=protected-access has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) if has_out_grads and (op not in stop_ops): - if is_func_call: - if is_partitioned_call: - func_call = src_graph._get_function( # pylint: disable=protected-access - compat.as_bytes(op.get_attr("f").name)) + try: + grad_fn = ops.get_gradient_function(op) + except LookupError: + if is_func_call: + if is_partitioned_call: + func_call = src_graph._get_function( # pylint: disable=protected-access + compat.as_bytes(op.get_attr("f").name)) + else: + func_call = src_graph._get_function(op.type) # pylint: disable=protected-access + # Note that __defun is not set if the graph is + # imported. If it's set, we prefer to access the original + # defun. + func_call = getattr(op, "__defun", func_call) + grad_fn = func_call.python_grad_func else: - func_call = src_graph._get_function(op.type) # pylint: disable=protected-access - # Note that __defun is not set if the graph is - # imported. If it's set, we prefer to access the original - # defun. - func_call = getattr(op, "__defun", func_call) - grad_fn = func_call.python_grad_func - else: - # A grad_fn must be defined, either as a function or as None - # for ops that do not have gradients. - try: - grad_fn = ops.get_gradient_function(op) - except LookupError: raise LookupError( "No gradient defined for operation '%s' (op type: %s)" % (op.name, op.type)) diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 8e88a84d60..0419656143 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2_impl as cond_v2 from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gen_functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import list_ops @@ -580,7 +581,7 @@ def _check_shapes_compat(output_tensors, shape_invariants, input_tensors): def _copy_handle_data(src_tensors, tgt_tensors): for src_t, tgt_t in zip(src_tensors, tgt_tensors): - function._copy_handle_data(src_t, tgt_t) + custom_gradient.copy_handle_data(src_t, tgt_t) # TODO(srbs): Move to common utils for cond_v2 and while_v2. |