/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include #include #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { class SelectTest : public ClientLibraryTestBase { public: ErrorSpec error_spec_{0.0001}; }; TEST_F(SelectTest, SelectScalarF32True) { XlaBuilder builder(TestName()); auto pred = ConstantR0(&builder, true); auto on_true = ConstantR0(&builder, 123.0f); auto on_false = ConstantR0(&builder, 42.0f); Select(pred, on_true, on_false); ComputeAndCompareR0(&builder, 123.0f, {}, error_spec_); } TEST_F(SelectTest, SelectScalarS32True) { XlaBuilder builder(TestName()); auto pred = ConstantR0(&builder, true); auto on_true = ConstantR0(&builder, -42); auto on_false = ConstantR0(&builder, 42); Select(pred, on_true, on_false); ComputeAndCompareR0(&builder, -42, {}); } TEST_F(SelectTest, SelectScalarF32False) { XlaBuilder builder(TestName()); auto pred = ConstantR0(&builder, false); auto on_true = ConstantR0(&builder, 123.0f); auto on_false = ConstantR0(&builder, 42.0f); Select(pred, on_true, on_false); ComputeAndCompareR0(&builder, 42.0f, {}, error_spec_); } XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { XlaBuilder builder(TestName()); auto pred = ConstantR1(&builder, {}); auto on_true = ConstantR1(&builder, {}); auto on_false = ConstantR1(&builder, {}); Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { XlaBuilder builder(TestName()); auto pred = ConstantR1(&builder, {false, true, false, true, false}); auto on_true = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto on_false = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, error_spec_); } XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector // is not a constant, but rather the result of comparing two other vectors. XlaBuilder builder(TestName()); auto v1 = ConstantR1(&builder, {}); auto v2 = ConstantR1(&builder, {}); auto cmp = Eq(v1, v2); auto on_true = ConstantR1(&builder, {}); auto on_false = ConstantR1(&builder, {}); Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is // not a constant, but rather the result of comparing two other vectors. XlaBuilder builder(TestName()); auto v1 = ConstantR1(&builder, {1, 2, 3, 4, 5}); auto v2 = ConstantR1(&builder, {9, 2, 9, 4, 9}); auto cmp = Eq(v1, v2); auto on_true = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto on_false = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s. XlaBuilder builder(TestName()); auto v1 = ConstantR1(&builder, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); auto v2 = ConstantR1(&builder, {-1.0f, -2.0f, 13.0f, 14.0f, 4.4f}); auto cmp = Gt(v1, v2); auto on_true = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto on_false = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { // Selects among two R1F32s, which come from parameters. v1 and v2 are // compared, and selection between them happens based on a gt-comparison mask. XlaBuilder builder(TestName()); XlaOp v1, v2; std::unique_ptr param0_data = CreateR1Parameter( {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); std::unique_ptr param1_data = CreateR1Parameter( {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); auto cmp = Gt(v1, v2); Select(cmp, v1, v2); ComputeAndCompareR1(&builder, {41.0f, 22.0f, 23.0f, 84.0f}, {param0_data.get(), param1_data.get()}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the // data size passed in and out is large. XlaBuilder builder(TestName()); // Number of floats in the data passed into and out of the computation. constexpr int datalen = 15 * 1000; // The inputs are initialized with a special pattern where in the first third // of the data v1[i] > v2[i] and elsewhere it's vice versa. std::vector v1vec; std::vector v2vec; std::vector expected_vec; for (int i = 0; i < datalen; ++i) { float smaller = i; float larger = i * 2; if (i < datalen / 3) { v1vec.push_back(larger); v2vec.push_back(smaller); } else { v1vec.push_back(smaller); v2vec.push_back(larger); } expected_vec.push_back(larger); } XlaOp v1, v2; std::unique_ptr param0_data = CreateR1Parameter(v1vec, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); std::unique_ptr param1_data = CreateR1Parameter(v2vec, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); auto cmp = Gt(v1, v2); Select(cmp, v1, v2); ComputeAndCompareR1(&builder, expected_vec, {param0_data.get(), param1_data.get()}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to // select between two R1F32s. XlaBuilder builder(TestName()); auto v = ConstantR1(&builder, {1, -1, 2, -2}); auto s = ConstantR0(&builder, 0); auto cmp = Gt(v, s); auto on_true = ConstantR1(&builder, {11.0f, 22.0f, 33.0f, 44.0f}); auto on_false = ConstantR1(&builder, {-111.0f, -222.0f, -333.0f, -444.0f}); Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to // select between two R1F32s. XlaBuilder builder(TestName()); auto v = ConstantR1(&builder, {1.0f, 2.0f, 3.0f, 4.0f}); auto s = ConstantR0(&builder, 2.5f); auto cmp = Gt(v, s); auto on_true = ConstantR1(&builder, {11.0f, 22.0f, 33.0f, 44.0f}); auto on_false = ConstantR1(&builder, {-111.0f, -222.0f, -333.0f, -444.0f}); Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {}, error_spec_); } XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { for (bool which : {false, true}) { XlaBuilder builder(TestName()); auto pred = ConstantR0(&builder, which); auto on_true = ConstantR1(&builder, {}); auto on_false = ConstantR1(&builder, {}); Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } } TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) { XlaBuilder builder(TestName()); auto pred = ConstantR0(&builder, true); auto on_true = ConstantR1(&builder, {-2.5f, 25.5f}); auto on_false = ConstantR1(&builder, {10.0f, 5.0f}); Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {-2.5f, 25.5f}, {}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { XlaBuilder builder(TestName()); auto pred = ConstantR0(&builder, false); auto on_true = ConstantR1(&builder, {-2.5f, 25.5f}); auto on_false = ConstantR1(&builder, {10.0f, 5.0f}); Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 5.0f}, {}, error_spec_); } } // namespace } // namespace xla