diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/constant_folding.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 172 |
1 files changed, 11 insertions, 161 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 02a732b092..cb02314183 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/public/version.h" -#include "tensorflow/core/util/bcast.h" namespace tensorflow { namespace grappler { @@ -96,15 +95,11 @@ class DeviceSimple : public DeviceBase { }; } // namespace -ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level, - DeviceBase* cpu_device) - : opt_level_(opt_level), cpu_device_(cpu_device) { +ConstantFolding::ConstantFolding(DeviceBase* cpu_device) + : cpu_device_(cpu_device) { resource_mgr_.reset(new ResourceMgr()); } -ConstantFolding::ConstantFolding(DeviceBase* cpu_device) - : ConstantFolding(RewriterConfig::ON, cpu_device) {} - // static string ConstantFolding::AddControlDependency(const string& input_name, GraphDef* graph, @@ -286,149 +281,6 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item, return Status::OK(); } -bool ShapesEqual(const TensorShapeProto& shape1, - const TensorShapeProto& shape2) { - if (shape1.unknown_rank() || shape2.unknown_rank()) { - return false; - } - if (shape1.dim_size() != shape2.dim_size()) { - return false; - } - for (int i = 0; i < shape1.dim_size(); ++i) { - if (shape1.dim(i).size() != shape2.dim(i).size()) { - return false; - } - } - return true; -} - -namespace { -bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties, - BCast::Vec* shape, int64* min_id) { - if (shape_node.op() == "Shape") { - const std::vector<OpInfo::TensorProperties>& prop1 = - properties.GetInputProperties(shape_node.name()); - if (prop1.size() != 1) { - return false; - } - const TensorShapeProto& shp = prop1[0].shape(); - if (shp.unknown_rank()) { - return false; - } - for (const auto& dim : shp.dim()) { - shape->push_back(dim.size()); - *min_id = std::min<int64>(*min_id, dim.size()); - } - } else { - const TensorProto& raw_val = shape_node.attr().at("value").tensor(); - if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) { - return false; - } - Tensor value(raw_val.dtype(), raw_val.tensor_shape()); - if (!value.FromProto(raw_val)) { - return false; - } - for (int j = 0; j < value.NumElements(); ++j) { - if (raw_val.dtype() == DT_INT64) { - shape->push_back(value.vec<int64>()(j)); - } else { - shape->push_back(value.vec<int>()(j)); - } - } - } - return true; -} -} // namespace - -Status ConstantFolding::MaterializeConstants( - const GrapplerItem& item, const GraphProperties& properties) { - const int node_count = graph_.node_size(); - for (int i = 0; i < node_count; ++i) { - NodeDef& node = *graph_.mutable_node(i); - const string& op = node.op(); - if (op != "BroadcastGradientArgs") { - continue; - } - const NodeDef* shape_node1 = node_map_->GetNode(node.input(0)); - const NodeDef* shape_node2 = node_map_->GetNode(node.input(1)); - if (shape_node1 == nullptr || - (shape_node1->op() != "Shape" && shape_node1->op() != "Const") || - shape_node2 == nullptr || - (shape_node2->op() != "Shape" && shape_node2->op() != "Const")) { - continue; - } - int64 min_id = 0; - BCast::Vec shape1; - if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) { - continue; - } - BCast::Vec shape2; - if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) { - continue; - } - // A value of -1 means we don't known anything about the dimension. Replace - // the -1 values with unique dimension ids since we don't want two '-1' - // dimensions to be considered equal. - for (auto& id : shape1) { - if (id == -1) { - id = --min_id; - } - } - for (auto& id : shape2) { - if (id == -1) { - id = --min_id; - } - } - BCast bcast(shape1, shape2); - if (!bcast.IsValid()) { - continue; - } - BCast::Vec reduce_dims[2]; - reduce_dims[0] = bcast.grad_x_reduce_idx(); - reduce_dims[1] = bcast.grad_y_reduce_idx(); - - const DataType type = node.attr().at("T").type(); - NodeDef* out[2]; - for (int j = 0; j < 2; ++j) { - if (!reduce_dims[j].empty()) { - // This is the case when a tensor dimension 1 is matched against an - // unknown dimension. The unknown dimension could also be equal to 1, in - // which case there would be no reduction. - out[j] = nullptr; - } else { - Tensor value(type, TensorShape({0})); - string const_name = AddPrefixToNodeName( - strings::StrCat(node.name(), "-", j), kConstantFoldingConst); - out[j] = node_map_->GetNode(const_name); - if (!out[j]) { - out[j] = graph_.add_node(); - *out[j] = CreateNodeDef(const_name, TensorValue(&value)); - out[j]->set_device(node.device()); - node_map_->AddNode(const_name, out[j]); - string ctrl_dep = - AddControlDependency(node.name(), &graph_, node_map_.get()); - *out[j]->add_input() = ctrl_dep; - node_map_->AddOutput(NodeName(ctrl_dep), const_name); - } - } - } - - auto outputs = node_map_->GetOutputs(node.name()); - for (const auto& output : outputs) { - for (int k = 0; k < output->input_size(); ++k) { - int port; - string node_name = ParseNodeName(output->input(k), &port); - if (node_name == node.name() && port >= 0 && port < 2 && out[port]) { - *output->mutable_input(k) = out[port]->name(); - node_map_->UpdateInput(output->name(), node_name, out[port]->name()); - } - } - } - } - - return Status::OK(); -} - bool ConstantFolding::IsFoldable(const NodeDef& node) const { // Folding not applicable to ops with no inputs. if (node.input().empty()) { @@ -1069,23 +921,23 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, } GraphProperties properties(item); - Status s = properties.InferStatically(); bool has_feed = !item.feed.empty(); - - if (!has_feed && s.ok()) { + if (!has_feed) { // Only use static shape information when there is no feed in the // graph. That's because it's possible to feed a placeholder with a tensor // of any shape, which could make the static information inconsistent with // the shapes actually fed. - TF_RETURN_IF_ERROR(MaterializeShapes(item, properties)); - } - if (opt_level_ == RewriterConfig::AGGRESSIVE && s.ok()) { - TF_RETURN_IF_ERROR(MaterializeConstants(item, properties)); + Status s = properties.InferStatically(); + if (!s.ok()) { + VLOG(1) << "Failed to infer graph shapes: " << s; + } else { + TF_RETURN_IF_ERROR(MaterializeShapes(item, properties)); + } } TF_RETURN_IF_ERROR(FoldGraph(output)); - if (!has_feed && s.ok()) { + if (!has_feed) { TF_RETURN_IF_ERROR(SimplifyGraph(output, properties)); } return Status::OK(); @@ -1104,14 +956,12 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, GrapplerItem item_to_optimize = item; *output = item.graph; - int64 node_count; do { graph_.Swap(output); item_to_optimize.graph = graph_; *output = GraphDef(); - node_count = graph_.node_size(); TF_RETURN_IF_ERROR(RunOptimizationPass(cluster, item_to_optimize, output)); - } while (output->node_size() != node_count); + } while (output->node_size() < graph_.node_size()); *output->mutable_library() = item.graph.library(); *output->mutable_versions() = item.graph.versions(); |