diff options
author | Mingxing Tan <tanmingxing@google.com> | 2017-09-25 20:20:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-25 20:24:19 -0700 |
commit | c4ee3bc1929ac672f76ed44bd25eeb7a5400fca5 (patch) | |
tree | 00dd2c39c6b4df6147a413229aad90b29510dacc /tensorflow/tools/graph_transforms | |
parent | 89ffbeaca0dcc69186b90d22d3282fc28db143c3 (diff) |
Enable folding batch norm when inputs are concat of Conv2D.
PiperOrigin-RevId: 170001077
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_old_batch_norms.cc | 371 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc | 109 |
2 files changed, 360 insertions, 120 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc index 0978c336b4..d89afe85c7 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc @@ -32,12 +32,216 @@ Status ErrorIfNotVector(const Tensor& input, const string& input_name, int expected_width) { if ((input.shape().dims() != 1) || (input.shape().dim_size(0) != expected_width)) { - return errors::InvalidArgument(input_name, - " input to batch norm has bad shape: ", - input.shape().DebugString()); + return errors::InvalidArgument( + input_name, + " input to batch norm has bad shape: ", input.shape().DebugString()); } return Status::OK(); } + +Status GetScaleAndOffsetValues(const NodeMatch& match, + std::vector<float>* scale_values, + std::vector<float>* offset_values) { + // Find all the nodes we expect in the subgraph. + const NodeDef& batch_norm_node = match.node; + // BatchNormWithGlobalNormalization and FusedBatchNorm ops only differ + // by input order and attribute names. + CHECK(batch_norm_node.op() == "BatchNormWithGlobalNormalization" || + batch_norm_node.op() == "FusedBatchNorm"); + const bool is_fused = batch_norm_node.op() == "FusedBatchNorm"; + const int mean_idx = is_fused ? 3 : 1; + const int var_idx = is_fused ? 4 : 2; + const int beta_idx = is_fused ? 2 : 3; + const int gamma_idx = is_fused ? 1 : 4; + const string epsilon_attr = is_fused ? "epsilon" : "variance_epsilon"; + // FusedBatchNorm always scales after normalization. + const bool scale_after_normalization = + is_fused || batch_norm_node.attr().at("scale_after_normalization").b(); + + const NodeDef& mean_node = match.inputs[mean_idx].node; + CHECK_EQ("Const", mean_node.op()); + const NodeDef& variance_node = match.inputs[var_idx].node; + CHECK_EQ("Const", variance_node.op()); + const NodeDef& beta_node = match.inputs[beta_idx].node; + CHECK_EQ("Const", beta_node.op()); + const NodeDef& gamma_node = match.inputs[gamma_idx].node; + CHECK_EQ("Const", gamma_node.op()); + + // We have a set of vectors that we want to combine into a vector of + // scale values and offset values. + Tensor mean = GetNodeTensorAttr(mean_node, "value"); + Tensor variance = GetNodeTensorAttr(variance_node, "value"); + Tensor beta = GetNodeTensorAttr(beta_node, "value"); + Tensor gamma = GetNodeTensorAttr(gamma_node, "value"); + const float variance_epsilon = batch_norm_node.attr().at(epsilon_attr).f(); + + // Make sure all the inputs really are vectors with the same shape. + const int64 num_cols = mean.shape().dim_size(0); + TF_RETURN_IF_ERROR(ErrorIfNotVector(variance, "Variance", num_cols)); + TF_RETURN_IF_ERROR(ErrorIfNotVector(beta, "Beta", num_cols)); + TF_RETURN_IF_ERROR(ErrorIfNotVector(gamma, "gamma", num_cols)); + + scale_values->resize(num_cols); + offset_values->resize(num_cols); + + // Calculate the scale and offset values to apply. + if (scale_after_normalization) { + for (int i = 0; i < num_cols; ++i) { + (*scale_values)[i] = + (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)) * + gamma.flat<float>()(i); + } + } else { + for (int i = 0; i < num_cols; ++i) { + (*scale_values)[i] = + (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)); + } + } + for (int i = 0; i < num_cols; ++i) { + (*offset_values)[i] = + (-mean.flat<float>()(i) * (*scale_values)[i]) + beta.flat<float>()(i); + } + return Status::OK(); +} + +Status FuseScaleOffsetToConvWeights(const std::vector<float>& scale_values, + const std::vector<float>& offset_values, + const NodeMatch& conv_node_match, + const string& conv_output_name, + std::vector<NodeDef>* new_nodes) { + const NodeDef& conv_node = conv_node_match.node; + CHECK_EQ("Conv2D", conv_node.op()); + const NodeDef& input_node = conv_node_match.inputs[0].node; + const NodeDef& weights_node = conv_node_match.inputs[1].node; + CHECK_EQ("Const", weights_node.op()); + + Tensor weights = GetNodeTensorAttr(weights_node, "value"); + const int64 weights_cols = weights.shape().dim_size(3); + CHECK_EQ(weights_cols, scale_values.size()); + + // Multiply the original weights by the scale vector. + auto weights_matrix = weights.flat_inner_dims<float>(); + Tensor scaled_weights(DT_FLOAT, weights.shape()); + auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>(); + for (int64 row = 0; row < weights_matrix.dimension(0); ++row) { + for (int64 col = 0; col < weights_cols; ++col) { + scaled_weights_matrix(row, col) = + weights_matrix(row, col) * scale_values[col]; + } + } + // Figure out the remaining bias to add on. + Tensor bias_offset(DT_FLOAT, {weights_cols}); + auto bias_offset_vector = bias_offset.flat<float>(); + for (int64 col = 0; col < weights_cols; ++col) { + bias_offset_vector(col) = offset_values[col]; + } + + // Construct the new nodes. + NodeDef scaled_weights_node; + scaled_weights_node.set_op("Const"); + scaled_weights_node.set_name(weights_node.name()); + SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node); + SetNodeTensorAttr<float>("value", scaled_weights, &scaled_weights_node); + new_nodes->push_back(scaled_weights_node); + + // The input and convolution can be copied straight over, since the + // name of the scaled weights constant is the same as the original. + new_nodes->push_back(input_node); + new_nodes->push_back(conv_node); + + NodeDef bias_offset_node; + bias_offset_node.set_op("Const"); + bias_offset_node.set_name(conv_node.name() + "_bn_offset"); + SetNodeAttr("dtype", DT_FLOAT, &bias_offset_node); + SetNodeTensorAttr<float>("value", bias_offset, &bias_offset_node); + new_nodes->push_back(bias_offset_node); + + NodeDef bias_add_node; + bias_add_node.set_op("BiasAdd"); + bias_add_node.set_name(conv_output_name); + CopyNodeAttr(conv_node, "T", "T", &bias_add_node); + AddNodeInput(conv_node.name(), &bias_add_node); + AddNodeInput(bias_offset_node.name(), &bias_add_node); + new_nodes->push_back(bias_add_node); + return Status::OK(); +} + +Status FuseBatchNormWithConv(const NodeMatch& match, + std::vector<NodeDef>* new_nodes) { + // Calculate the scale and offset values to apply. + std::vector<float> scale_values; + std::vector<float> offset_values; + TF_RETURN_IF_ERROR( + GetScaleAndOffsetValues(match, &scale_values, &offset_values)); + + // Fuse conv weights, and set the final output node name as batch_norm_node. + const NodeDef& batch_norm_node = match.node; + TF_RETURN_IF_ERROR( + FuseScaleOffsetToConvWeights(scale_values, offset_values, match.inputs[0], + batch_norm_node.name(), new_nodes)); + return Status::OK(); +} + +Status FuseBatchNormWithConvConcat(const NodeMatch& match, + std::vector<NodeDef>* new_nodes) { + // Calculate the scale and offset values to apply. + std::vector<float> scale_values; + std::vector<float> offset_values; + TF_RETURN_IF_ERROR( + GetScaleAndOffsetValues(match, &scale_values, &offset_values)); + + // Find all the nodes we expect in the subgraph. + const NodeDef& batch_norm_node = match.node; + const NodeMatch& concat_node_match = match.inputs[0]; + NodeDef concat_node = concat_node_match.node; + CHECK_EQ("ConcatV2", concat_node.op()); + + // First process the axis. + NodeDef axis_node = concat_node_match.inputs[2].node; + CHECK_EQ("Const", axis_node.op()); + Tensor axis = GetNodeTensorAttr(axis_node, "value"); + int32 axis_scalar = (axis.scalar<int32>())(); + + // Set both conv0 and conv1 have the same scale and offset in default. + std::vector<float> scale0(scale_values); + std::vector<float> offset0(offset_values); + std::vector<float> scale1(scale_values); + std::vector<float> offset1(offset_values); + if (axis_scalar == 3) { + // If axis is 3, then scale and offset will be split into two halfs. + const NodeDef& weights0_node = concat_node_match.inputs[0].inputs[1].node; + Tensor weights0 = GetNodeTensorAttr(weights0_node, "value"); + const int64 split_cols = weights0.shape().dim_size(3); + // Only keep the first half for scale0/offset0. + scale0.erase(scale0.begin() + split_cols, scale0.end()); + offset0.erase(offset0.begin() + split_cols, offset0.end()); + // Only keep the second half for scale1/offset1. + scale1.erase(scale1.begin(), scale1.begin() + split_cols); + offset1.erase(offset1.begin(), offset1.begin() + split_cols); + } + + // Fuse the weights for input0 of conv2d. + const string concat0_output_name = concat_node.name() + "_bn_in0"; + TF_RETURN_IF_ERROR( + FuseScaleOffsetToConvWeights(scale0, offset0, concat_node_match.inputs[0], + concat0_output_name, new_nodes)); + + // Fuse the weights for input1 of conv2d. + const string concat1_output_name = concat_node.name() + "_bn_in1"; + TF_RETURN_IF_ERROR( + FuseScaleOffsetToConvWeights(scale1, offset1, concat_node_match.inputs[1], + concat1_output_name, new_nodes)); + + // Push the shape node. + new_nodes->push_back(concat_node_match.inputs[2].node); + + // Set the final output op name to batch_normal_node. + concat_node.set_name(batch_norm_node.name()); + concat_node.set_input(0, concat0_output_name); + concat_node.set_input(1, concat1_output_name); + new_nodes->push_back(concat_node); + return Status::OK(); +} } // namespace // Finds monolithic batch norm ops (as used in early versions of TensorFlow) and @@ -72,130 +276,57 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def, const std::set<string>& input_nodes, const std::set<string>& output_nodes, std::vector<NodeDef>* new_nodes) { - // Find all the nodes we expect in the subgraph. - const NodeDef& batch_norm_node = match.node; - // BatchNormWithGlobalNormalization and FusedBatchNorm ops only differ - // by input order and attribute names. - CHECK(batch_norm_node.op() == "BatchNormWithGlobalNormalization" || - batch_norm_node.op() == "FusedBatchNorm"); - const bool is_fused = batch_norm_node.op() == "FusedBatchNorm"; - const int mean_idx = is_fused ? 3 : 1; - const int var_idx = is_fused ? 4 : 2; - const int beta_idx = is_fused ? 2 : 3; - const int gamma_idx = is_fused ? 1 : 4; - const string epsilon_attr = is_fused ? "epsilon" : "variance_epsilon"; - // FusedBatchNorm always scales after normalization. - const bool scale_after_normalization = - is_fused || - batch_norm_node.attr().at("scale_after_normalization").b(); - - const NodeDef& conv_node = match.inputs[0].node; - CHECK_EQ("Conv2D", conv_node.op()); - const NodeDef& input_node = match.inputs[0].inputs[0].node; - const NodeDef& weights_node = match.inputs[0].inputs[1].node; - CHECK_EQ("Const", weights_node.op()); - const NodeDef& mean_node = match.inputs[mean_idx].node; - CHECK_EQ("Const", mean_node.op()); - const NodeDef& variance_node = match.inputs[var_idx].node; - CHECK_EQ("Const", variance_node.op()); - const NodeDef& beta_node = match.inputs[beta_idx].node; - CHECK_EQ("Const", beta_node.op()); - const NodeDef& gamma_node = match.inputs[gamma_idx].node; - CHECK_EQ("Const", gamma_node.op()); - - // We have a set of vectors that we want to combine into a vector of - // scale values to apply column-wise to the weight input to the conv, - // and an offset vector that we'll apply to the output of the conv. - Tensor weights = GetNodeTensorAttr(weights_node, "value"); - Tensor mean = GetNodeTensorAttr(mean_node, "value"); - Tensor variance = GetNodeTensorAttr(variance_node, "value"); - Tensor beta = GetNodeTensorAttr(beta_node, "value"); - Tensor gamma = GetNodeTensorAttr(gamma_node, "value"); - const float variance_epsilon = - batch_norm_node.attr().at(epsilon_attr).f(); - - // Make sure all the inputs really are vectors, with as many entries - // as there are columns in the weights. - const int64 weights_cols = weights.shape().dim_size(3); - TF_RETURN_IF_ERROR(ErrorIfNotVector(mean, "Mean", weights_cols)); - TF_RETURN_IF_ERROR( - ErrorIfNotVector(variance, "Variance", weights_cols)); - TF_RETURN_IF_ERROR(ErrorIfNotVector(beta, "Beta", weights_cols)); - TF_RETURN_IF_ERROR(ErrorIfNotVector(gamma, "gamma", weights_cols)); - - // Calculate the scale and offset values to apply. - std::vector<float> scale_values(weights_cols); - std::vector<float> offset_values(weights_cols); - if (scale_after_normalization) { - for (int i = 0; i < weights_cols; ++i) { - scale_values[i] = - (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)) * - gamma.flat<float>()(i); - } - } else { - for (int i = 0; i < weights_cols; ++i) { - scale_values[i] = - (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)); - } - } - for (int i = 0; i < weights_cols; ++i) { - offset_values[i] = (-mean.flat<float>()(i) * scale_values[i]) + - beta.flat<float>()(i); - } - - // Multiply the original weights by the scale vector. - auto weights_matrix = weights.flat_inner_dims<float>(); - Tensor scaled_weights(DT_FLOAT, weights.shape()); - auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>(); - for (int64 row = 0; row < weights_matrix.dimension(0); ++row) { - for (int64 col = 0; col < weights_cols; ++col) { - scaled_weights_matrix(row, col) = - weights_matrix(row, col) * scale_values[col]; - } - } - // Figure out the remaining bias to add on. - Tensor bias_offset(DT_FLOAT, {weights_cols}); - auto bias_offset_vector = bias_offset.flat<float>(); - for (int64 col = 0; col < weights_cols; ++col) { - bias_offset_vector(col) = offset_values[col]; - } - - // Construct the new nodes. - NodeDef scaled_weights_node; - scaled_weights_node.set_op("Const"); - scaled_weights_node.set_name(weights_node.name()); - SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node); - SetNodeTensorAttr<float>("value", scaled_weights, - &scaled_weights_node); - new_nodes->push_back(scaled_weights_node); - - // The input and convolution can be copied straight over, since the - // name of the scaled weights constant is the same as the original. - new_nodes->push_back(input_node); - new_nodes->push_back(conv_node); - - NodeDef bias_offset_node; - bias_offset_node.set_op("Const"); - bias_offset_node.set_name(conv_node.name() + "_bn_offset"); - SetNodeAttr("dtype", DT_FLOAT, &bias_offset_node); - SetNodeTensorAttr<float>("value", bias_offset, &bias_offset_node); - new_nodes->push_back(bias_offset_node); - - NodeDef bias_add_node; - bias_add_node.set_op("BiasAdd"); - bias_add_node.set_name(batch_norm_node.name()); - CopyNodeAttr(conv_node, "T", "T", &bias_add_node); - AddNodeInput(conv_node.name(), &bias_add_node); - AddNodeInput(bias_offset_node.name(), &bias_add_node); - new_nodes->push_back(bias_add_node); - + TF_RETURN_IF_ERROR(FuseBatchNormWithConv(match, new_nodes)); did_graph_change = true; + return Status::OK(); + }, + {}, &replaced_graph_def)); + current_graph_def = replaced_graph_def; + } while (did_graph_change); + do { + did_graph_change = false; + GraphDef replaced_graph_def; + // Replace BatchNorm with concat as input. + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + current_graph_def, // clang-format off + {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node + { + {"ConcatV2|Concat", // concat two conv2d. + { + {"Conv2D", // conv_node + { + {"*"}, // input_node + {"Const"}, // weights_node + } + }, + {"Conv2D", // conv_node + { + {"*"}, // input_node + {"Const"}, // weights_node + } + }, + {"Const"}, // axis + }, + }, + {"Const"}, // mean_node + {"Const"}, // variance_node + {"Const"}, // beta_node + {"Const"}, // gamma_node + } + }, // clang-format on + [&did_graph_change](const NodeMatch& match, + const std::set<string>& input_nodes, + const std::set<string>& output_nodes, + std::vector<NodeDef>* new_nodes) { + TF_RETURN_IF_ERROR(FuseBatchNormWithConvConcat(match, new_nodes)); + did_graph_change = true; return Status::OK(); }, {}, &replaced_graph_def)); current_graph_def = replaced_graph_def; } while (did_graph_change); + *output_graph_def = current_graph_def; return Status::OK(); } diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc index 3be9110b47..b30ba9ac8b 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc @@ -196,6 +196,106 @@ class FoldOldBatchNormsTest : public ::testing::Test { EXPECT_NE("FusedBatchNorm", node.op()); } } + + void TestFoldFusedBatchNormsWithConcat(const bool split) { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + // If split is true, concat two inputs at dim=3; otherwise, concat at dim 2. + auto input_shape = + split ? TensorShape({1, 1, 6, 2}) : TensorShape({1, 1, 12, 1}); + Tensor input_data(DT_FLOAT, input_shape); + test::FillValues<float>( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + + Output input0_op = + Const(root.WithOpName("input_op0"), Input::Initializer(input_data)); + // If split is true, concat two inputs at dim=3; otherwise, concat at dim 2. + // The final output shape of concat is always {1, 2, 2, 2}. + auto weight_shape = + split ? TensorShape({1, 2, 2, 1}) : TensorShape({1, 2, 1, 2}); + Tensor weights0_data(DT_FLOAT, weight_shape); + test::FillValues<float>(&weights0_data, {1.0f, 2.0f, 3.0f, 4.0f}); + Output weights0_op = Const(root.WithOpName("weights1_op"), + Input::Initializer(weights0_data)); + Output conv0_op = Conv2D(root.WithOpName("conv1_op"), input0_op, + weights0_op, {1, 1, 1, 1}, "VALID"); + + Output input1_op = + Const(root.WithOpName("input1_op"), Input::Initializer(input_data)); + Tensor weights1_data(DT_FLOAT, weight_shape); + test::FillValues<float>(&weights1_data, {1.0f, 2.0f, 3.0f, 4.0f}); + Output weights1_op = Const(root.WithOpName("weights1_op"), + Input::Initializer(weights1_data)); + Output conv1_op = Conv2D(root.WithOpName("conv1_op"), input1_op, + weights1_op, {1, 1, 1, 1}, "VALID"); + + Tensor shape_tensor(DT_INT32, TensorShape({})); + // Concat at dim 3 if split; otherwise, concat at dim 2. + int32 concat_axis = split ? 3 : 2; + test::FillValues<int32>(&shape_tensor, {concat_axis}); + Output shape_op = + Const(root.WithOpName("shape_op"), Input::Initializer(shape_tensor)); + Output concat_op = + Concat(root.WithOpName("concat_op"), {conv0_op, conv1_op}, shape_op); + + Tensor mean_data(DT_FLOAT, TensorShape({2})); + test::FillValues<float>(&mean_data, {10.0f, 20.0f}); + Output mean_op = + Const(root.WithOpName("mean_op"), Input::Initializer(mean_data)); + + Tensor variance_data(DT_FLOAT, TensorShape({2})); + test::FillValues<float>(&variance_data, {0.25f, 0.5f}); + Output variance_op = Const(root.WithOpName("variance_op"), + Input::Initializer(variance_data)); + + Tensor beta_data(DT_FLOAT, TensorShape({2})); + test::FillValues<float>(&beta_data, {0.1f, 0.6f}); + Output beta_op = + Const(root.WithOpName("beta_op"), Input::Initializer(beta_data)); + + Tensor gamma_data(DT_FLOAT, TensorShape({2})); + test::FillValues<float>(&gamma_data, {1.0f, 2.0f}); + Output gamma_op = + Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data)); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + NodeDef batch_norm_node; + batch_norm_node.set_op("FusedBatchNorm"); + batch_norm_node.set_name("output"); + AddNodeInput("concat_op", &batch_norm_node); + AddNodeInput("gamma_op", &batch_norm_node); + AddNodeInput("beta_op", &batch_norm_node); + AddNodeInput("mean_op", &batch_norm_node); + AddNodeInput("variance_op", &batch_norm_node); + SetNodeAttr("T", DT_FLOAT, &batch_norm_node); + SetNodeAttr("epsilon", 0.00001f, &batch_norm_node); + SetNodeAttr("is_training", false, &batch_norm_node); + *(original_graph_def.mutable_node()->Add()) = batch_norm_node; + + std::unique_ptr<Session> original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector<Tensor> original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs)); + + GraphDef fused_graph_def; + TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}}, + &fused_graph_def)); + + std::unique_ptr<Session> fused_session(NewSession(SessionOptions())); + TF_ASSERT_OK(fused_session->Create(fused_graph_def)); + std::vector<Tensor> fused_outputs; + TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); + + test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5); + + for (const NodeDef& node : fused_graph_def.node()) { + EXPECT_NE("FusedBatchNorm", node.op()); + } + } }; TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNorms) { @@ -206,5 +306,14 @@ TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNorms) { TestFoldFusedBatchNorms(); } +TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithConcat) { + // Test axis is not 3, so all weigths and offsets are fused to each of inputs + // of conv2d. + TestFoldFusedBatchNormsWithConcat(/*split=*/true); + // Test axis = 3, BatchNorm weights and offsets will be split before fused + // with conv2d weights. + TestFoldFusedBatchNormsWithConcat(/*split=*/false); +} + } // namespace graph_transforms } // namespace tensorflow |