diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/batch_normalization_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/batch_normalization_test.cc | 128 |
1 files changed, 60 insertions, 68 deletions
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index ac90a3adb6..bc2ba151a3 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -63,7 +63,7 @@ class BatchNormalizationTest {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_)); + input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, - {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) - .get(), - LiteralUtil::CreateR1<float>({4, 5}).get(), - LiteralUtil::CreateR1<float>({5, 5}).get()}); + {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}), + LiteralUtil::CreateR1<float>({4, 5}), + LiteralUtil::CreateR1<float>({5, 5})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { @@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, - {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) - .get(), - LiteralUtil::CreateR1<float>({4, 5}).get(), - LiteralUtil::CreateR1<float>({5, 5}).get()}); + {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}), + LiteralUtil::CreateR1<float>({4, 5}), + LiteralUtil::CreateR1<float>({5, 5})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { @@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { BatchNormTraining(h0, h1, h2, /*epsilon=*/1, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)) - .get(), - LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)).get(), - LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f)).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)), + LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)), + LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))}); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {operand.get(), scale.get(), offset.get()}, ErrorSpec(0.1)); } @@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { BatchNormTraining(h0, h1, h2, /*epsilon=*/-100, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR3FromArray3D<float>( - {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) - .get(), - LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)).get(), - LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f)).get()}); + {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}), + LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)), + LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))}); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {operand.get(), scale.get(), offset.get()}, ErrorSpec(0.1)); } @@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, - {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) - .get(), - LiteralUtil::CreateR1<float>({0, 0}).get(), - LiteralUtil::CreateR1<float>({16, 20}).get()}); + {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}), + LiteralUtil::CreateR1<float>({0, 0}), + LiteralUtil::CreateR1<float>({16, 20})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } struct BatchNormTestParam { @@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); auto scale_activations = - Parameter(&builder, 1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal.shape(), "offset"); auto offset_activations = - Parameter(&builder, 2, offset_literal->shape(), "scale"); + Parameter(&builder, 2, offset_literal.shape(), "scale"); - auto expected = LiteralUtil::MakeTuple( - {expected_normalized.get(), LiteralUtil::CreateR1<float>(mean).get(), - LiteralUtil::CreateR1<float>(var).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {expected_normalized, LiteralUtil::CreateR1<float>(mean), + LiteralUtil::CreateR1<float>(var)}); std::unique_ptr<GlobalData> input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr<GlobalData> scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr<GlobalData> offset_data = - client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + client_->TransferToServer(offset_literal).ConsumeValueOrDie(); BatchNormTraining(input_activations, scale_activations, offset_activations, epsilon, feature_index); @@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { // testcase. execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); ComputeAndCompareTuple( - &builder, *expected, + &builder, expected, {input_data.get(), scale_data.get(), offset_data.get()}, ErrorSpec(0.01, 1)); } @@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); auto scale_activations = - Parameter(&builder, 1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal.shape(), "offset"); auto offset_activations = - Parameter(&builder, 2, offset_literal->shape(), "scale"); - auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean"); + Parameter(&builder, 2, offset_literal.shape(), "scale"); + auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean"); auto variance_activations = - Parameter(&builder, 4, var_literal->shape(), "variance"); + Parameter(&builder, 4, var_literal.shape(), "variance"); Array4D<float> expected = normalized; std::unique_ptr<GlobalData> input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr<GlobalData> scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr<GlobalData> offset_data = - client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + client_->TransferToServer(offset_literal).ConsumeValueOrDie(); std::unique_ptr<GlobalData> mean_data = - client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + client_->TransferToServer(mean_literal).ConsumeValueOrDie(); std::unique_ptr<GlobalData> variance_data = - client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + client_->TransferToServer(var_literal).ConsumeValueOrDie(); BatchNormInference(input_activations, scale_activations, offset_activations, mean_activations, variance_activations, epsilon, @@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { auto grad_output_literal = LiteralUtil::CreateR4FromArray4D<float>(grad_output_array); - auto input_parameter = - Parameter(&builder, 0, input_literal->shape(), "input"); - auto scale_parameter = - Parameter(&builder, 1, scale_literal->shape(), "scale"); - auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean"); - auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance"); + auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input"); + auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale"); + auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean"); + auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance"); auto grad_output_parameter = - Parameter(&builder, 4, grad_output_literal->shape(), "grad_output"); + Parameter(&builder, 4, grad_output_literal.shape(), "grad_output"); std::unique_ptr<GlobalData> input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr<GlobalData> scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr<GlobalData> mean_data = - client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + client_->TransferToServer(mean_literal).ConsumeValueOrDie(); std::unique_ptr<GlobalData> var_data = - client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + client_->TransferToServer(var_literal).ConsumeValueOrDie(); std::unique_ptr<GlobalData> grad_output_data = - client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); + client_->TransferToServer(grad_output_literal).ConsumeValueOrDie(); BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter, grad_output_parameter, epsilon, feature_index); - auto expected = - LiteralUtil::MakeTuple({expected_grad_activation.get(), - LiteralUtil::CreateR1<float>(grad_scale).get(), - LiteralUtil::CreateR1<float>(grad_offset).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale), + LiteralUtil::CreateR1<float>(grad_offset)}); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor // testcase. execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {input_data.get(), scale_data.get(), mean_data.get(), var_data.get(), grad_output_data.get()}, ErrorSpec(0.01, 1)); |