aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-01 02:29:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-01 02:50:43 -0800
commitec86b037893fb00be8e9c366a5a6196d89a6dd72 (patch)
treee3edafa6703f73a28583080ef23e221ba1bd720e
parent7e48bada5a6c5583ab6e9a103337780863af08cc (diff)
Extends fold_batch_norms transform to also fold the mul introduced by batch normalization after fully connected layers (MatMul).
Change: 148868461
-rw-r--r--tensorflow/tools/graph_transforms/README.md13
-rw-r--r--tensorflow/tools/graph_transforms/fold_batch_norms.cc18
-rw-r--r--tensorflow/tools/graph_transforms/fold_batch_norms_test.cc60
3 files changed, 75 insertions, 16 deletions
diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md
index 8cba86f63e..36a5c01a0c 100644
--- a/tensorflow/tools/graph_transforms/README.md
+++ b/tensorflow/tools/graph_transforms/README.md
@@ -341,12 +341,13 @@ Args: None \
Prerequisites: [fold_constants](#fold_constants)
This transform tries to optimize away the Mul that's introduced after a Conv2D
-when batch normalization has been used during training. It scans the graph for
-any channel-wise multiplies immediately after convolutions, and multiplies the
-convolution's weights with the Mul instead so this can be omitted at inference
-time. You'll need to make sure you run [fold_constants](#fold_constants) first,
-since the pattern can only be spotted if the normal complex expression that's
-produced by training for the Mul input is collapsed down into a simple constant.
+(or a MatMul) when batch normalization has been used during training. It scans
+the graph for any channel-wise multiplies immediately after convolutions, and
+multiplies the convolution's (or matrix multiplication's) weights with the Mul
+instead so this can be omitted at inference time. You'll need to make sure you
+run [fold_constants](#fold_constants) first, since the pattern can only be
+spotted if the normal complex expression that's produced by training for the Mul
+input is collapsed down into a simple constant.
### fold_constants
diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_batch_norms.cc
index 9f3393f126..7daacb55a2 100644
--- a/tensorflow/tools/graph_transforms/fold_batch_norms.cc
+++ b/tensorflow/tools/graph_transforms/fold_batch_norms.cc
@@ -27,23 +27,24 @@ limitations under the License.
namespace tensorflow {
namespace graph_transforms {
-// Converts Conv2D ops followed by column-wise Muls into equivalent ops with the
-// Mul baked into the convolution weights, to save computation during inference.
+// Converts Conv2D or MatMul ops followed by column-wise Muls into equivalent
+// ops with the Mul baked into the convolution weights, to save computation
+// during inference.
Status FoldBatchNorms(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
GraphDef replaced_graph_def;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, // clang-format off
- {"Mul", // mul_node
+ {"Mul", // mul_node
{
- {"Conv2D", // conv_node
+ {"Conv2D|MatMul", // conv_node
{
- {"*"}, // input_node
- {"Const"}, // weights_node
+ {"*"}, // input_node
+ {"Const"}, // weights_node
}
},
- {"Const"}, // mul_values_node
+ {"Const"}, // mul_values_node
}
}, // clang-format on
[](const NodeMatch& match, const std::set<string>& input_nodes,
@@ -61,7 +62,8 @@ Status FoldBatchNorms(const GraphDef& input_graph_def,
// Make sure all the inputs really are vectors, with as many entries as
// there are columns in the weights.
- const int64 weights_cols = weights.shape().dim_size(3);
+ const int weights_cols_index = conv_node.op() == "Conv2D" ? 3 : 1;
+ const int64 weights_cols = weights.shape().dim_size(weights_cols_index);
if ((mul_values.shape().dims() != 1) ||
(mul_values.shape().dim_size(0) != weights_cols)) {
return errors::InvalidArgument(
diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
index b9983fdd0b..ed741f002c 100644
--- a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
+#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
@@ -35,7 +36,7 @@ Status FoldBatchNorms(const GraphDef& input_graph_def,
class FoldBatchNormsTest : public ::testing::Test {
protected:
- void TestFoldBatchNorms() {
+ void TestFoldBatchNormsConv2D() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
@@ -85,9 +86,64 @@ class FoldBatchNormsTest : public ::testing::Test {
EXPECT_NE("Mul", node.op());
}
}
+
+ void TestFoldBatchNormsMatMul() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Tensor input_data(DT_FLOAT, TensorShape({6, 2}));
+ test::FillValues<float>(
+ &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
+ -5.0f, -3.0f, -6.0f});
+ Output input_op =
+ Const(root.WithOpName("input_op"), Input::Initializer(input_data));
+
+ Tensor weights_data(DT_FLOAT, TensorShape({2, 2}));
+ test::FillValues<float>(&weights_data, {1.0f, 2.0f, 0.3f, 0.4f});
+ Output weights_op =
+ Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
+
+ Output matmul_op =
+ MatMul(root.WithOpName("matmul_op"), input_op, weights_op);
+
+ Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
+ Output mul_values_op = Const(root.WithOpName("mul_values"),
+ Input::Initializer(mul_values_data));
+
+ Output mul_op = Mul(root.WithOpName("output"), matmul_op, mul_values_op);
+
+ GraphDef original_graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
+
+ std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(original_session->Create(original_graph_def));
+ std::vector<Tensor> original_outputs;
+ TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
+
+ GraphDef fused_graph_def;
+ TF_ASSERT_OK(
+ FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def));
+
+ std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(fused_session->Create(fused_graph_def));
+ std::vector<Tensor> fused_outputs;
+ TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
+
+ test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
+
+ for (const NodeDef& node : fused_graph_def.node()) {
+ EXPECT_NE("Mul", node.op());
+ }
+ }
};
-TEST_F(FoldBatchNormsTest, TestFoldBatchNorms) { TestFoldBatchNorms(); }
+TEST_F(FoldBatchNormsTest, TestFoldBatchNormsConv2D) {
+ TestFoldBatchNormsConv2D();
+}
+TEST_F(FoldBatchNormsTest, TestFoldBatchNormsMatMul) {
+ TestFoldBatchNormsMatMul();
+}
} // namespace graph_transforms
} // namespace tensorflow