aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-10-08 13:50:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 13:58:40 -0700
commiteec9ca8f0baccd249a49046fe31b460903e44850 (patch)
treeb6397af544af7c05abca4bea08bd6354f90bedf1 /tensorflow/core
parent494bbdfced3fd8596721d12e73676c4967f452e4 (diff)
Partial support tfe.defun in tf.gradients.
Doesn't attempt to deal with cases where we might have already generated the functiondef for the parent function as in that case we cannot easily modify the forward pass. PiperOrigin-RevId: 216243224
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc5
-rw-r--r--tensorflow/core/framework/shape_inference.cc9
-rw-r--r--tensorflow/core/framework/shape_inference.h9
-rw-r--r--tensorflow/core/graph/graph.cc13
-rw-r--r--tensorflow/core/graph/graph.h5
-rw-r--r--tensorflow/core/graph/node_builder.cc8
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc3
7 files changed, 44 insertions, 8 deletions
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index fa4d1eda62..9488a44778 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -288,6 +288,11 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port,
"output_port '", output_port, "' is out of range, ", "node '",
node->name(), "' has ", node->num_outputs(), " outputs");
}
+ // Note: it's possible, if the node's been updated, that the shape inference
+ // context doesn't have the right number of outputs.
+ if (node->num_outputs() > c->num_outputs()) {
+ TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
+ }
// Check compatibility, and merge the shapes.
ShapeHandle existing_shape = c->output(output_port);
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 3e77028a5f..4dcc80680f 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -239,6 +239,15 @@ void InferenceContext::PreInputInit(
output_handle_shapes_and_types_.resize(num_outputs);
}
+Status InferenceContext::ExpandOutputs(int new_output_size) {
+ if (new_output_size < outputs_.size()) {
+ return errors::InvalidArgument("Trying to reduce number of outputs of op.");
+ }
+ outputs_.resize(new_output_size, nullptr);
+ output_handle_shapes_and_types_.resize(new_output_size);
+ return Status::OK();
+}
+
void InferenceContext::PostInputInit(
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
int num_inputs_from_node_def = 0;
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 81258b55b3..e3885b7d9e 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -323,13 +323,13 @@ class InferenceContext {
return input_tensors_as_shapes_;
}
- ShapeHandle output(int64 idx) const { return outputs_[idx]; }
- void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; }
+ ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
+ void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
Status set_output(StringPiece output_name,
const std::vector<ShapeHandle>& shapes);
int num_outputs() const { return outputs_.size(); }
- ShapeHandle output(int idx) const { return outputs_[idx]; }
+ ShapeHandle output(int idx) const { return outputs_.at(idx); }
Status output(StringPiece output_name,
std::vector<ShapeHandle>* output) const;
@@ -645,6 +645,9 @@ class InferenceContext {
return merged_dims_;
}
+ // Adds new outputs; useful when mutating the graph.
+ Status ExpandOutputs(int new_output_size);
+
private:
// Creates and stores shapes for use in InferenceContext.
class ShapeManager {
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 7a4a0096fa..6f068546d2 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -142,6 +142,19 @@ void Node::Clear() {
assigned_device_name_index_ = 0;
}
+void Node::UpdateProperties() {
+ DataTypeVector inputs;
+ DataTypeVector outputs;
+ Status status =
+ InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed at updating node: " << status;
+ return;
+ }
+ props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def,
+ inputs, outputs);
+}
+
const string& Node::name() const { return props_->node_def.name(); }
const string& Node::type_string() const { return props_->node_def.op(); }
const NodeDef& Node::def() const { return props_->node_def; }
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 2944951f82..228b1331d9 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -171,6 +171,7 @@ class Node {
template <typename T>
void AddAttr(const string& name, const T& val) {
SetAttrValue(val, AddAttrHelper(name));
+ UpdateProperties();
}
void ClearAttr(const string& name);
@@ -211,6 +212,10 @@ class Node {
// e.g. in AddAttr.
void MaybeCopyOnWrite();
+ // Called after an attr has changed. Decides whether we need to update some
+ // property of the node (stored in props_).
+ void UpdateProperties();
+
AttrValue* AddAttrHelper(const string& name);
// A set of mutually exclusive classes for different kinds of nodes,
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index d92874909f..68a20fcc5f 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -140,10 +140,10 @@ void NodeBuilder::AddIndexError(const Node* node, int i) {
strings::StrCat("Attempt to add nullptr Node to node with type ",
def_builder_.op_def().name()));
} else {
- errors_.emplace_back(
- strings::StrCat("Attempt to add output ", i, " of ", node->name(),
- " not in range [0, ", node->num_outputs(),
- ") to node with type ", def_builder_.op_def().name()));
+ errors_.emplace_back(strings::StrCat(
+ "Attempt to add output ", i, " of ", node->name(), " not in range [0, ",
+ node->num_outputs(), ") to node with type ",
+ def_builder_.op_def().name(), ". Node: ", node->DebugString()));
}
}
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index adc9cd1486..65bdde375b 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -216,7 +216,8 @@ REGISTER_OP("VarIsInitializedOp")
Status VariableShapeShapeFn(InferenceContext* c) {
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data == nullptr || handle_data->empty()) {
- return errors::InvalidArgument("Handle doesn't have shape information.");
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
}
ShapeHandle var_shape = (*handle_data)[0].shape;
int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape)