aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc5
-rw-r--r--tensorflow/core/framework/shape_inference.cc9
-rw-r--r--tensorflow/core/framework/shape_inference.h9
-rw-r--r--tensorflow/core/graph/graph.cc13
-rw-r--r--tensorflow/core/graph/graph.h5
-rw-r--r--tensorflow/core/graph/node_builder.cc8
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc3
7 files changed, 44 insertions, 8 deletions
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index fa4d1eda62..9488a44778 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -288,6 +288,11 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port,
"output_port '", output_port, "' is out of range, ", "node '",
node->name(), "' has ", node->num_outputs(), " outputs");
}
+ // Note: it's possible, if the node's been updated, that the shape inference
+ // context doesn't have the right number of outputs.
+ if (node->num_outputs() > c->num_outputs()) {
+ TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
+ }
// Check compatibility, and merge the shapes.
ShapeHandle existing_shape = c->output(output_port);
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 3e77028a5f..4dcc80680f 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -239,6 +239,15 @@ void InferenceContext::PreInputInit(
output_handle_shapes_and_types_.resize(num_outputs);
}
+Status InferenceContext::ExpandOutputs(int new_output_size) {
+ if (new_output_size < outputs_.size()) {
+ return errors::InvalidArgument("Trying to reduce number of outputs of op.");
+ }
+ outputs_.resize(new_output_size, nullptr);
+ output_handle_shapes_and_types_.resize(new_output_size);
+ return Status::OK();
+}
+
void InferenceContext::PostInputInit(
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
int num_inputs_from_node_def = 0;
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 81258b55b3..e3885b7d9e 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -323,13 +323,13 @@ class InferenceContext {
return input_tensors_as_shapes_;
}
- ShapeHandle output(int64 idx) const { return outputs_[idx]; }
- void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; }
+ ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
+ void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
Status set_output(StringPiece output_name,
const std::vector<ShapeHandle>& shapes);
int num_outputs() const { return outputs_.size(); }
- ShapeHandle output(int idx) const { return outputs_[idx]; }
+ ShapeHandle output(int idx) const { return outputs_.at(idx); }
Status output(StringPiece output_name,
std::vector<ShapeHandle>* output) const;
@@ -645,6 +645,9 @@ class InferenceContext {
return merged_dims_;
}
+ // Adds new outputs; useful when mutating the graph.
+ Status ExpandOutputs(int new_output_size);
+
private:
// Creates and stores shapes for use in InferenceContext.
class ShapeManager {
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 7a4a0096fa..6f068546d2 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -142,6 +142,19 @@ void Node::Clear() {
assigned_device_name_index_ = 0;
}
+void Node::UpdateProperties() {
+ DataTypeVector inputs;
+ DataTypeVector outputs;
+ Status status =
+ InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed at updating node: " << status;
+ return;
+ }
+ props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def,
+ inputs, outputs);
+}
+
const string& Node::name() const { return props_->node_def.name(); }
const string& Node::type_string() const { return props_->node_def.op(); }
const NodeDef& Node::def() const { return props_->node_def; }
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 2944951f82..228b1331d9 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -171,6 +171,7 @@ class Node {
template <typename T>
void AddAttr(const string& name, const T& val) {
SetAttrValue(val, AddAttrHelper(name));
+ UpdateProperties();
}
void ClearAttr(const string& name);
@@ -211,6 +212,10 @@ class Node {
// e.g. in AddAttr.
void MaybeCopyOnWrite();
+ // Called after an attr has changed. Decides whether we need to update some
+ // property of the node (stored in props_).
+ void UpdateProperties();
+
AttrValue* AddAttrHelper(const string& name);
// A set of mutually exclusive classes for different kinds of nodes,
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index d92874909f..68a20fcc5f 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -140,10 +140,10 @@ void NodeBuilder::AddIndexError(const Node* node, int i) {
strings::StrCat("Attempt to add nullptr Node to node with type ",
def_builder_.op_def().name()));
} else {
- errors_.emplace_back(
- strings::StrCat("Attempt to add output ", i, " of ", node->name(),
- " not in range [0, ", node->num_outputs(),
- ") to node with type ", def_builder_.op_def().name()));
+ errors_.emplace_back(strings::StrCat(
+ "Attempt to add output ", i, " of ", node->name(), " not in range [0, ",
+ node->num_outputs(), ") to node with type ",
+ def_builder_.op_def().name(), ". Node: ", node->DebugString()));
}
}
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index adc9cd1486..65bdde375b 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -216,7 +216,8 @@ REGISTER_OP("VarIsInitializedOp")
Status VariableShapeShapeFn(InferenceContext* c) {
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data == nullptr || handle_data->empty()) {
- return errors::InvalidArgument("Handle doesn't have shape information.");
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
}
ShapeHandle var_shape = (*handle_data)[0].shape;
int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape)