diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/remapper.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/remapper.cc | 87 |
1 files changed, 73 insertions, 14 deletions
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 2a62871293..efd870b118 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -28,10 +28,71 @@ namespace grappler { void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) { const string& x = fused_node.input(0); - const string& scale = fused_node.input(1); - const string& offset = fused_node.input(2); - const string& mean = fused_node.input(3); - const string& variance = fused_node.input(4); + string scale = fused_node.input(1); + string offset = fused_node.input(2); + string mean = fused_node.input(3); + string variance = fused_node.input(4); + + if (fused_node.attr().at("data_format").s() == "NCHW") { + // Need to reshape the last 4 inputs + NodeDef* new_shape = optimized_graph->add_node(); + new_shape->set_name(AddPrefixToNodeName("NCHWShape", fused_node.name())); + new_shape->set_op("Const"); + new_shape->set_device(fused_node.device()); + *new_shape->add_input() = AsControlDependency(scale); + (*new_shape->mutable_attr())["dtype"].set_type(DT_INT32); + Tensor t(DT_INT32, {4}); + t.flat<int32>()(0) = 1; + t.flat<int32>()(1) = -1; + t.flat<int32>()(2) = 1; + t.flat<int32>()(3) = 1; + t.AsProtoTensorContent( + (*new_shape->mutable_attr())["value"].mutable_tensor()); + + NodeDef* reshaped_scale = optimized_graph->add_node(); + reshaped_scale->set_name( + AddPrefixToNodeName("NCHWShapedScale", fused_node.name())); + reshaped_scale->set_op("Reshape"); + reshaped_scale->set_device(fused_node.device()); + *reshaped_scale->add_input() = scale; + *reshaped_scale->add_input() = new_shape->name(); + (*reshaped_scale->mutable_attr())["T"] = fused_node.attr().at("T"); + (*reshaped_scale->mutable_attr())["Tshape"].set_type(DT_INT32); + scale = reshaped_scale->name(); + + NodeDef* reshaped_offset = optimized_graph->add_node(); + reshaped_offset->set_name( + AddPrefixToNodeName("NCHWShapedOffset", fused_node.name())); + reshaped_offset->set_op("Reshape"); + reshaped_offset->set_device(fused_node.device()); + *reshaped_offset->add_input() = offset; + *reshaped_offset->add_input() = new_shape->name(); + (*reshaped_offset->mutable_attr())["T"] = fused_node.attr().at("T"); + (*reshaped_offset->mutable_attr())["Tshape"].set_type(DT_INT32); + offset = reshaped_offset->name(); + + NodeDef* reshaped_mean = optimized_graph->add_node(); + reshaped_mean->set_name( + AddPrefixToNodeName("NCHWShapedMean", fused_node.name())); + reshaped_mean->set_op("Reshape"); + reshaped_mean->set_device(fused_node.device()); + *reshaped_mean->add_input() = mean; + *reshaped_mean->add_input() = new_shape->name(); + (*reshaped_mean->mutable_attr())["T"] = fused_node.attr().at("T"); + (*reshaped_mean->mutable_attr())["Tshape"].set_type(DT_INT32); + mean = reshaped_mean->name(); + + NodeDef* reshaped_variance = optimized_graph->add_node(); + reshaped_variance->set_name( + AddPrefixToNodeName("NCHWShapedVariance", fused_node.name())); + reshaped_variance->set_op("Reshape"); + reshaped_variance->set_device(fused_node.device()); + *reshaped_variance->add_input() = variance; + *reshaped_variance->add_input() = new_shape->name(); + (*reshaped_variance->mutable_attr())["T"] = fused_node.attr().at("T"); + (*reshaped_variance->mutable_attr())["Tshape"].set_type(DT_INT32); + variance = reshaped_variance->name(); + } float epsilon = 0.0f; if (fused_node.attr().count("epsilon")) { @@ -118,20 +179,16 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, optimizable &= (node.attr().count("is_training") == 0 || !node.attr().at("is_training").b()); if (optimizable) { - std::unordered_set<int> const_inputs; - for (const string& input : node.input()) { - int pos; - const string input_node = ParseNodeName(input, &pos); - if (properties.HasInputProperties(input_node)) { - const auto& props = properties.GetInputProperties(input_node); - if (props.size() > pos && props[pos].has_value()) { - const_inputs.insert(pos); - } + int const_inputs = 0; + const auto& props = properties.GetInputProperties(node.name()); + for (const auto& prop : props) { + if (prop.has_value()) { + const_inputs += 1; } } // TODO(bsteiner): use the cost model to compare the cost of fused batch // norm against that of the optimized form. - optimizable = (const_inputs.size() >= 4); + optimizable = (const_inputs >= 4); } if (optimizable) { for (GraphView::Edge edge : graph.GetFanoutEdges(node, false)) { @@ -143,6 +200,8 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, } } if (optimizable) { + std::cout << "Optimizing fused batch norm node " << node.DebugString() + << std::endl; AddBatchNormNodes(optimized_graph, node); continue; } |