aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc')
-rw-r--r--tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc97
1 files changed, 96 insertions, 1 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
index b30ba9ac8b..7651a03fe5 100644
--- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -298,6 +299,96 @@ class FoldOldBatchNormsTest : public ::testing::Test {
}
};
+void TestFoldFusedBatchNormsWithBatchToSpace() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Tensor input_data(DT_FLOAT, TensorShape({2, 1, 3, 2}));
+ test::FillValues<float>(
+ &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
+ -5.0f, -3.0f, -6.0f});
+ Output input_op =
+ Const(root.WithOpName("input_op"), Input::Initializer(input_data));
+
+ Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
+ test::FillValues<float>(&weights_data,
+ {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
+ Output weights_op =
+ Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
+
+ Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
+ {1, 1, 1, 1}, "VALID");
+
+ Tensor block_shape_data(DT_INT32, TensorShape({2}));
+ test::FillValues<int32>(&block_shape_data, {1, 2});
+ Output block_shape_op =
+ Const(root.WithOpName("block_shape_op"), Input::Initializer(block_shape_data));
+
+ Tensor crops_data(DT_INT32, TensorShape({2, 2}));
+ test::FillValues<int32>(&crops_data, {0, 0, 0, 1});
+ Output crops_op =
+ Const(root.WithOpName("crops_op"), Input::Initializer(crops_data));
+
+ Output batch_to_space_op = BatchToSpaceND(root.WithOpName("batch_to_space_op"),
+ conv_op, block_shape_op, crops_data);
+
+ Tensor mean_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&mean_data, {10.0f, 20.0f});
+ Output mean_op =
+ Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
+
+ Tensor variance_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&variance_data, {0.25f, 0.5f});
+ Output variance_op = Const(root.WithOpName("variance_op"),
+ Input::Initializer(variance_data));
+
+ Tensor beta_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&beta_data, {0.1f, 0.6f});
+ Output beta_op =
+ Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
+
+ Tensor gamma_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&gamma_data, {1.0f, 2.0f});
+ Output gamma_op =
+ Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
+
+ GraphDef original_graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
+
+ NodeDef batch_norm_node;
+ batch_norm_node.set_op("FusedBatchNorm");
+ batch_norm_node.set_name("output");
+ AddNodeInput("batch_to_space_op", &batch_norm_node);
+ AddNodeInput("gamma_op", &batch_norm_node);
+ AddNodeInput("beta_op", &batch_norm_node);
+ AddNodeInput("mean_op", &batch_norm_node);
+ AddNodeInput("variance_op", &batch_norm_node);
+ SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
+ SetNodeAttr("epsilon", 0.00001f, &batch_norm_node);
+ SetNodeAttr("is_training", false, &batch_norm_node);
+ *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
+
+ std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(original_session->Create(original_graph_def));
+ std::vector<Tensor> original_outputs;
+ TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
+
+ GraphDef fused_graph_def;
+ TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
+ &fused_graph_def));
+
+ std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(fused_session->Create(fused_graph_def));
+ std::vector<Tensor> fused_outputs;
+ TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
+
+ test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
+
+ for (const NodeDef& node : fused_graph_def.node()) {
+ EXPECT_NE("FusedBatchNormWithBatchToSpace", node.op());
+ }
+}
+
TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNorms) {
TestFoldOldBatchNorms();
}
@@ -307,7 +398,7 @@ TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNorms) {
}
TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithConcat) {
- // Test axis is not 3, so all weigths and offsets are fused to each of inputs
+ // Test axis is not 3, so all weights and offsets are fused to each of inputs
// of conv2d.
TestFoldFusedBatchNormsWithConcat(/*split=*/true);
// Test axis = 3, BatchNorm weights and offsets will be split before fused
@@ -315,5 +406,9 @@ TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithConcat) {
TestFoldFusedBatchNormsWithConcat(/*split=*/false);
}
+TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithBatchToSpace) {
+ TestFoldFusedBatchNormsWithBatchToSpace();
+}
+
} // namespace graph_transforms
} // namespace tensorflow