diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-01 02:29:36 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-01 02:50:43 -0800 |
commit | ec86b037893fb00be8e9c366a5a6196d89a6dd72 (patch) | |
tree | e3edafa6703f73a28583080ef23e221ba1bd720e | |
parent | 7e48bada5a6c5583ab6e9a103337780863af08cc (diff) |
Extends fold_batch_norms transform to also fold the mul introduced by batch normalization after fully connected layers (MatMul).
Change: 148868461
-rw-r--r-- | tensorflow/tools/graph_transforms/README.md | 13 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_batch_norms.cc | 18 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_batch_norms_test.cc | 60 |
3 files changed, 75 insertions, 16 deletions
diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md index 8cba86f63e..36a5c01a0c 100644 --- a/tensorflow/tools/graph_transforms/README.md +++ b/tensorflow/tools/graph_transforms/README.md @@ -341,12 +341,13 @@ Args: None \ Prerequisites: [fold_constants](#fold_constants) This transform tries to optimize away the Mul that's introduced after a Conv2D -when batch normalization has been used during training. It scans the graph for -any channel-wise multiplies immediately after convolutions, and multiplies the -convolution's weights with the Mul instead so this can be omitted at inference -time. You'll need to make sure you run [fold_constants](#fold_constants) first, -since the pattern can only be spotted if the normal complex expression that's -produced by training for the Mul input is collapsed down into a simple constant. +(or a MatMul) when batch normalization has been used during training. It scans +the graph for any channel-wise multiplies immediately after convolutions, and +multiplies the convolution's (or matrix multiplication's) weights with the Mul +instead so this can be omitted at inference time. You'll need to make sure you +run [fold_constants](#fold_constants) first, since the pattern can only be +spotted if the normal complex expression that's produced by training for the Mul +input is collapsed down into a simple constant. ### fold_constants diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_batch_norms.cc index 9f3393f126..7daacb55a2 100644 --- a/tensorflow/tools/graph_transforms/fold_batch_norms.cc +++ b/tensorflow/tools/graph_transforms/fold_batch_norms.cc @@ -27,23 +27,24 @@ limitations under the License. namespace tensorflow { namespace graph_transforms { -// Converts Conv2D ops followed by column-wise Muls into equivalent ops with the -// Mul baked into the convolution weights, to save computation during inference. +// Converts Conv2D or MatMul ops followed by column-wise Muls into equivalent +// ops with the Mul baked into the convolution weights, to save computation +// during inference. Status FoldBatchNorms(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { GraphDef replaced_graph_def; TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( input_graph_def, // clang-format off - {"Mul", // mul_node + {"Mul", // mul_node { - {"Conv2D", // conv_node + {"Conv2D|MatMul", // conv_node { - {"*"}, // input_node - {"Const"}, // weights_node + {"*"}, // input_node + {"Const"}, // weights_node } }, - {"Const"}, // mul_values_node + {"Const"}, // mul_values_node } }, // clang-format on [](const NodeMatch& match, const std::set<string>& input_nodes, @@ -61,7 +62,8 @@ Status FoldBatchNorms(const GraphDef& input_graph_def, // Make sure all the inputs really are vectors, with as many entries as // there are columns in the weights. - const int64 weights_cols = weights.shape().dim_size(3); + const int weights_cols_index = conv_node.op() == "Conv2D" ? 3 : 1; + const int64 weights_cols = weights.shape().dim_size(weights_cols_index); if ((mul_values.shape().dims() != 1) || (mul_values.shape().dim_size(0) != weights_cols)) { return errors::InvalidArgument( diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc index b9983fdd0b..ed741f002c 100644 --- a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -35,7 +36,7 @@ Status FoldBatchNorms(const GraphDef& input_graph_def, class FoldBatchNormsTest : public ::testing::Test { protected: - void TestFoldBatchNorms() { + void TestFoldBatchNormsConv2D() { auto root = tensorflow::Scope::NewRootScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) @@ -85,9 +86,64 @@ class FoldBatchNormsTest : public ::testing::Test { EXPECT_NE("Mul", node.op()); } } + + void TestFoldBatchNormsMatMul() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, TensorShape({6, 2})); + test::FillValues<float>( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_data)); + + Tensor weights_data(DT_FLOAT, TensorShape({2, 2})); + test::FillValues<float>(&weights_data, {1.0f, 2.0f, 0.3f, 0.4f}); + Output weights_op = + Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); + + Output matmul_op = + MatMul(root.WithOpName("matmul_op"), input_op, weights_op); + + Tensor mul_values_data(DT_FLOAT, TensorShape({2})); + test::FillValues<float>(&mul_values_data, {2.0f, 3.0f}); + Output mul_values_op = Const(root.WithOpName("mul_values"), + Input::Initializer(mul_values_data)); + + Output mul_op = Mul(root.WithOpName("output"), matmul_op, mul_values_op); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + std::unique_ptr<Session> original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector<Tensor> original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs)); + + GraphDef fused_graph_def; + TF_ASSERT_OK( + FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def)); + + std::unique_ptr<Session> fused_session(NewSession(SessionOptions())); + TF_ASSERT_OK(fused_session->Create(fused_graph_def)); + std::vector<Tensor> fused_outputs; + TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); + + test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5); + + for (const NodeDef& node : fused_graph_def.node()) { + EXPECT_NE("Mul", node.op()); + } + } }; -TEST_F(FoldBatchNormsTest, TestFoldBatchNorms) { TestFoldBatchNorms(); } +TEST_F(FoldBatchNormsTest, TestFoldBatchNormsConv2D) { + TestFoldBatchNormsConv2D(); +} +TEST_F(FoldBatchNormsTest, TestFoldBatchNormsMatMul) { + TestFoldBatchNormsMatMul(); +} } // namespace graph_transforms } // namespace tensorflow |