aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Max Galkin <maxgalkin@google.com>2018-03-12 18:35:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-12 18:39:38 -0700
commitf5efe97603855c517795e3fe9fc6364b59502d8a (patch)
tree28485d121096d7ab8029f1f1aa2ebb899b210ef4
parenta2643a983694a91ef0027650bc0ce28f2f760067 (diff)
Demystify MaterializeShapes a bit.
PiperOrigin-RevId: 188812445
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc230
1 files changed, 123 insertions, 107 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 4c9431deac..a4d8376667 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -244,44 +244,41 @@ string ConstantFolding::AddControlDependency(const string& input_name,
}
}
-Status ConvertShapeToConstant(const string& op, const DataType& type,
- const PartialTensorShape& shp, Tensor* value) {
+// Puts the given value into the tensor at the given "flat" index.
+static Status PutValueIntoTensor(const int64 value, const DataType& type,
+ const int index, Tensor* tensor) {
+ if (type == DT_INT32) {
+ if (value >= INT_MAX) {
+ return Status(error::INVALID_ARGUMENT, "int32 overflow");
+ }
+ tensor->flat<int32>()(index) = static_cast<int32>(value);
+ } else {
+ tensor->flat<int64>()(index) = value;
+ }
+ return Status::OK();
+}
+
+// Writes the given tensor shape into the given tensor.
+// Op is assumed to be Shape, ShapeN, Size or Rank.
+static Status ConvertShapeToConstant(const string& op, const DataType& type,
+ const PartialTensorShape& shp,
+ Tensor* tensor) {
if (op == "Shape" || op == "ShapeN") {
- *value = Tensor(type, TensorShape({shp.dims()}));
+ *tensor = Tensor(type, TensorShape({shp.dims()}));
for (int i = 0; i < shp.dims(); ++i) {
- if (type == DT_INT32) {
- if (shp.dim_size(i) >= INT_MAX) {
- return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
- }
- value->flat<int32>()(i) = shp.dim_size(i);
- } else {
- value->flat<int64>()(i) = shp.dim_size(i);
- }
+ TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
}
} else if (op == "Size") {
int64 size = 1;
for (int i = 0; i < shp.dims(); ++i) {
size *= shp.dim_size(i);
}
- *value = Tensor(type, TensorShape({}));
- if (type == DT_INT32) {
- if (size >= INT_MAX) {
- return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
- }
- value->flat<int32>()(0) = size;
- } else {
- value->flat<int64>()(0) = size;
- }
+ *tensor = Tensor(type, TensorShape({}));
+ TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
} else {
- *value = Tensor(type, TensorShape({}));
- if (type == DT_INT32) {
- if (shp.dims() >= INT_MAX) {
- return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
- }
- value->flat<int32>()(0) = shp.dims();
- } else {
- value->flat<int64>()(0) = shp.dims();
- }
+ CHECK_EQ(op, "Rank");
+ *tensor = Tensor(type, TensorShape({}));
+ TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
}
return Status::OK();
}
@@ -306,13 +303,14 @@ bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
return feed_nodes_.find(node.name()) == feed_nodes_.end();
}
+// Materialize the shapes using constants whenever possible.
Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
- // We may add some nodes to the graph to encode control dependencies: there is
- // no need to process these, so only iterate over the nodes of the input
- // graph.
+ // We may add some nodes to the graph to encode control dependencies and hold
+ // the materialized shapes: there is no need to process these added nodes, so
+ // only iterate over the nodes of the input graph.
const int node_count = graph_->node_size();
- for (int i = 0; i < node_count; ++i) {
- NodeDef* node = graph_->mutable_node(i);
+ for (int node_idx = 0; node_idx < node_count; ++node_idx) {
+ NodeDef* node = graph_->mutable_node(node_idx);
const string op = node->op();
if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") {
continue;
@@ -325,91 +323,109 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
if (input.empty() || output.empty()) {
continue;
}
+
if (op == "Shape" || op == "Size" || op == "Rank") {
CHECK_EQ(1, output.size());
CHECK_EQ(1, input.size());
+
+ const DataType type = output[0].dtype();
+ CHECK(type == DT_INT32 || type == DT_INT64);
+ const PartialTensorShape shape(input[0].shape());
+
+ if ((op != "Rank" && !shape.IsFullyDefined()) ||
+ (op == "Rank" && shape.unknown_rank())) {
+ continue;
+ }
+
+ Tensor constant_value(type);
+ if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
+ continue;
+ }
+
+ // Repurpose the existing node to be the constant.
+ // Device placement is preserved.
+ node->set_op("Const");
+ node->clear_attr();
+ (*node->mutable_attr())["dtype"].set_type(type);
+ constant_value.AsProtoTensorContent(
+ (*node->mutable_attr())["value"].mutable_tensor());
+
+ // Turn the data input into a control dependency: this is needed to
+ // ensure that the constant value will only be run in the
+ // cases where the shape/rank/size would have been run in
+ // the original graph.
+ string ctrl_dep =
+ AddControlDependency(node->input(0), graph_, node_map_.get());
+ node->set_input(0, ctrl_dep);
+ node_map_->AddOutput(NodeName(ctrl_dep), node->name());
+
+ // Done with the Shape/Size/Rank node, move to the next node.
+ continue;
}
- CHECK_EQ(input.size(), output.size());
- for (int j = 0; j < output.size(); ++j) {
- const DataType type = output[j].dtype();
+ // Handle ShapeN materialization case.
+ // It's possible that not all input tensors have known shapes.
+ CHECK_EQ(op, "ShapeN");
+ CHECK_EQ(input.size(), output.size());
+ const NodeDef* const shape_n_node = node;
+ for (int port_idx = 0; port_idx < output.size(); ++port_idx) {
+ const DataType type = output[port_idx].dtype();
CHECK(type == DT_INT32 || type == DT_INT64);
- const TensorShapeProto shape = input[j].shape();
- // Materialize the shapes using constants whenever possible.
- PartialTensorShape shp(shape);
- if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) {
- Tensor value(type);
- auto status = ConvertShapeToConstant(op, type, shp, &value);
- if (!status.ok()) {
- continue;
- }
- // We rewrite the existing node for the first const output and
- // create new nodes for the remaining const outputs (Note that ShapeN
- // could have multiple outputs).
- if (op == "Shape" || op == "Size" || op == "Rank") {
- // Replace the node with the corresponding constant.
- node->set_op("Const");
- node->clear_attr();
- (*node->mutable_attr())["dtype"].set_type(type);
- value.AsProtoTensorContent(
- (*node->mutable_attr())["value"].mutable_tensor());
-
- // Turn the data input into a control dependency: this is needed to
- // ensure that the constant value will only be run in the
- // cases where the shape/rank/size would have been run in
- // the original graph. Additional inputs are extra control
- string ctrl_dep =
- AddControlDependency(node->input(0), graph_, node_map_.get());
- node->set_input(0, ctrl_dep);
- node_map_->AddOutput(NodeName(ctrl_dep), node->name());
- } else {
- auto outputs = node_map_->GetOutputs(node->name());
- for (NodeDef* output : outputs) {
- for (int k = 0; k < output->input_size(); ++k) {
- int port;
- string node_name = ParseNodeName(output->input(k), &port);
- if (node_name == node->name() && port == j) {
- // Create a const node as ShapeN's output if not already.
- const string const_name =
- OptimizedNodeName(*node, strings::StrCat("-matshapes-", j));
- if (node_map_->GetNode(const_name) == nullptr) {
- NodeDef* added_node = graph_->add_node();
- added_node->set_name(const_name);
- added_node->set_op("Const");
- added_node->set_device(node->device());
- node_map_->AddNode(added_node->name(), added_node);
- (*added_node->mutable_attr())["dtype"].set_type(type);
- value.AsProtoTensorContent(
- (*added_node->mutable_attr())["value"].mutable_tensor());
- // We add a control dependency to the original ShapeN node,
- // so that the node will only be run if all inputs of the
- // original ShapeN node are run.
- string ctrl_dep = AddControlDependency(node->name(), graph_,
- node_map_.get());
- *added_node->add_input() = ctrl_dep;
- node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
- }
- *output->mutable_input(k) = const_name;
- node_map_->AddOutput(const_name, output->name());
- }
- }
- bool remove_output = true;
- for (int k = 0; k < output->input_size(); ++k) {
- int port;
- string node_name = ParseNodeName(output->input(k), &port);
- if (node_name == node->name()) {
- remove_output = false;
- break;
- }
- }
- if (remove_output) {
- node_map_->RemoveOutput(node->name(), output->name());
+ const PartialTensorShape shape(input[port_idx].shape());
+ if (!shape.IsFullyDefined()) {
+ continue;
+ }
+ Tensor constant_value(type);
+ auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
+ if (!status.ok()) {
+ continue;
+ }
+
+ // Find all nodes consuming this shape and connect them through the new
+ // constant node instead.
+ auto outputs = node_map_->GetOutputs(shape_n_node->name());
+ for (NodeDef* output : outputs) {
+ // Track whether there are any direct edges left between shape_n_node
+ // and this output node after the transformation.
+ bool direct_edges_exist = false;
+ for (int k = 0; k < output->input_size(); ++k) {
+ int port;
+ const string node_name = ParseNodeName(output->input(k), &port);
+ if (node_name == shape_n_node->name() && port == port_idx) {
+ // Create a const node as ShapeN's output if not already.
+ const string const_name = OptimizedNodeName(
+ *shape_n_node, strings::StrCat("-matshapes-", port_idx));
+ if (node_map_->GetNode(const_name) == nullptr) {
+ NodeDef* added_node = graph_->add_node();
+ added_node->set_name(const_name);
+ added_node->set_op("Const");
+ added_node->set_device(shape_n_node->device());
+ node_map_->AddNode(added_node->name(), added_node);
+ (*added_node->mutable_attr())["dtype"].set_type(type);
+ constant_value.AsProtoTensorContent(
+ (*added_node->mutable_attr())["value"].mutable_tensor());
+ // We add a control dependency to the original ShapeN node,
+ // so that the node will only be run if all inputs of the
+ // original ShapeN node are run.
+ string ctrl_dep = AddControlDependency(shape_n_node->name(),
+ graph_, node_map_.get());
+ *added_node->add_input() = ctrl_dep;
+ node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
}
+ *output->mutable_input(k) = const_name;
+ node_map_->AddOutput(const_name, output->name());
}
+ if (node_name == shape_n_node->name() && port != port_idx) {
+ direct_edges_exist = true;
+ }
+ }
+ if (!direct_edges_exist) {
+ node_map_->RemoveOutput(node->name(), output->name());
}
}
}
}
+
return Status::OK();
}