diff options
Diffstat (limited to 'tensorflow/tools/graph_transforms/fold_batch_norms_test.cc')
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_batch_norms_test.cc | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc index ed741f002c..a5d541feb6 100644 --- a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc @@ -87,6 +87,64 @@ class FoldBatchNormsTest : public ::testing::Test { } } + void TestFoldBatchNormsConv2DShared() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2})); + test::FillValues<float>( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_data)); + + Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2})); + test::FillValues<float>(&weights_data, + {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}); + Output weights_op = + Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); + + Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op, + {1, 1, 1, 1}, "VALID"); + + Tensor mul_values_data(DT_FLOAT, TensorShape({2})); + test::FillValues<float>(&mul_values_data, {2.0f, 3.0f}); + Output mul_values_op = Const(root.WithOpName("mul_values"), + Input::Initializer(mul_values_data)); + + Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op); + + Tensor mul_values_data_2(DT_FLOAT, TensorShape({2})); + test::FillValues<float>(&mul_values_data_2, {1.0f, 2.0f}); + Output mul_values_op_2 = Const(root.WithOpName("mul_values_2"), + Input::Initializer(mul_values_data)); + + Output mul_op_2 = + Mul(root.WithOpName("output_2"), conv_op, mul_values_op_2); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + std::unique_ptr<Session> original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector<Tensor> original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output", "output_2"}, {}, + &original_outputs)); + + GraphDef fused_graph_def; + TF_ASSERT_OK(FoldBatchNorms( + original_graph_def, {{}, {"output", "output_2"}}, &fused_graph_def)); + + std::unique_ptr<Session> fused_session(NewSession(SessionOptions())); + TF_ASSERT_OK(fused_session->Create(fused_graph_def)); + std::vector<Tensor> fused_outputs; + TF_ASSERT_OK( + fused_session->Run({}, {"output", "output_2"}, {}, &fused_outputs)); + + test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5); + test::ExpectTensorNear<float>(original_outputs[1], fused_outputs[1], 1e-5); + } + void TestFoldBatchNormsMatMul() { auto root = tensorflow::Scope::NewRootScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) |