aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar Mingxing Tan <tanmingxing@google.com>2017-09-25 20:20:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-25 20:24:19 -0700
commitc4ee3bc1929ac672f76ed44bd25eeb7a5400fca5 (patch)
tree00dd2c39c6b4df6147a413229aad90b29510dacc /tensorflow/tools/graph_transforms
parent89ffbeaca0dcc69186b90d22d3282fc28db143c3 (diff)
Enable folding batch norm when inputs are concat of Conv2D.
PiperOrigin-RevId: 170001077
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/fold_old_batch_norms.cc371
-rw-r--r--tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc109
2 files changed, 360 insertions, 120 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
index 0978c336b4..d89afe85c7 100644
--- a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
+++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
@@ -32,12 +32,216 @@ Status ErrorIfNotVector(const Tensor& input, const string& input_name,
int expected_width) {
if ((input.shape().dims() != 1) ||
(input.shape().dim_size(0) != expected_width)) {
- return errors::InvalidArgument(input_name,
- " input to batch norm has bad shape: ",
- input.shape().DebugString());
+ return errors::InvalidArgument(
+ input_name,
+ " input to batch norm has bad shape: ", input.shape().DebugString());
}
return Status::OK();
}
+
+Status GetScaleAndOffsetValues(const NodeMatch& match,
+ std::vector<float>* scale_values,
+ std::vector<float>* offset_values) {
+ // Find all the nodes we expect in the subgraph.
+ const NodeDef& batch_norm_node = match.node;
+ // BatchNormWithGlobalNormalization and FusedBatchNorm ops only differ
+ // by input order and attribute names.
+ CHECK(batch_norm_node.op() == "BatchNormWithGlobalNormalization" ||
+ batch_norm_node.op() == "FusedBatchNorm");
+ const bool is_fused = batch_norm_node.op() == "FusedBatchNorm";
+ const int mean_idx = is_fused ? 3 : 1;
+ const int var_idx = is_fused ? 4 : 2;
+ const int beta_idx = is_fused ? 2 : 3;
+ const int gamma_idx = is_fused ? 1 : 4;
+ const string epsilon_attr = is_fused ? "epsilon" : "variance_epsilon";
+ // FusedBatchNorm always scales after normalization.
+ const bool scale_after_normalization =
+ is_fused || batch_norm_node.attr().at("scale_after_normalization").b();
+
+ const NodeDef& mean_node = match.inputs[mean_idx].node;
+ CHECK_EQ("Const", mean_node.op());
+ const NodeDef& variance_node = match.inputs[var_idx].node;
+ CHECK_EQ("Const", variance_node.op());
+ const NodeDef& beta_node = match.inputs[beta_idx].node;
+ CHECK_EQ("Const", beta_node.op());
+ const NodeDef& gamma_node = match.inputs[gamma_idx].node;
+ CHECK_EQ("Const", gamma_node.op());
+
+ // We have a set of vectors that we want to combine into a vector of
+ // scale values and offset values.
+ Tensor mean = GetNodeTensorAttr(mean_node, "value");
+ Tensor variance = GetNodeTensorAttr(variance_node, "value");
+ Tensor beta = GetNodeTensorAttr(beta_node, "value");
+ Tensor gamma = GetNodeTensorAttr(gamma_node, "value");
+ const float variance_epsilon = batch_norm_node.attr().at(epsilon_attr).f();
+
+ // Make sure all the inputs really are vectors with the same shape.
+ const int64 num_cols = mean.shape().dim_size(0);
+ TF_RETURN_IF_ERROR(ErrorIfNotVector(variance, "Variance", num_cols));
+ TF_RETURN_IF_ERROR(ErrorIfNotVector(beta, "Beta", num_cols));
+ TF_RETURN_IF_ERROR(ErrorIfNotVector(gamma, "gamma", num_cols));
+
+ scale_values->resize(num_cols);
+ offset_values->resize(num_cols);
+
+ // Calculate the scale and offset values to apply.
+ if (scale_after_normalization) {
+ for (int i = 0; i < num_cols; ++i) {
+ (*scale_values)[i] =
+ (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)) *
+ gamma.flat<float>()(i);
+ }
+ } else {
+ for (int i = 0; i < num_cols; ++i) {
+ (*scale_values)[i] =
+ (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon));
+ }
+ }
+ for (int i = 0; i < num_cols; ++i) {
+ (*offset_values)[i] =
+ (-mean.flat<float>()(i) * (*scale_values)[i]) + beta.flat<float>()(i);
+ }
+ return Status::OK();
+}
+
+Status FuseScaleOffsetToConvWeights(const std::vector<float>& scale_values,
+ const std::vector<float>& offset_values,
+ const NodeMatch& conv_node_match,
+ const string& conv_output_name,
+ std::vector<NodeDef>* new_nodes) {
+ const NodeDef& conv_node = conv_node_match.node;
+ CHECK_EQ("Conv2D", conv_node.op());
+ const NodeDef& input_node = conv_node_match.inputs[0].node;
+ const NodeDef& weights_node = conv_node_match.inputs[1].node;
+ CHECK_EQ("Const", weights_node.op());
+
+ Tensor weights = GetNodeTensorAttr(weights_node, "value");
+ const int64 weights_cols = weights.shape().dim_size(3);
+ CHECK_EQ(weights_cols, scale_values.size());
+
+ // Multiply the original weights by the scale vector.
+ auto weights_matrix = weights.flat_inner_dims<float>();
+ Tensor scaled_weights(DT_FLOAT, weights.shape());
+ auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>();
+ for (int64 row = 0; row < weights_matrix.dimension(0); ++row) {
+ for (int64 col = 0; col < weights_cols; ++col) {
+ scaled_weights_matrix(row, col) =
+ weights_matrix(row, col) * scale_values[col];
+ }
+ }
+ // Figure out the remaining bias to add on.
+ Tensor bias_offset(DT_FLOAT, {weights_cols});
+ auto bias_offset_vector = bias_offset.flat<float>();
+ for (int64 col = 0; col < weights_cols; ++col) {
+ bias_offset_vector(col) = offset_values[col];
+ }
+
+ // Construct the new nodes.
+ NodeDef scaled_weights_node;
+ scaled_weights_node.set_op("Const");
+ scaled_weights_node.set_name(weights_node.name());
+ SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node);
+ SetNodeTensorAttr<float>("value", scaled_weights, &scaled_weights_node);
+ new_nodes->push_back(scaled_weights_node);
+
+ // The input and convolution can be copied straight over, since the
+ // name of the scaled weights constant is the same as the original.
+ new_nodes->push_back(input_node);
+ new_nodes->push_back(conv_node);
+
+ NodeDef bias_offset_node;
+ bias_offset_node.set_op("Const");
+ bias_offset_node.set_name(conv_node.name() + "_bn_offset");
+ SetNodeAttr("dtype", DT_FLOAT, &bias_offset_node);
+ SetNodeTensorAttr<float>("value", bias_offset, &bias_offset_node);
+ new_nodes->push_back(bias_offset_node);
+
+ NodeDef bias_add_node;
+ bias_add_node.set_op("BiasAdd");
+ bias_add_node.set_name(conv_output_name);
+ CopyNodeAttr(conv_node, "T", "T", &bias_add_node);
+ AddNodeInput(conv_node.name(), &bias_add_node);
+ AddNodeInput(bias_offset_node.name(), &bias_add_node);
+ new_nodes->push_back(bias_add_node);
+ return Status::OK();
+}
+
+Status FuseBatchNormWithConv(const NodeMatch& match,
+ std::vector<NodeDef>* new_nodes) {
+ // Calculate the scale and offset values to apply.
+ std::vector<float> scale_values;
+ std::vector<float> offset_values;
+ TF_RETURN_IF_ERROR(
+ GetScaleAndOffsetValues(match, &scale_values, &offset_values));
+
+ // Fuse conv weights, and set the final output node name as batch_norm_node.
+ const NodeDef& batch_norm_node = match.node;
+ TF_RETURN_IF_ERROR(
+ FuseScaleOffsetToConvWeights(scale_values, offset_values, match.inputs[0],
+ batch_norm_node.name(), new_nodes));
+ return Status::OK();
+}
+
+Status FuseBatchNormWithConvConcat(const NodeMatch& match,
+ std::vector<NodeDef>* new_nodes) {
+ // Calculate the scale and offset values to apply.
+ std::vector<float> scale_values;
+ std::vector<float> offset_values;
+ TF_RETURN_IF_ERROR(
+ GetScaleAndOffsetValues(match, &scale_values, &offset_values));
+
+ // Find all the nodes we expect in the subgraph.
+ const NodeDef& batch_norm_node = match.node;
+ const NodeMatch& concat_node_match = match.inputs[0];
+ NodeDef concat_node = concat_node_match.node;
+ CHECK_EQ("ConcatV2", concat_node.op());
+
+ // First process the axis.
+ NodeDef axis_node = concat_node_match.inputs[2].node;
+ CHECK_EQ("Const", axis_node.op());
+ Tensor axis = GetNodeTensorAttr(axis_node, "value");
+ int32 axis_scalar = (axis.scalar<int32>())();
+
+ // Set both conv0 and conv1 have the same scale and offset in default.
+ std::vector<float> scale0(scale_values);
+ std::vector<float> offset0(offset_values);
+ std::vector<float> scale1(scale_values);
+ std::vector<float> offset1(offset_values);
+ if (axis_scalar == 3) {
+ // If axis is 3, then scale and offset will be split into two halfs.
+ const NodeDef& weights0_node = concat_node_match.inputs[0].inputs[1].node;
+ Tensor weights0 = GetNodeTensorAttr(weights0_node, "value");
+ const int64 split_cols = weights0.shape().dim_size(3);
+ // Only keep the first half for scale0/offset0.
+ scale0.erase(scale0.begin() + split_cols, scale0.end());
+ offset0.erase(offset0.begin() + split_cols, offset0.end());
+ // Only keep the second half for scale1/offset1.
+ scale1.erase(scale1.begin(), scale1.begin() + split_cols);
+ offset1.erase(offset1.begin(), offset1.begin() + split_cols);
+ }
+
+ // Fuse the weights for input0 of conv2d.
+ const string concat0_output_name = concat_node.name() + "_bn_in0";
+ TF_RETURN_IF_ERROR(
+ FuseScaleOffsetToConvWeights(scale0, offset0, concat_node_match.inputs[0],
+ concat0_output_name, new_nodes));
+
+ // Fuse the weights for input1 of conv2d.
+ const string concat1_output_name = concat_node.name() + "_bn_in1";
+ TF_RETURN_IF_ERROR(
+ FuseScaleOffsetToConvWeights(scale1, offset1, concat_node_match.inputs[1],
+ concat1_output_name, new_nodes));
+
+ // Push the shape node.
+ new_nodes->push_back(concat_node_match.inputs[2].node);
+
+ // Set the final output op name to batch_normal_node.
+ concat_node.set_name(batch_norm_node.name());
+ concat_node.set_input(0, concat0_output_name);
+ concat_node.set_input(1, concat1_output_name);
+ new_nodes->push_back(concat_node);
+ return Status::OK();
+}
} // namespace
// Finds monolithic batch norm ops (as used in early versions of TensorFlow) and
@@ -72,130 +276,57 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def,
const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
- // Find all the nodes we expect in the subgraph.
- const NodeDef& batch_norm_node = match.node;
- // BatchNormWithGlobalNormalization and FusedBatchNorm ops only differ
- // by input order and attribute names.
- CHECK(batch_norm_node.op() == "BatchNormWithGlobalNormalization" ||
- batch_norm_node.op() == "FusedBatchNorm");
- const bool is_fused = batch_norm_node.op() == "FusedBatchNorm";
- const int mean_idx = is_fused ? 3 : 1;
- const int var_idx = is_fused ? 4 : 2;
- const int beta_idx = is_fused ? 2 : 3;
- const int gamma_idx = is_fused ? 1 : 4;
- const string epsilon_attr = is_fused ? "epsilon" : "variance_epsilon";
- // FusedBatchNorm always scales after normalization.
- const bool scale_after_normalization =
- is_fused ||
- batch_norm_node.attr().at("scale_after_normalization").b();
-
- const NodeDef& conv_node = match.inputs[0].node;
- CHECK_EQ("Conv2D", conv_node.op());
- const NodeDef& input_node = match.inputs[0].inputs[0].node;
- const NodeDef& weights_node = match.inputs[0].inputs[1].node;
- CHECK_EQ("Const", weights_node.op());
- const NodeDef& mean_node = match.inputs[mean_idx].node;
- CHECK_EQ("Const", mean_node.op());
- const NodeDef& variance_node = match.inputs[var_idx].node;
- CHECK_EQ("Const", variance_node.op());
- const NodeDef& beta_node = match.inputs[beta_idx].node;
- CHECK_EQ("Const", beta_node.op());
- const NodeDef& gamma_node = match.inputs[gamma_idx].node;
- CHECK_EQ("Const", gamma_node.op());
-
- // We have a set of vectors that we want to combine into a vector of
- // scale values to apply column-wise to the weight input to the conv,
- // and an offset vector that we'll apply to the output of the conv.
- Tensor weights = GetNodeTensorAttr(weights_node, "value");
- Tensor mean = GetNodeTensorAttr(mean_node, "value");
- Tensor variance = GetNodeTensorAttr(variance_node, "value");
- Tensor beta = GetNodeTensorAttr(beta_node, "value");
- Tensor gamma = GetNodeTensorAttr(gamma_node, "value");
- const float variance_epsilon =
- batch_norm_node.attr().at(epsilon_attr).f();
-
- // Make sure all the inputs really are vectors, with as many entries
- // as there are columns in the weights.
- const int64 weights_cols = weights.shape().dim_size(3);
- TF_RETURN_IF_ERROR(ErrorIfNotVector(mean, "Mean", weights_cols));
- TF_RETURN_IF_ERROR(
- ErrorIfNotVector(variance, "Variance", weights_cols));
- TF_RETURN_IF_ERROR(ErrorIfNotVector(beta, "Beta", weights_cols));
- TF_RETURN_IF_ERROR(ErrorIfNotVector(gamma, "gamma", weights_cols));
-
- // Calculate the scale and offset values to apply.
- std::vector<float> scale_values(weights_cols);
- std::vector<float> offset_values(weights_cols);
- if (scale_after_normalization) {
- for (int i = 0; i < weights_cols; ++i) {
- scale_values[i] =
- (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)) *
- gamma.flat<float>()(i);
- }
- } else {
- for (int i = 0; i < weights_cols; ++i) {
- scale_values[i] =
- (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon));
- }
- }
- for (int i = 0; i < weights_cols; ++i) {
- offset_values[i] = (-mean.flat<float>()(i) * scale_values[i]) +
- beta.flat<float>()(i);
- }
-
- // Multiply the original weights by the scale vector.
- auto weights_matrix = weights.flat_inner_dims<float>();
- Tensor scaled_weights(DT_FLOAT, weights.shape());
- auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>();
- for (int64 row = 0; row < weights_matrix.dimension(0); ++row) {
- for (int64 col = 0; col < weights_cols; ++col) {
- scaled_weights_matrix(row, col) =
- weights_matrix(row, col) * scale_values[col];
- }
- }
- // Figure out the remaining bias to add on.
- Tensor bias_offset(DT_FLOAT, {weights_cols});
- auto bias_offset_vector = bias_offset.flat<float>();
- for (int64 col = 0; col < weights_cols; ++col) {
- bias_offset_vector(col) = offset_values[col];
- }
-
- // Construct the new nodes.
- NodeDef scaled_weights_node;
- scaled_weights_node.set_op("Const");
- scaled_weights_node.set_name(weights_node.name());
- SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node);
- SetNodeTensorAttr<float>("value", scaled_weights,
- &scaled_weights_node);
- new_nodes->push_back(scaled_weights_node);
-
- // The input and convolution can be copied straight over, since the
- // name of the scaled weights constant is the same as the original.
- new_nodes->push_back(input_node);
- new_nodes->push_back(conv_node);
-
- NodeDef bias_offset_node;
- bias_offset_node.set_op("Const");
- bias_offset_node.set_name(conv_node.name() + "_bn_offset");
- SetNodeAttr("dtype", DT_FLOAT, &bias_offset_node);
- SetNodeTensorAttr<float>("value", bias_offset, &bias_offset_node);
- new_nodes->push_back(bias_offset_node);
-
- NodeDef bias_add_node;
- bias_add_node.set_op("BiasAdd");
- bias_add_node.set_name(batch_norm_node.name());
- CopyNodeAttr(conv_node, "T", "T", &bias_add_node);
- AddNodeInput(conv_node.name(), &bias_add_node);
- AddNodeInput(bias_offset_node.name(), &bias_add_node);
- new_nodes->push_back(bias_add_node);
-
+ TF_RETURN_IF_ERROR(FuseBatchNormWithConv(match, new_nodes));
did_graph_change = true;
+ return Status::OK();
+ },
+ {}, &replaced_graph_def));
+ current_graph_def = replaced_graph_def;
+ } while (did_graph_change);
+ do {
+ did_graph_change = false;
+ GraphDef replaced_graph_def;
+ // Replace BatchNorm with concat as input.
+ TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
+ current_graph_def, // clang-format off
+ {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node
+ {
+ {"ConcatV2|Concat", // concat two conv2d.
+ {
+ {"Conv2D", // conv_node
+ {
+ {"*"}, // input_node
+ {"Const"}, // weights_node
+ }
+ },
+ {"Conv2D", // conv_node
+ {
+ {"*"}, // input_node
+ {"Const"}, // weights_node
+ }
+ },
+ {"Const"}, // axis
+ },
+ },
+ {"Const"}, // mean_node
+ {"Const"}, // variance_node
+ {"Const"}, // beta_node
+ {"Const"}, // gamma_node
+ }
+ }, // clang-format on
+ [&did_graph_change](const NodeMatch& match,
+ const std::set<string>& input_nodes,
+ const std::set<string>& output_nodes,
+ std::vector<NodeDef>* new_nodes) {
+ TF_RETURN_IF_ERROR(FuseBatchNormWithConvConcat(match, new_nodes));
+ did_graph_change = true;
return Status::OK();
},
{}, &replaced_graph_def));
current_graph_def = replaced_graph_def;
} while (did_graph_change);
+
*output_graph_def = current_graph_def;
return Status::OK();
}
diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
index 3be9110b47..b30ba9ac8b 100644
--- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
@@ -196,6 +196,106 @@ class FoldOldBatchNormsTest : public ::testing::Test {
EXPECT_NE("FusedBatchNorm", node.op());
}
}
+
+ void TestFoldFusedBatchNormsWithConcat(const bool split) {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ // If split is true, concat two inputs at dim=3; otherwise, concat at dim 2.
+ auto input_shape =
+ split ? TensorShape({1, 1, 6, 2}) : TensorShape({1, 1, 12, 1});
+ Tensor input_data(DT_FLOAT, input_shape);
+ test::FillValues<float>(
+ &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
+ -5.0f, -3.0f, -6.0f});
+
+ Output input0_op =
+ Const(root.WithOpName("input_op0"), Input::Initializer(input_data));
+ // If split is true, concat two inputs at dim=3; otherwise, concat at dim 2.
+ // The final output shape of concat is always {1, 2, 2, 2}.
+ auto weight_shape =
+ split ? TensorShape({1, 2, 2, 1}) : TensorShape({1, 2, 1, 2});
+ Tensor weights0_data(DT_FLOAT, weight_shape);
+ test::FillValues<float>(&weights0_data, {1.0f, 2.0f, 3.0f, 4.0f});
+ Output weights0_op = Const(root.WithOpName("weights1_op"),
+ Input::Initializer(weights0_data));
+ Output conv0_op = Conv2D(root.WithOpName("conv1_op"), input0_op,
+ weights0_op, {1, 1, 1, 1}, "VALID");
+
+ Output input1_op =
+ Const(root.WithOpName("input1_op"), Input::Initializer(input_data));
+ Tensor weights1_data(DT_FLOAT, weight_shape);
+ test::FillValues<float>(&weights1_data, {1.0f, 2.0f, 3.0f, 4.0f});
+ Output weights1_op = Const(root.WithOpName("weights1_op"),
+ Input::Initializer(weights1_data));
+ Output conv1_op = Conv2D(root.WithOpName("conv1_op"), input1_op,
+ weights1_op, {1, 1, 1, 1}, "VALID");
+
+ Tensor shape_tensor(DT_INT32, TensorShape({}));
+ // Concat at dim 3 if split; otherwise, concat at dim 2.
+ int32 concat_axis = split ? 3 : 2;
+ test::FillValues<int32>(&shape_tensor, {concat_axis});
+ Output shape_op =
+ Const(root.WithOpName("shape_op"), Input::Initializer(shape_tensor));
+ Output concat_op =
+ Concat(root.WithOpName("concat_op"), {conv0_op, conv1_op}, shape_op);
+
+ Tensor mean_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&mean_data, {10.0f, 20.0f});
+ Output mean_op =
+ Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
+
+ Tensor variance_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&variance_data, {0.25f, 0.5f});
+ Output variance_op = Const(root.WithOpName("variance_op"),
+ Input::Initializer(variance_data));
+
+ Tensor beta_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&beta_data, {0.1f, 0.6f});
+ Output beta_op =
+ Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
+
+ Tensor gamma_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&gamma_data, {1.0f, 2.0f});
+ Output gamma_op =
+ Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
+
+ GraphDef original_graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
+
+ NodeDef batch_norm_node;
+ batch_norm_node.set_op("FusedBatchNorm");
+ batch_norm_node.set_name("output");
+ AddNodeInput("concat_op", &batch_norm_node);
+ AddNodeInput("gamma_op", &batch_norm_node);
+ AddNodeInput("beta_op", &batch_norm_node);
+ AddNodeInput("mean_op", &batch_norm_node);
+ AddNodeInput("variance_op", &batch_norm_node);
+ SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
+ SetNodeAttr("epsilon", 0.00001f, &batch_norm_node);
+ SetNodeAttr("is_training", false, &batch_norm_node);
+ *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
+
+ std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(original_session->Create(original_graph_def));
+ std::vector<Tensor> original_outputs;
+ TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
+
+ GraphDef fused_graph_def;
+ TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
+ &fused_graph_def));
+
+ std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(fused_session->Create(fused_graph_def));
+ std::vector<Tensor> fused_outputs;
+ TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
+
+ test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
+
+ for (const NodeDef& node : fused_graph_def.node()) {
+ EXPECT_NE("FusedBatchNorm", node.op());
+ }
+ }
};
TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNorms) {
@@ -206,5 +306,14 @@ TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNorms) {
TestFoldFusedBatchNorms();
}
+TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithConcat) {
+ // Test axis is not 3, so all weigths and offsets are fused to each of inputs
+ // of conv2d.
+ TestFoldFusedBatchNormsWithConcat(/*split=*/true);
+ // Test axis = 3, BatchNorm weights and offsets will be split before fused
+ // with conv2d weights.
+ TestFoldFusedBatchNormsWithConcat(/*split=*/false);
+}
+
} // namespace graph_transforms
} // namespace tensorflow