diff options
author | 2016-05-26 15:14:00 -0800 | |
---|---|---|
committer | 2016-05-26 16:18:36 -0700 | |
commit | 36357e7e1127873165694a38e3a989df4e0b6ffe (patch) | |
tree | eaf00810dfaefdc7f308a088829a3798af23e7bc /tensorflow/core/kernels/batch_norm_op_test.cc | |
parent | e1b4934bb59904ee4dd243a34cc8356ff6bd266d (diff) |
Added support for half floats to the batch normalization op
Change: 123368006
Diffstat (limited to 'tensorflow/core/kernels/batch_norm_op_test.cc')
-rw-r--r-- | tensorflow/core/kernels/batch_norm_op_test.cc | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/batch_norm_op_test.cc b/tensorflow/core/kernels/batch_norm_op_test.cc index e70bcc5b4c..9b7bf6d149 100644 --- a/tensorflow/core/kernels/batch_norm_op_test.cc +++ b/tensorflow/core/kernels/batch_norm_op_test.cc @@ -59,4 +59,31 @@ TEST_F(BatchNormOpTest, Simple) { test::ExpectTensorNear<float>(expected, *GetOutput(0), 0.01); } +TEST_F(BatchNormOpTest, Fp16) { + TF_EXPECT_OK( + NodeDefBuilder("batch_norm_op", "BatchNormWithGlobalNormalization") + .Input(FakeInput(DT_HALF)) + .Input(FakeInput(DT_HALF)) + .Input(FakeInput(DT_HALF)) + .Input(FakeInput(DT_HALF)) + .Input(FakeInput(DT_HALF)) + .Attr("scale_after_normalization", false) + .Attr("variance_epsilon", 0.001) + .Finalize(node_def())); + TF_EXPECT_OK(InitOpWithGraphVersion(8)); + AddInputFromList<Eigen::half>(TensorShape({1, 1, 6, 2}), + {1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6}); + AddInputFromList<Eigen::half>(TensorShape({2}), {10, 20}); + AddInputFromList<Eigen::half>(TensorShape({2}), {0.25, 0.5}); + AddInputFromList<Eigen::half>(TensorShape({2}), {0.1, 0.6}); + AddInputFromList<Eigen::half>(TensorShape({2}), {0.0, 0.0}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_HALF, TensorShape({1, 1, 6, 2})); + test::FillValues<Eigen::half>( + &expected, {-17.86, -22.00, -15.87, -20.59, -13.87, -19.18, -21.86, + -33.31, -23.85, -34.72, -25.85, -36.13}); + test::ExpectTensorNear<Eigen::half>(expected, *GetOutput(0), 0.1); +} + } // namespace tensorflow |