diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/batch_normalization_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/batch_normalization_test.cc | 210 |
1 files changed, 210 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc new file mode 100644 index 0000000000..598fd69909 --- /dev/null +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -0,0 +1,210 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cmath> +#include <memory> +#include <vector> + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class BatchNormalizationTest : public ClientLibraryTestBase { + protected: + BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) { + Array2D<float> pz({ + // z0 z1 + {-1.0f, 4.1f}, // p0 + {2.0f, 4.1f}, // p1 + {5.0f, 4.4f}, // p2 + }); + input_array_.FillWithPZ(pz); + input_literal_ = *LiteralUtil::CreateR4FromArray4D(input_array_); + CHECK_EQ(kSamples, input_array_.planes()); + CHECK_EQ(kZ, input_array_.depth()); + CHECK_EQ(kY, input_array_.height()); + CHECK_EQ(kY, input_array_.width()); + } + + static constexpr int64 kSamples = 3; + static constexpr int64 kX = 1; + static constexpr int64 kY = 1; + static constexpr int64 kZ = 2; + + Array4D<float> input_array_; + Literal input_literal_; + const ErrorSpec error_spec_{0.001, 0.001}; +}; + +TEST_F(BatchNormalizationTest, SubtractInZ) { + ComputationBuilder builder(client_, "subtract_in_z_one_sample"); + auto x = builder.ConstantLiteral(input_literal_); + auto y = builder.ConstantR1<float>({3.14, 4.25}); + builder.Sub(x, y, /*broadcast_dimensions=*/{1}); + + Array4D<float> expected(kSamples, kZ, kY, kX); + Array2D<float> pz({ + {-1.0f - 3.14f, 4.1f - 4.25f}, // p0 + {2.0f - 3.14f, 4.1f - 4.25f}, // p1 + {5.0f - 3.14f, 4.4f - 4.25f}, // p2 + }); + expected.FillWithPZ(pz); + ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); +} + +TEST_F(BatchNormalizationTest, SquareTesseractElementwise) { + ComputationBuilder builder(client_, "square_tesseract_elementwise"); + auto x = builder.ConstantLiteral(input_literal_); + builder.SquareF32(x); + + Array4D<float> expected(kSamples, kZ, kY, kX); + Array2D<float> expected_pz({ + {std::pow(-1.0f, 2.0f), std::pow(4.1f, 2.0f)}, + {std::pow(2.0f, 2.0f), std::pow(4.1f, 2.0f)}, + {std::pow(5.0f, 2.0f), std::pow(4.4f, 2.0f)}, + }); + expected.FillWithPZ(expected_pz); + ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); +} + +TEST_F(BatchNormalizationTest, SumToZ) { + ComputationBuilder builder(client_, "sum_to_z"); + auto input_activations = builder.ConstantLiteral(input_literal_); + Computation add = CreateScalarAddComputation(F32, &builder); + // Reduce all but the Z dimension. + builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add, + {0, 2, 3}); + + std::vector<float> expected = {6, 12.6}; + ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); +} + +TEST_F(BatchNormalizationTest, SquareAndReduce) { + ComputationBuilder builder(client_, "square_and_reduce"); + auto input_activations = builder.ConstantLiteral(input_literal_); + auto set_means = builder.ConstantR1<float>({2.f, 4.2f}); + auto activation_deviations = builder.Sub(input_activations, set_means, + /*broadcast_dimensions=*/{1}); + Computation add = CreateScalarAddComputation(F32, &builder); + auto dev_squares = builder.SquareF32(activation_deviations); + auto sum_of_squares = builder.Reduce( + dev_squares, builder.ConstantR0<float>(0.0f), add, {0, 2, 3}); + + std::vector<float> expected = {18, 0.06}; + ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); +} + +TEST_F(BatchNormalizationTest, VarianceToStddev) { + ComputationBuilder builder(client_, "variance_to_stddev"); + auto variance = builder.ConstantR1<float>({6.f, .02f}); + auto sqrt = builder.SqrtF32(variance); + + std::vector<float> expected = {2.44948974f, 0.14142136f}; + ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); +} + +// Compare against a forward batch normalization example in the NN spec +// reference. +TEST_F(BatchNormalizationTest, SpecComparisonForward) { + ComputationBuilder builder(client_, "batch_normalize_per_spec"); + auto input_activations = + builder.CheckShape(builder.ConstantLiteral(input_literal_), + ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); + auto gamma = builder.ConstantR1<float>({1.0, 1.0}); + auto beta = builder.ConstantR1<float>({0.0, 0.0}); + Computation add = CreateScalarAddComputation(F32, &builder); + // Reduce all dimensions except dimension 1. + Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2}); + auto sum = builder.CheckShape( + builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add, + /*dimensions_to_reduce=*/{0, 2, 3}), + TwoElementVectorF32); + auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie(); + auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie(); + auto count = builder.ConstantR0<float>(ShapeUtil::ElementsIn(*input_shape) / + ShapeUtil::ElementsIn(*sum_shape)); + auto set_means = builder.Div(sum, count); + + const float kEpsilon = 1e-9f; + auto epsilon = builder.ConstantR0<float>(kEpsilon); + auto epsilon2 = builder.ConstantR1<float>({kEpsilon, kEpsilon}); + auto activation_deviations = builder.Sub(input_activations, set_means, + /*broadcast_dimensions=*/{1}); + auto dev_squares = builder.SquareF32(activation_deviations); + auto sum_of_squares = builder.CheckShape( + builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add, + /*dimensions_to_reduce=*/{0, 2, 3}), + TwoElementVectorF32); + auto variance = builder.Div(sum_of_squares, count); + auto standard_deviation = builder.SqrtF32(variance); + auto standard_deviation_above_epsilon = builder.CheckShape( + builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); + auto gt_eps = builder.Select(standard_deviation_above_epsilon, + standard_deviation, epsilon2); + auto normalization_factors = builder.ReciprocalF32(gt_eps); + auto normalized_input_activations = + builder.Mul(activation_deviations, normalization_factors, + /*broadcast_dimensions=*/{1}); + /* auto output_activations = */ builder.Add( + builder.Mul(normalized_input_activations, gamma, + /*broadcast_dimensions=*/{1}), + beta, /*broadcast_dimensions=*/{1}); + + Array4D<float> expected(kSamples, kZ, kY, kX); + Array2D<float> pz({ + {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)}, + {0.f, -.1f / std::sqrt(.02f)}, + {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)}, + }); + expected.FillWithPZ(pz); + + ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector<tensorflow::Flag> flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} |