aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc5
-rw-r--r--tensorflow/core/framework/shape_inference.cc9
-rw-r--r--tensorflow/core/framework/shape_inference.h9
-rw-r--r--tensorflow/core/graph/graph.cc13
-rw-r--r--tensorflow/core/graph/graph.h5
-rw-r--r--tensorflow/core/graph/node_builder.cc8
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc3
-rw-r--r--tensorflow/python/eager/function.py87
-rw-r--r--tensorflow/python/eager/function_test.py18
-rw-r--r--tensorflow/python/framework/op_def_library.py3
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py1
-rw-r--r--tensorflow/python/ops/custom_gradient.py44
-rw-r--r--tensorflow/python/ops/gradients_impl.py30
-rw-r--r--tensorflow/python/ops/while_v2.py3
14 files changed, 169 insertions, 69 deletions
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index fa4d1eda62..9488a44778 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -288,6 +288,11 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port,
"output_port '", output_port, "' is out of range, ", "node '",
node->name(), "' has ", node->num_outputs(), " outputs");
}
+ // Note: it's possible, if the node's been updated, that the shape inference
+ // context doesn't have the right number of outputs.
+ if (node->num_outputs() > c->num_outputs()) {
+ TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
+ }
// Check compatibility, and merge the shapes.
ShapeHandle existing_shape = c->output(output_port);
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 3e77028a5f..4dcc80680f 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -239,6 +239,15 @@ void InferenceContext::PreInputInit(
output_handle_shapes_and_types_.resize(num_outputs);
}
+Status InferenceContext::ExpandOutputs(int new_output_size) {
+ if (new_output_size < outputs_.size()) {
+ return errors::InvalidArgument("Trying to reduce number of outputs of op.");
+ }
+ outputs_.resize(new_output_size, nullptr);
+ output_handle_shapes_and_types_.resize(new_output_size);
+ return Status::OK();
+}
+
void InferenceContext::PostInputInit(
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
int num_inputs_from_node_def = 0;
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 81258b55b3..e3885b7d9e 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -323,13 +323,13 @@ class InferenceContext {
return input_tensors_as_shapes_;
}
- ShapeHandle output(int64 idx) const { return outputs_[idx]; }
- void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; }
+ ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
+ void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
Status set_output(StringPiece output_name,
const std::vector<ShapeHandle>& shapes);
int num_outputs() const { return outputs_.size(); }
- ShapeHandle output(int idx) const { return outputs_[idx]; }
+ ShapeHandle output(int idx) const { return outputs_.at(idx); }
Status output(StringPiece output_name,
std::vector<ShapeHandle>* output) const;
@@ -645,6 +645,9 @@ class InferenceContext {
return merged_dims_;
}
+ // Adds new outputs; useful when mutating the graph.
+ Status ExpandOutputs(int new_output_size);
+
private:
// Creates and stores shapes for use in InferenceContext.
class ShapeManager {
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 7a4a0096fa..6f068546d2 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -142,6 +142,19 @@ void Node::Clear() {
assigned_device_name_index_ = 0;
}
+void Node::UpdateProperties() {
+ DataTypeVector inputs;
+ DataTypeVector outputs;
+ Status status =
+ InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed at updating node: " << status;
+ return;
+ }
+ props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def,
+ inputs, outputs);
+}
+
const string& Node::name() const { return props_->node_def.name(); }
const string& Node::type_string() const { return props_->node_def.op(); }
const NodeDef& Node::def() const { return props_->node_def; }
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 2944951f82..228b1331d9 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -171,6 +171,7 @@ class Node {
template <typename T>
void AddAttr(const string& name, const T& val) {
SetAttrValue(val, AddAttrHelper(name));
+ UpdateProperties();
}
void ClearAttr(const string& name);
@@ -211,6 +212,10 @@ class Node {
// e.g. in AddAttr.
void MaybeCopyOnWrite();
+ // Called after an attr has changed. Decides whether we need to update some
+ // property of the node (stored in props_).
+ void UpdateProperties();
+
AttrValue* AddAttrHelper(const string& name);
// A set of mutually exclusive classes for different kinds of nodes,
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index d92874909f..68a20fcc5f 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -140,10 +140,10 @@ void NodeBuilder::AddIndexError(const Node* node, int i) {
strings::StrCat("Attempt to add nullptr Node to node with type ",
def_builder_.op_def().name()));
} else {
- errors_.emplace_back(
- strings::StrCat("Attempt to add output ", i, " of ", node->name(),
- " not in range [0, ", node->num_outputs(),
- ") to node with type ", def_builder_.op_def().name()));
+ errors_.emplace_back(strings::StrCat(
+ "Attempt to add output ", i, " of ", node->name(), " not in range [0, ",
+ node->num_outputs(), ") to node with type ",
+ def_builder_.op_def().name(), ". Node: ", node->DebugString()));
}
}
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index adc9cd1486..65bdde375b 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -216,7 +216,8 @@ REGISTER_OP("VarIsInitializedOp")
Status VariableShapeShapeFn(InferenceContext* c) {
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data == nullptr || handle_data->empty()) {
- return errors::InvalidArgument("Handle doesn't have shape information.");
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
}
ShapeHandle var_shape = (*handle_data)[0].shape;
int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 93168826b1..99bf375ea7 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -46,6 +46,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
@@ -81,49 +82,10 @@ def _create_substitute_placeholder(value, name=None, dtype=None):
with ops.control_dependencies(None):
placeholder = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
- _copy_handle_data(value, placeholder)
+ custom_gradient.copy_handle_data(value, placeholder)
return placeholder
-def _copy_handle_data(source_t, target_t):
- """Copies HandleData for variant and resource type tensors if available.
-
- The CppShapeInferenceResult::HandleData proto contains information about the
- shapes and types of the element tensors of resource/variant type tensors.
- We need to copy this across function boundaries, i.e., when capturing a
- placeholder or when returning a function tensor as output. If we don't do this
- the element tensors will have unknown shapes, e.g., if a TensorList variant
- tensor is captured as a placeholder, elements popped from that list would have
- unknown shape.
-
- Args:
- source_t: The tensor to copy HandleData from.
- target_t: The tensor to copy HandleData to.
- """
- if (target_t.dtype == dtypes_module.resource or
- target_t.dtype == dtypes_module.variant):
- if isinstance(source_t, ops.EagerTensor):
- handle_data = source_t._handle_data # pylint: disable=protected-access
- else:
- handle_data = resource_variable_ops.get_resource_handle_data(source_t)
- if handle_data is not None and handle_data.is_set:
- # pylint: disable=protected-access
- pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
- target_t._as_tf_output(),
- handle_data.SerializeToString())
- # pylint: enable=protected-access
- # Ensure that shapes and dtypes are propagated.
- shapes, types = zip(*[(pair.shape, pair.dtype)
- for pair in handle_data.shape_and_type])
- ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
- shapes = [[d.size for d in s.dim]
- if not s.unknown_rank else None for s in shapes]
- pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
- target_t._op._graph._c_graph, # pylint: disable=protected-access
- target_t._as_tf_output(), # pylint: disable=protected-access
- shapes, ranks, types)
-
-
def _get_device_functions(ctx, graph):
"""Returns a tuple of device functions representing the device stack."""
if ctx.executing_eagerly():
@@ -547,7 +509,7 @@ class _EagerDefinedFunction(object):
for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
for i, func_graph_output in enumerate(self._func_graph_outputs):
- _copy_handle_data(func_graph_output, outputs[i])
+ custom_gradient.copy_handle_data(func_graph_output, outputs[i])
return outputs
@@ -658,7 +620,48 @@ class Function(object):
if tape.should_record(tensor_inputs) or tape.should_record(captures):
return self._backprop_call(args)
- outputs = self._inference_function.call(ctx, args)
+ # Only need to override the gradient in graph mode and when we have outputs.
+ if context.executing_eagerly() or not self.outputs:
+ outputs = self._inference_function.call(ctx, args)
+ else:
+ name = "PartitionedCall-%s" % ops.uid()
+
+ @ops.RegisterGradient(name)
+ def grad_fn(op, *doutputs): # pylint: disable=unused-variable
+ """Gradients of this function."""
+ if op.graph is not ops.get_default_graph():
+ # TODO(apassos) this will still emit SymbolicGradient ops when
+ # nested defuns are being differentiated. We need to somehow figure
+ # out a way to update the FunctionDef corresponding to the calling
+ # function when mutating a call to the forward pass.
+ return gradients_impl._SymGrad(op, list(doutputs)) # pylint: disable=protected-access
+ if self._backward_graph_function is None:
+ self._construct_backprop_function()
+ self._forward_function.add_to_graph(op.graph)
+ func = attr_value_pb2.AttrValue(
+ func=attr_value_pb2.NameAttrList(
+ name=self._forward_function.name))
+ # pylint: disable=protected-access
+ op._set_attr("f", func)
+ types = attr_value_pb2.AttrValue.ListValue(
+ type=self._forward_function._output_types)
+ op._set_attr("Tout", attr_value_pb2.AttrValue(list=types))
+ for i in range(
+ len(outputs), len(self._forward_function._output_types)):
+ t = ops.Tensor(op, i, self._forward_function._output_types[i])
+ t.set_shape(self._forward_function._output_shapes[i])
+ func_graph_output = self._forward_function._func_graph_outputs[i]
+ custom_gradient.copy_handle_data(func_graph_output, t)
+ op._outputs.append(t)
+ # pylint: enable=protected-access
+ side_outputs = op.outputs[len(outputs):]
+ return self._backward_graph_function(
+ *(list(doutputs) + list(side_outputs)))
+
+ with ops.get_default_graph().gradient_override_map(
+ {"PartitionedCall": name}):
+ outputs = self._inference_function.call(ctx, args)
+
return self._build_call_outputs(outputs)
@property
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 57e545be69..e46bde098b 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -286,7 +286,23 @@ class FunctionTest(test.TestCase):
c = constant_op.constant([[2.]])
f_c = f(c)
g, = gradients_impl.gradients(f_c, c)
- self.assertAllEqual(sess.run(g), [[1.0]])
+ self.assertAllEqual(sess.run(g).values, [[1.0]])
+
+ def testNoSymGradNestedDefun(self):
+
+ @function.defun
+ def outer():
+
+ @function.defun
+ def f(x):
+ return array_ops.gather_nd(x, [[0]])
+
+ c = constant_op.constant([[2.]])
+ f_c = f(c)
+ g, = gradients_impl.gradients(f_c, c)
+ self.assertTrue(isinstance(g, ops.IndexedSlices))
+
+ outer()
def testNestedInputsGraphFunction(self):
matmul = function.defun(math_ops.matmul)
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index e85bba11cd..9955a9a2cd 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -482,7 +482,8 @@ class OpDefLibrary(object):
else:
raise TypeError("%s that don't all match." % prefix)
else:
- raise TypeError("%s that are invalid." % prefix)
+ raise TypeError(
+ "%s that are invalid. Tensors: %s" % (prefix, values))
types = [x.dtype for x in values]
inputs.extend(values)
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index ec875aae59..a424a0f219 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -153,6 +153,7 @@ class CondV2Test(test.TestCase):
self.assertIn("foo_cond_1_false", ops.get_default_graph()._functions)
def testDefunInCond(self):
+ self.skipTest("b/117293122")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index d7834ba350..bfe23834b7 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape as tape_lib
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
@@ -33,6 +35,45 @@ from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
+def copy_handle_data(source_t, target_t):
+ """Copies HandleData for variant and resource type tensors if available.
+
+ The CppShapeInferenceResult::HandleData proto contains information about the
+ shapes and types of the element tensors of resource/variant type tensors.
+ We need to copy this across function boundaries, i.e., when capturing a
+ placeholder or when returning a function tensor as output. If we don't do this
+ the element tensors will have unknown shapes, e.g., if a TensorList variant
+ tensor is captured as a placeholder, elements popped from that list would have
+ unknown shape.
+
+ Args:
+ source_t: The tensor to copy HandleData from.
+ target_t: The tensor to copy HandleData to.
+ """
+ if (target_t.dtype == dtypes.resource or
+ target_t.dtype == dtypes.variant):
+ if isinstance(source_t, ops.EagerTensor):
+ handle_data = source_t._handle_data # pylint: disable=protected-access
+ else:
+ handle_data = resource_variable_ops.get_resource_handle_data(source_t)
+ if handle_data is not None and handle_data.is_set:
+ # pylint: disable=protected-access
+ pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
+ target_t._as_tf_output(),
+ handle_data.SerializeToString())
+ # pylint: enable=protected-access
+ # Ensure that shapes and dtypes are propagated.
+ shapes, types = zip(*[(pair.shape, pair.dtype)
+ for pair in handle_data.shape_and_type])
+ ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
+ shapes = [[d.size for d in s.dim]
+ if not s.unknown_rank else None for s in shapes]
+ pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
+ target_t._op._graph._c_graph, # pylint: disable=protected-access
+ target_t._as_tf_output(), # pylint: disable=protected-access
+ shapes, ranks, types)
+
+
@tf_export("custom_gradient")
def custom_gradient(f):
"""Decorator to define a function with a custom gradient.
@@ -180,8 +221,11 @@ def _graph_mode_decorator(f, *args, **kwargs):
input_grads = nest.flatten(input_grads)
return ([None] * len(flat_result)) + input_grads + variable_grads
+ original_tensors = all_tensors
with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
all_tensors = array_ops.identity_n(all_tensors)
+ for ot, t in zip(original_tensors, all_tensors):
+ copy_handle_data(ot, t)
return nest.pack_sequence_as(
structure=result, flat_sequence=all_tensors[:len(flat_result)])
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index aac95037dc..6909fcaed5 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -800,23 +800,21 @@ def _GradientsHelper(ys,
# pylint: enable=protected-access
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op not in stop_ops):
- if is_func_call:
- if is_partitioned_call:
- func_call = src_graph._get_function( # pylint: disable=protected-access
- compat.as_bytes(op.get_attr("f").name))
+ try:
+ grad_fn = ops.get_gradient_function(op)
+ except LookupError:
+ if is_func_call:
+ if is_partitioned_call:
+ func_call = src_graph._get_function( # pylint: disable=protected-access
+ compat.as_bytes(op.get_attr("f").name))
+ else:
+ func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
+ # Note that __defun is not set if the graph is
+ # imported. If it's set, we prefer to access the original
+ # defun.
+ func_call = getattr(op, "__defun", func_call)
+ grad_fn = func_call.python_grad_func
else:
- func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
- # Note that __defun is not set if the graph is
- # imported. If it's set, we prefer to access the original
- # defun.
- func_call = getattr(op, "__defun", func_call)
- grad_fn = func_call.python_grad_func
- else:
- # A grad_fn must be defined, either as a function or as None
- # for ops that do not have gradients.
- try:
- grad_fn = ops.get_gradient_function(op)
- except LookupError:
raise LookupError(
"No gradient defined for operation '%s' (op type: %s)" %
(op.name, op.type))
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 8e88a84d60..0419656143 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl as cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
@@ -580,7 +581,7 @@ def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
def _copy_handle_data(src_tensors, tgt_tensors):
for src_t, tgt_t in zip(src_tensors, tgt_tensors):
- function._copy_handle_data(src_t, tgt_t)
+ custom_gradient.copy_handle_data(src_t, tgt_t)
# TODO(srbs): Move to common utils for cond_v2 and while_v2.