aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/batch_normalization_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/batch_normalization_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc128
1 files changed, 60 insertions, 68 deletions
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index ac90a3adb6..bc2ba151a3 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -63,7 +63,7 @@ class BatchNormalizationTest
{5.0f, 4.4f}, // p2
});
input_array_.FillWithPZ(pz);
- input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_));
+ input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_);
CHECK_EQ(kSamples, input_array_.planes());
CHECK_EQ(kZ, input_array_.depth());
CHECK_EQ(kY, input_array_.height());
@@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
- {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({4, 5}).get(),
- LiteralUtil::CreateR1<float>({5, 5}).get()});
+ {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}),
+ LiteralUtil::CreateR1<float>({4, 5}),
+ LiteralUtil::CreateR1<float>({5, 5})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
@@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
- {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({4, 5}).get(),
- LiteralUtil::CreateR1<float>({5, 5}).get()});
+ {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}),
+ LiteralUtil::CreateR1<float>({4, 5}),
+ LiteralUtil::CreateR1<float>({5, 5})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
@@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/1, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
- .get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))});
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
@@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/-100, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR3FromArray3D<float>(
- {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
- .get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
+ {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}),
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))});
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
@@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
- {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({0, 0}).get(),
- LiteralUtil::CreateR1<float>({16, 20}).get()});
+ {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}),
+ LiteralUtil::CreateR1<float>({0, 0}),
+ LiteralUtil::CreateR1<float>({16, 20})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
struct BatchNormTestParam {
@@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
auto scale_activations =
- Parameter(&builder, 1, scale_literal->shape(), "offset");
+ Parameter(&builder, 1, scale_literal.shape(), "offset");
auto offset_activations =
- Parameter(&builder, 2, offset_literal->shape(), "scale");
+ Parameter(&builder, 2, offset_literal.shape(), "scale");
- auto expected = LiteralUtil::MakeTuple(
- {expected_normalized.get(), LiteralUtil::CreateR1<float>(mean).get(),
- LiteralUtil::CreateR1<float>(var).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {expected_normalized, LiteralUtil::CreateR1<float>(mean),
+ LiteralUtil::CreateR1<float>(var)});
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
- client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
+ client_->TransferToServer(offset_literal).ConsumeValueOrDie();
BatchNormTraining(input_activations, scale_activations, offset_activations,
epsilon, feature_index);
@@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
ComputeAndCompareTuple(
- &builder, *expected,
+ &builder, expected,
{input_data.get(), scale_data.get(), offset_data.get()},
ErrorSpec(0.01, 1));
}
@@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
auto scale_activations =
- Parameter(&builder, 1, scale_literal->shape(), "offset");
+ Parameter(&builder, 1, scale_literal.shape(), "offset");
auto offset_activations =
- Parameter(&builder, 2, offset_literal->shape(), "scale");
- auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean");
+ Parameter(&builder, 2, offset_literal.shape(), "scale");
+ auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean");
auto variance_activations =
- Parameter(&builder, 4, var_literal->shape(), "variance");
+ Parameter(&builder, 4, var_literal.shape(), "variance");
Array4D<float> expected = normalized;
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
- client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
+ client_->TransferToServer(offset_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
- client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+ client_->TransferToServer(mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> variance_data =
- client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+ client_->TransferToServer(var_literal).ConsumeValueOrDie();
BatchNormInference(input_activations, scale_activations, offset_activations,
mean_activations, variance_activations, epsilon,
@@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
auto grad_output_literal =
LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
- auto input_parameter =
- Parameter(&builder, 0, input_literal->shape(), "input");
- auto scale_parameter =
- Parameter(&builder, 1, scale_literal->shape(), "scale");
- auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean");
- auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance");
+ auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input");
+ auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale");
+ auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean");
+ auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance");
auto grad_output_parameter =
- Parameter(&builder, 4, grad_output_literal->shape(), "grad_output");
+ Parameter(&builder, 4, grad_output_literal.shape(), "grad_output");
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
- client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+ client_->TransferToServer(mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> var_data =
- client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+ client_->TransferToServer(var_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> grad_output_data =
- client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
+ client_->TransferToServer(grad_output_literal).ConsumeValueOrDie();
BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
grad_output_parameter, epsilon, feature_index);
- auto expected =
- LiteralUtil::MakeTuple({expected_grad_activation.get(),
- LiteralUtil::CreateR1<float>(grad_scale).get(),
- LiteralUtil::CreateR1<float>(grad_offset).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale),
+ LiteralUtil::CreateR1<float>(grad_offset)});
// Run all HLO passes during this test. In particular, ClientLibraryTestBase
// disables constant folding, but we want it enabled for our zero-sized tensor
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{input_data.get(), scale_data.get(), mean_data.get(),
var_data.get(), grad_output_data.get()},
ErrorSpec(0.01, 1));