aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/remapper.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/remapper.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/remapper.cc87
1 files changed, 73 insertions, 14 deletions
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 2a62871293..efd870b118 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -28,10 +28,71 @@ namespace grappler {
void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) {
const string& x = fused_node.input(0);
- const string& scale = fused_node.input(1);
- const string& offset = fused_node.input(2);
- const string& mean = fused_node.input(3);
- const string& variance = fused_node.input(4);
+ string scale = fused_node.input(1);
+ string offset = fused_node.input(2);
+ string mean = fused_node.input(3);
+ string variance = fused_node.input(4);
+
+ if (fused_node.attr().at("data_format").s() == "NCHW") {
+ // Need to reshape the last 4 inputs
+ NodeDef* new_shape = optimized_graph->add_node();
+ new_shape->set_name(AddPrefixToNodeName("NCHWShape", fused_node.name()));
+ new_shape->set_op("Const");
+ new_shape->set_device(fused_node.device());
+ *new_shape->add_input() = AsControlDependency(scale);
+ (*new_shape->mutable_attr())["dtype"].set_type(DT_INT32);
+ Tensor t(DT_INT32, {4});
+ t.flat<int32>()(0) = 1;
+ t.flat<int32>()(1) = -1;
+ t.flat<int32>()(2) = 1;
+ t.flat<int32>()(3) = 1;
+ t.AsProtoTensorContent(
+ (*new_shape->mutable_attr())["value"].mutable_tensor());
+
+ NodeDef* reshaped_scale = optimized_graph->add_node();
+ reshaped_scale->set_name(
+ AddPrefixToNodeName("NCHWShapedScale", fused_node.name()));
+ reshaped_scale->set_op("Reshape");
+ reshaped_scale->set_device(fused_node.device());
+ *reshaped_scale->add_input() = scale;
+ *reshaped_scale->add_input() = new_shape->name();
+ (*reshaped_scale->mutable_attr())["T"] = fused_node.attr().at("T");
+ (*reshaped_scale->mutable_attr())["Tshape"].set_type(DT_INT32);
+ scale = reshaped_scale->name();
+
+ NodeDef* reshaped_offset = optimized_graph->add_node();
+ reshaped_offset->set_name(
+ AddPrefixToNodeName("NCHWShapedOffset", fused_node.name()));
+ reshaped_offset->set_op("Reshape");
+ reshaped_offset->set_device(fused_node.device());
+ *reshaped_offset->add_input() = offset;
+ *reshaped_offset->add_input() = new_shape->name();
+ (*reshaped_offset->mutable_attr())["T"] = fused_node.attr().at("T");
+ (*reshaped_offset->mutable_attr())["Tshape"].set_type(DT_INT32);
+ offset = reshaped_offset->name();
+
+ NodeDef* reshaped_mean = optimized_graph->add_node();
+ reshaped_mean->set_name(
+ AddPrefixToNodeName("NCHWShapedMean", fused_node.name()));
+ reshaped_mean->set_op("Reshape");
+ reshaped_mean->set_device(fused_node.device());
+ *reshaped_mean->add_input() = mean;
+ *reshaped_mean->add_input() = new_shape->name();
+ (*reshaped_mean->mutable_attr())["T"] = fused_node.attr().at("T");
+ (*reshaped_mean->mutable_attr())["Tshape"].set_type(DT_INT32);
+ mean = reshaped_mean->name();
+
+ NodeDef* reshaped_variance = optimized_graph->add_node();
+ reshaped_variance->set_name(
+ AddPrefixToNodeName("NCHWShapedVariance", fused_node.name()));
+ reshaped_variance->set_op("Reshape");
+ reshaped_variance->set_device(fused_node.device());
+ *reshaped_variance->add_input() = variance;
+ *reshaped_variance->add_input() = new_shape->name();
+ (*reshaped_variance->mutable_attr())["T"] = fused_node.attr().at("T");
+ (*reshaped_variance->mutable_attr())["Tshape"].set_type(DT_INT32);
+ variance = reshaped_variance->name();
+ }
float epsilon = 0.0f;
if (fused_node.attr().count("epsilon")) {
@@ -118,20 +179,16 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
optimizable &= (node.attr().count("is_training") == 0 ||
!node.attr().at("is_training").b());
if (optimizable) {
- std::unordered_set<int> const_inputs;
- for (const string& input : node.input()) {
- int pos;
- const string input_node = ParseNodeName(input, &pos);
- if (properties.HasInputProperties(input_node)) {
- const auto& props = properties.GetInputProperties(input_node);
- if (props.size() > pos && props[pos].has_value()) {
- const_inputs.insert(pos);
- }
+ int const_inputs = 0;
+ const auto& props = properties.GetInputProperties(node.name());
+ for (const auto& prop : props) {
+ if (prop.has_value()) {
+ const_inputs += 1;
}
}
// TODO(bsteiner): use the cost model to compare the cost of fused batch
// norm against that of the optimized form.
- optimizable = (const_inputs.size() >= 4);
+ optimizable = (const_inputs >= 4);
}
if (optimizable) {
for (GraphView::Edge edge : graph.GetFanoutEdges(node, false)) {
@@ -143,6 +200,8 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
}
}
if (optimizable) {
+ std::cout << "Optimizing fused batch norm node " << node.DebugString()
+ << std::endl;
AddBatchNormNodes(optimized_graph, node);
continue;
}