aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/batch_normalization_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/batch_normalization_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc109
1 files changed, 56 insertions, 53 deletions
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index d9d7ba1362..033382708a 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -20,10 +20,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -62,7 +63,7 @@ class BatchNormalizationTest
{5.0f, 4.4f}, // p2
});
input_array_.FillWithPZ(pz);
- input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_));
+ input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_));
CHECK_EQ(kSamples, input_array_.planes());
CHECK_EQ(kZ, input_array_.depth());
CHECK_EQ(kY, input_array_.height());
@@ -118,7 +119,7 @@ XLA_TEST_P(BatchNormalizationTest, SubtractInZ) {
XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) {
XlaBuilder builder("square_tesseract_elementwise");
auto x = ConstantLiteral(&builder, input_literal_);
- SquareF32(x);
+ Square(x);
using tensorflow::MathUtil;
@@ -150,7 +151,7 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) {
auto activation_deviations = Sub(input_activations, set_means,
/*broadcast_dimensions=*/{1});
XlaComputation add = CreateScalarAddComputation(F32, &builder);
- auto dev_squares = SquareF32(activation_deviations);
+ auto dev_squares = Square(activation_deviations);
Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3});
std::vector<float> expected = {18, 0.06};
@@ -160,7 +161,7 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) {
XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) {
XlaBuilder builder("variance_to_stddev");
auto variance = ConstantR1<float>(&builder, {6.f, .02f});
- SqrtF32(variance);
+ Sqrt(variance);
std::vector<float> expected = {2.44948974f, 0.14142136f};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
@@ -195,20 +196,20 @@ XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) {
auto epsilon2 = ConstantR1<float>(&builder, {kEpsilon, kEpsilon});
auto activation_deviations = Sub(input_activations, set_means,
/*broadcast_dimensions=*/{1});
- auto dev_squares = SquareF32(activation_deviations);
+ auto dev_squares = Square(activation_deviations);
auto sum_of_squares =
CheckShape(&builder,
Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add,
/*dimensions_to_reduce=*/{0, 2, 3}),
TwoElementVectorF32);
auto variance = Div(sum_of_squares, count);
- auto standard_deviation = SqrtF32(variance);
+ auto standard_deviation = Sqrt(variance);
auto standard_deviation_above_epsilon =
CheckShape(&builder, Gt(standard_deviation, epsilon),
ShapeUtil::MakeShape(PRED, {2}));
auto gt_eps =
Select(standard_deviation_above_epsilon, standard_deviation, epsilon2);
- auto normalization_factors = ReciprocalF32(gt_eps);
+ auto normalization_factors = Reciprocal(gt_eps);
auto normalized_input_activations =
Mul(activation_deviations, normalization_factors,
/*broadcast_dimensions=*/{1});
@@ -241,12 +242,12 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
- {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
+ {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
.get(),
- Literal::CreateR1<float>({4, 5}).get(),
- Literal::CreateR1<float>({5, 5}).get()});
+ LiteralUtil::CreateR1<float>({4, 5}).get(),
+ LiteralUtil::CreateR1<float>({5, 5}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
}
@@ -266,12 +267,12 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
- {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
+ {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
.get(),
- Literal::CreateR1<float>({4, 5}).get(),
- Literal::CreateR1<float>({5, 5}).get()});
+ LiteralUtil::CreateR1<float>({4, 5}).get(),
+ LiteralUtil::CreateR1<float>({5, 5}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
}
@@ -297,11 +298,11 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/1, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
.get(),
- Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
- Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
ComputeAndCompareTuple(&builder, *expected,
{operand.get(), scale.get(), offset.get()},
@@ -330,11 +331,12 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/-100, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR3FromArray3D<float>(
+ {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
.get(),
- Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
- Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
ComputeAndCompareTuple(&builder, *expected,
{operand.get(), scale.get(), offset.get()},
@@ -361,12 +363,12 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
- {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
+ {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
.get(),
- Literal::CreateR1<float>({0, 0}).get(),
- Literal::CreateR1<float>({16, 20}).get()});
+ LiteralUtil::CreateR1<float>({0, 0}).get(),
+ LiteralUtil::CreateR1<float>({16, 20}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
}
@@ -512,11 +514,12 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
scale4D, offset4D, epsilon);
- auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized);
+ auto expected_normalized =
+ LiteralUtil::CreateR4FromArray4D<float>(normalized);
- auto offset_literal = Literal::CreateR1<float>(offset);
- auto scale_literal = Literal::CreateR1<float>(scale);
- auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
+ auto offset_literal = LiteralUtil::CreateR1<float>(offset);
+ auto scale_literal = LiteralUtil::CreateR1<float>(scale);
+ auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
Parameter(&builder, 0, input_literal->shape(), "input");
@@ -525,9 +528,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
auto offset_activations =
Parameter(&builder, 2, offset_literal->shape(), "scale");
- auto expected = Literal::MakeTuple({expected_normalized.get(),
- Literal::CreateR1<float>(mean).get(),
- Literal::CreateR1<float>(var).get()});
+ auto expected = LiteralUtil::MakeTuple(
+ {expected_normalized.get(), LiteralUtil::CreateR1<float>(mean).get(),
+ LiteralUtil::CreateR1<float>(var).get()});
std::unique_ptr<GlobalData> input_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -612,11 +615,11 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
scale4D, offset4D, epsilon);
- auto offset_literal = Literal::CreateR1<float>(offset);
- auto scale_literal = Literal::CreateR1<float>(scale);
- auto mean_literal = Literal::CreateR1<float>(mean);
- auto var_literal = Literal::CreateR1<float>(var);
- auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
+ auto offset_literal = LiteralUtil::CreateR1<float>(offset);
+ auto scale_literal = LiteralUtil::CreateR1<float>(scale);
+ auto mean_literal = LiteralUtil::CreateR1<float>(mean);
+ auto var_literal = LiteralUtil::CreateR1<float>(var);
+ auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
Parameter(&builder, 0, input_literal->shape(), "input");
@@ -799,14 +802,14 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
});
auto expected_grad_activation =
- Literal::CreateR4FromArray4D<float>(grad_activation);
+ LiteralUtil::CreateR4FromArray4D<float>(grad_activation);
- auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
- auto scale_literal = Literal::CreateR1<float>(scale);
- auto mean_literal = Literal::CreateR1<float>(mean);
- auto var_literal = Literal::CreateR1<float>(var);
+ auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
+ auto scale_literal = LiteralUtil::CreateR1<float>(scale);
+ auto mean_literal = LiteralUtil::CreateR1<float>(mean);
+ auto var_literal = LiteralUtil::CreateR1<float>(var);
auto grad_output_literal =
- Literal::CreateR4FromArray4D<float>(grad_output_array);
+ LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
auto input_parameter =
Parameter(&builder, 0, input_literal->shape(), "input");
@@ -832,9 +835,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
grad_output_parameter, epsilon, feature_index);
auto expected =
- Literal::MakeTuple({expected_grad_activation.get(),
- Literal::CreateR1<float>(grad_scale).get(),
- Literal::CreateR1<float>(grad_offset).get()});
+ LiteralUtil::MakeTuple({expected_grad_activation.get(),
+ LiteralUtil::CreateR1<float>(grad_scale).get(),
+ LiteralUtil::CreateR1<float>(grad_offset).get()});
// Run all HLO passes during this test. In particular, ClientLibraryTestBase
// disables constant folding, but we want it enabled for our zero-sized tensor