aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tools/graph_transforms/fold_batch_norms_test.cc')
-rw-r--r--tensorflow/tools/graph_transforms/fold_batch_norms_test.cc58
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
index ed741f002c..a5d541feb6 100644
--- a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
@@ -87,6 +87,64 @@ class FoldBatchNormsTest : public ::testing::Test {
}
}
+ void TestFoldBatchNormsConv2DShared() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
+ test::FillValues<float>(
+ &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
+ -5.0f, -3.0f, -6.0f});
+ Output input_op =
+ Const(root.WithOpName("input_op"), Input::Initializer(input_data));
+
+ Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
+ test::FillValues<float>(&weights_data,
+ {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
+ Output weights_op =
+ Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
+
+ Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
+ {1, 1, 1, 1}, "VALID");
+
+ Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
+ Output mul_values_op = Const(root.WithOpName("mul_values"),
+ Input::Initializer(mul_values_data));
+
+ Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op);
+
+ Tensor mul_values_data_2(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&mul_values_data_2, {1.0f, 2.0f});
+ Output mul_values_op_2 = Const(root.WithOpName("mul_values_2"),
+ Input::Initializer(mul_values_data));
+
+ Output mul_op_2 =
+ Mul(root.WithOpName("output_2"), conv_op, mul_values_op_2);
+
+ GraphDef original_graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
+
+ std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(original_session->Create(original_graph_def));
+ std::vector<Tensor> original_outputs;
+ TF_ASSERT_OK(original_session->Run({}, {"output", "output_2"}, {},
+ &original_outputs));
+
+ GraphDef fused_graph_def;
+ TF_ASSERT_OK(FoldBatchNorms(
+ original_graph_def, {{}, {"output", "output_2"}}, &fused_graph_def));
+
+ std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(fused_session->Create(fused_graph_def));
+ std::vector<Tensor> fused_outputs;
+ TF_ASSERT_OK(
+ fused_session->Run({}, {"output", "output_2"}, {}, &fused_outputs));
+
+ test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
+ test::ExpectTensorNear<float>(original_outputs[1], fused_outputs[1], 1e-5);
+ }
+
void TestFoldBatchNormsMatMul() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)