diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-09-28 00:31:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-28 00:36:05 -0700 |
commit | c9435befb1bd50ad550deaebfac272eb97da7780 (patch) | |
tree | 64559a21b4d2cb00d6ea26f2fbf20853adabc58b /tensorflow/tools/graph_transforms | |
parent | 35a162a8ee61b6d3fadc6c108ce97446bbb6afd8 (diff) |
Don't fold batch norm calculations if weights are used somewhere else in the graph.
PiperOrigin-RevId: 170309345
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_batch_norms.cc | 11 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_batch_norms_test.cc | 58 |
2 files changed, 69 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_batch_norms.cc index 2ff3bb641e..975b17380f 100644 --- a/tensorflow/tools/graph_transforms/fold_batch_norms.cc +++ b/tensorflow/tools/graph_transforms/fold_batch_norms.cc @@ -57,6 +57,17 @@ Status FoldBatchNorms(const GraphDef& input_graph_def, const NodeDef& weights_node = match.inputs[0].inputs[1].node; const NodeDef& mul_values_node = match.inputs[1].node; + // Check that nodes that we use are not used somewhere else. + for (const auto& node : {conv_node, weights_node, mul_values_node}) { + if (output_nodes.count(node.name())) { + // Return original nodes. + new_nodes->insert(new_nodes->end(), + {mul_node, conv_node, input_node, weights_node, + mul_values_node}); + return Status::OK(); + } + } + Tensor weights = GetNodeTensorAttr(weights_node, "value"); Tensor mul_values = GetNodeTensorAttr(mul_values_node, "value"); diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc index ed741f002c..a5d541feb6 100644 --- a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc @@ -87,6 +87,64 @@ class FoldBatchNormsTest : public ::testing::Test { } } + void TestFoldBatchNormsConv2DShared() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2})); + test::FillValues<float>( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_data)); + + Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2})); + test::FillValues<float>(&weights_data, + {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}); + Output weights_op = + Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); + + Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op, + {1, 1, 1, 1}, "VALID"); + + Tensor mul_values_data(DT_FLOAT, TensorShape({2})); + test::FillValues<float>(&mul_values_data, {2.0f, 3.0f}); + Output mul_values_op = Const(root.WithOpName("mul_values"), + Input::Initializer(mul_values_data)); + + Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op); + + Tensor mul_values_data_2(DT_FLOAT, TensorShape({2})); + test::FillValues<float>(&mul_values_data_2, {1.0f, 2.0f}); + Output mul_values_op_2 = Const(root.WithOpName("mul_values_2"), + Input::Initializer(mul_values_data)); + + Output mul_op_2 = + Mul(root.WithOpName("output_2"), conv_op, mul_values_op_2); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + std::unique_ptr<Session> original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector<Tensor> original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output", "output_2"}, {}, + &original_outputs)); + + GraphDef fused_graph_def; + TF_ASSERT_OK(FoldBatchNorms( + original_graph_def, {{}, {"output", "output_2"}}, &fused_graph_def)); + + std::unique_ptr<Session> fused_session(NewSession(SessionOptions())); + TF_ASSERT_OK(fused_session->Create(fused_graph_def)); + std::vector<Tensor> fused_outputs; + TF_ASSERT_OK( + fused_session->Run({}, {"output", "output_2"}, {}, &fused_outputs)); + + test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5); + test::ExpectTensorNear<float>(original_outputs[1], fused_outputs[1], 1e-5); + } + void TestFoldBatchNormsMatMul() { auto root = tensorflow::Scope::NewRootScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) |