diff options
author | Max Galkin <maxgalkin@google.com> | 2018-03-12 18:35:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-12 18:39:38 -0700 |
commit | f5efe97603855c517795e3fe9fc6364b59502d8a (patch) | |
tree | 28485d121096d7ab8029f1f1aa2ebb899b210ef4 | |
parent | a2643a983694a91ef0027650bc0ce28f2f760067 (diff) |
Demystify MaterializeShapes a bit.
PiperOrigin-RevId: 188812445
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 230 |
1 files changed, 123 insertions, 107 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 4c9431deac..a4d8376667 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -244,44 +244,41 @@ string ConstantFolding::AddControlDependency(const string& input_name, } } -Status ConvertShapeToConstant(const string& op, const DataType& type, - const PartialTensorShape& shp, Tensor* value) { +// Puts the given value into the tensor at the given "flat" index. +static Status PutValueIntoTensor(const int64 value, const DataType& type, + const int index, Tensor* tensor) { + if (type == DT_INT32) { + if (value >= INT_MAX) { + return Status(error::INVALID_ARGUMENT, "int32 overflow"); + } + tensor->flat<int32>()(index) = static_cast<int32>(value); + } else { + tensor->flat<int64>()(index) = value; + } + return Status::OK(); +} + +// Writes the given tensor shape into the given tensor. +// Op is assumed to be Shape, ShapeN, Size or Rank. +static Status ConvertShapeToConstant(const string& op, const DataType& type, + const PartialTensorShape& shp, + Tensor* tensor) { if (op == "Shape" || op == "ShapeN") { - *value = Tensor(type, TensorShape({shp.dims()})); + *tensor = Tensor(type, TensorShape({shp.dims()})); for (int i = 0; i < shp.dims(); ++i) { - if (type == DT_INT32) { - if (shp.dim_size(i) >= INT_MAX) { - return Status(error::INVALID_ARGUMENT, "Invalid dimension size"); - } - value->flat<int32>()(i) = shp.dim_size(i); - } else { - value->flat<int64>()(i) = shp.dim_size(i); - } + TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor)); } } else if (op == "Size") { int64 size = 1; for (int i = 0; i < shp.dims(); ++i) { size *= shp.dim_size(i); } - *value = Tensor(type, TensorShape({})); - if (type == DT_INT32) { - if (size >= INT_MAX) { - return Status(error::INVALID_ARGUMENT, "Invalid dimension size"); - } - value->flat<int32>()(0) = size; - } else { - value->flat<int64>()(0) = size; - } + *tensor = Tensor(type, TensorShape({})); + TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor)); } else { - *value = Tensor(type, TensorShape({})); - if (type == DT_INT32) { - if (shp.dims() >= INT_MAX) { - return Status(error::INVALID_ARGUMENT, "Invalid dimension size"); - } - value->flat<int32>()(0) = shp.dims(); - } else { - value->flat<int64>()(0) = shp.dims(); - } + CHECK_EQ(op, "Rank"); + *tensor = Tensor(type, TensorShape({})); + TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor)); } return Status::OK(); } @@ -306,13 +303,14 @@ bool ConstantFolding::IsReallyConstant(const NodeDef& node) const { return feed_nodes_.find(node.name()) == feed_nodes_.end(); } +// Materialize the shapes using constants whenever possible. Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { - // We may add some nodes to the graph to encode control dependencies: there is - // no need to process these, so only iterate over the nodes of the input - // graph. + // We may add some nodes to the graph to encode control dependencies and hold + // the materialized shapes: there is no need to process these added nodes, so + // only iterate over the nodes of the input graph. const int node_count = graph_->node_size(); - for (int i = 0; i < node_count; ++i) { - NodeDef* node = graph_->mutable_node(i); + for (int node_idx = 0; node_idx < node_count; ++node_idx) { + NodeDef* node = graph_->mutable_node(node_idx); const string op = node->op(); if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") { continue; @@ -325,91 +323,109 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { if (input.empty() || output.empty()) { continue; } + if (op == "Shape" || op == "Size" || op == "Rank") { CHECK_EQ(1, output.size()); CHECK_EQ(1, input.size()); + + const DataType type = output[0].dtype(); + CHECK(type == DT_INT32 || type == DT_INT64); + const PartialTensorShape shape(input[0].shape()); + + if ((op != "Rank" && !shape.IsFullyDefined()) || + (op == "Rank" && shape.unknown_rank())) { + continue; + } + + Tensor constant_value(type); + if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) { + continue; + } + + // Repurpose the existing node to be the constant. + // Device placement is preserved. + node->set_op("Const"); + node->clear_attr(); + (*node->mutable_attr())["dtype"].set_type(type); + constant_value.AsProtoTensorContent( + (*node->mutable_attr())["value"].mutable_tensor()); + + // Turn the data input into a control dependency: this is needed to + // ensure that the constant value will only be run in the + // cases where the shape/rank/size would have been run in + // the original graph. + string ctrl_dep = + AddControlDependency(node->input(0), graph_, node_map_.get()); + node->set_input(0, ctrl_dep); + node_map_->AddOutput(NodeName(ctrl_dep), node->name()); + + // Done with the Shape/Size/Rank node, move to the next node. + continue; } - CHECK_EQ(input.size(), output.size()); - for (int j = 0; j < output.size(); ++j) { - const DataType type = output[j].dtype(); + // Handle ShapeN materialization case. + // It's possible that not all input tensors have known shapes. + CHECK_EQ(op, "ShapeN"); + CHECK_EQ(input.size(), output.size()); + const NodeDef* const shape_n_node = node; + for (int port_idx = 0; port_idx < output.size(); ++port_idx) { + const DataType type = output[port_idx].dtype(); CHECK(type == DT_INT32 || type == DT_INT64); - const TensorShapeProto shape = input[j].shape(); - // Materialize the shapes using constants whenever possible. - PartialTensorShape shp(shape); - if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) { - Tensor value(type); - auto status = ConvertShapeToConstant(op, type, shp, &value); - if (!status.ok()) { - continue; - } - // We rewrite the existing node for the first const output and - // create new nodes for the remaining const outputs (Note that ShapeN - // could have multiple outputs). - if (op == "Shape" || op == "Size" || op == "Rank") { - // Replace the node with the corresponding constant. - node->set_op("Const"); - node->clear_attr(); - (*node->mutable_attr())["dtype"].set_type(type); - value.AsProtoTensorContent( - (*node->mutable_attr())["value"].mutable_tensor()); - - // Turn the data input into a control dependency: this is needed to - // ensure that the constant value will only be run in the - // cases where the shape/rank/size would have been run in - // the original graph. Additional inputs are extra control - string ctrl_dep = - AddControlDependency(node->input(0), graph_, node_map_.get()); - node->set_input(0, ctrl_dep); - node_map_->AddOutput(NodeName(ctrl_dep), node->name()); - } else { - auto outputs = node_map_->GetOutputs(node->name()); - for (NodeDef* output : outputs) { - for (int k = 0; k < output->input_size(); ++k) { - int port; - string node_name = ParseNodeName(output->input(k), &port); - if (node_name == node->name() && port == j) { - // Create a const node as ShapeN's output if not already. - const string const_name = - OptimizedNodeName(*node, strings::StrCat("-matshapes-", j)); - if (node_map_->GetNode(const_name) == nullptr) { - NodeDef* added_node = graph_->add_node(); - added_node->set_name(const_name); - added_node->set_op("Const"); - added_node->set_device(node->device()); - node_map_->AddNode(added_node->name(), added_node); - (*added_node->mutable_attr())["dtype"].set_type(type); - value.AsProtoTensorContent( - (*added_node->mutable_attr())["value"].mutable_tensor()); - // We add a control dependency to the original ShapeN node, - // so that the node will only be run if all inputs of the - // original ShapeN node are run. - string ctrl_dep = AddControlDependency(node->name(), graph_, - node_map_.get()); - *added_node->add_input() = ctrl_dep; - node_map_->AddOutput(NodeName(ctrl_dep), added_node->name()); - } - *output->mutable_input(k) = const_name; - node_map_->AddOutput(const_name, output->name()); - } - } - bool remove_output = true; - for (int k = 0; k < output->input_size(); ++k) { - int port; - string node_name = ParseNodeName(output->input(k), &port); - if (node_name == node->name()) { - remove_output = false; - break; - } - } - if (remove_output) { - node_map_->RemoveOutput(node->name(), output->name()); + const PartialTensorShape shape(input[port_idx].shape()); + if (!shape.IsFullyDefined()) { + continue; + } + Tensor constant_value(type); + auto status = ConvertShapeToConstant(op, type, shape, &constant_value); + if (!status.ok()) { + continue; + } + + // Find all nodes consuming this shape and connect them through the new + // constant node instead. + auto outputs = node_map_->GetOutputs(shape_n_node->name()); + for (NodeDef* output : outputs) { + // Track whether there are any direct edges left between shape_n_node + // and this output node after the transformation. + bool direct_edges_exist = false; + for (int k = 0; k < output->input_size(); ++k) { + int port; + const string node_name = ParseNodeName(output->input(k), &port); + if (node_name == shape_n_node->name() && port == port_idx) { + // Create a const node as ShapeN's output if not already. + const string const_name = OptimizedNodeName( + *shape_n_node, strings::StrCat("-matshapes-", port_idx)); + if (node_map_->GetNode(const_name) == nullptr) { + NodeDef* added_node = graph_->add_node(); + added_node->set_name(const_name); + added_node->set_op("Const"); + added_node->set_device(shape_n_node->device()); + node_map_->AddNode(added_node->name(), added_node); + (*added_node->mutable_attr())["dtype"].set_type(type); + constant_value.AsProtoTensorContent( + (*added_node->mutable_attr())["value"].mutable_tensor()); + // We add a control dependency to the original ShapeN node, + // so that the node will only be run if all inputs of the + // original ShapeN node are run. + string ctrl_dep = AddControlDependency(shape_n_node->name(), + graph_, node_map_.get()); + *added_node->add_input() = ctrl_dep; + node_map_->AddOutput(NodeName(ctrl_dep), added_node->name()); } + *output->mutable_input(k) = const_name; + node_map_->AddOutput(const_name, output->name()); } + if (node_name == shape_n_node->name() && port != port_idx) { + direct_edges_exist = true; + } + } + if (!direct_edges_exist) { + node_map_->RemoveOutput(node->name(), output->name()); } } } } + return Status::OK(); } |