aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/batch_norm_op_test.cc
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-05-26 15:14:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-26 16:18:36 -0700
commit36357e7e1127873165694a38e3a989df4e0b6ffe (patch)
treeeaf00810dfaefdc7f308a088829a3798af23e7bc /tensorflow/core/kernels/batch_norm_op_test.cc
parente1b4934bb59904ee4dd243a34cc8356ff6bd266d (diff)
Added support for half floats to the batch normalization op
Change: 123368006
Diffstat (limited to 'tensorflow/core/kernels/batch_norm_op_test.cc')
-rw-r--r--tensorflow/core/kernels/batch_norm_op_test.cc27
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/batch_norm_op_test.cc b/tensorflow/core/kernels/batch_norm_op_test.cc
index e70bcc5b4c..9b7bf6d149 100644
--- a/tensorflow/core/kernels/batch_norm_op_test.cc
+++ b/tensorflow/core/kernels/batch_norm_op_test.cc
@@ -59,4 +59,31 @@ TEST_F(BatchNormOpTest, Simple) {
test::ExpectTensorNear<float>(expected, *GetOutput(0), 0.01);
}
+TEST_F(BatchNormOpTest, Fp16) {
+ TF_EXPECT_OK(
+ NodeDefBuilder("batch_norm_op", "BatchNormWithGlobalNormalization")
+ .Input(FakeInput(DT_HALF))
+ .Input(FakeInput(DT_HALF))
+ .Input(FakeInput(DT_HALF))
+ .Input(FakeInput(DT_HALF))
+ .Input(FakeInput(DT_HALF))
+ .Attr("scale_after_normalization", false)
+ .Attr("variance_epsilon", 0.001)
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOpWithGraphVersion(8));
+ AddInputFromList<Eigen::half>(TensorShape({1, 1, 6, 2}),
+ {1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6});
+ AddInputFromList<Eigen::half>(TensorShape({2}), {10, 20});
+ AddInputFromList<Eigen::half>(TensorShape({2}), {0.25, 0.5});
+ AddInputFromList<Eigen::half>(TensorShape({2}), {0.1, 0.6});
+ AddInputFromList<Eigen::half>(TensorShape({2}), {0.0, 0.0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_HALF, TensorShape({1, 1, 6, 2}));
+ test::FillValues<Eigen::half>(
+ &expected, {-17.86, -22.00, -15.87, -20.59, -13.87, -19.18, -21.86,
+ -33.31, -23.85, -34.72, -25.85, -36.13});
+ test::ExpectTensorNear<Eigen::half>(expected, *GetOutput(0), 0.1);
+}
+
} // namespace tensorflow