aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/select_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/select_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/select_test.cc276
1 files changed, 276 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc
new file mode 100644
index 0000000000..5ec9ac95fa
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/select_test.cc
@@ -0,0 +1,276 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class SelectTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+};
+
+TEST_F(SelectTest, SelectScalarF32True) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(true);
+ auto on_true = builder.ConstantR0<float>(123.0f);
+ auto on_false = builder.ConstantR0<float>(42.0f);
+ auto result = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
+}
+
+TEST_F(SelectTest, SelectScalarS32True) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(true);
+ auto on_true = builder.ConstantR0<int32>(-42);
+ auto on_false = builder.ConstantR0<int32>(42);
+ auto result = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR0<int32>(&builder, -42, {});
+}
+
+TEST_F(SelectTest, SelectScalarF32False) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(false);
+ auto on_true = builder.ConstantR0<float>(123.0f);
+ auto on_false = builder.ConstantR0<float>(42.0f);
+ auto result = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
+}
+
+XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR1<bool>({});
+ auto on_true = builder.ConstantR1<float>({});
+ auto on_false = builder.ConstantR1<float>({});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR1<bool>({false, true, false, true, false});
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) {
+ // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector
+ // is not a constant, but rather the result of comparing two other vectors.
+ ComputationBuilder builder(client_, TestName());
+ auto v1 = builder.ConstantR1<int32>({});
+ auto v2 = builder.ConstantR1<int32>({});
+ auto cmp = builder.Eq(v1, v2);
+ auto on_true = builder.ConstantR1<float>({});
+ auto on_false = builder.ConstantR1<float>({});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) {
+ // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is
+ // not a constant, but rather the result of comparing two other vectors.
+ ComputationBuilder builder(client_, TestName());
+ auto v1 = builder.ConstantR1<int32>({1, 2, 3, 4, 5});
+ auto v2 = builder.ConstantR1<int32>({9, 2, 9, 4, 9});
+ auto cmp = builder.Eq(v1, v2);
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) {
+ // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s.
+ ComputationBuilder builder(client_, TestName());
+ auto v1 = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
+ auto v2 = builder.ConstantR1<float>({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f});
+ auto cmp = builder.Gt(v1, v2);
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) {
+ // Selects among two R1F32s, which come from parameters. v1 and v2 are
+ // compared, and selection between them happens based on a gt-comparison mask.
+ ComputationBuilder builder(client_, TestName());
+
+ ComputationDataHandle v1, v2;
+ std::unique_ptr<GlobalData> param0_data = CreateR1Parameter<float>(
+ {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1",
+ /*builder=*/&builder, /*data_handle=*/&v1);
+ std::unique_ptr<GlobalData> param1_data = CreateR1Parameter<float>(
+ {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
+ /*builder=*/&builder, /*data_handle=*/&v2);
+
+ auto cmp = builder.Gt(v1, v2);
+ auto select = builder.Select(cmp, v1, v2);
+ ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
+ {param0_data.get(), param1_data.get()},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) {
+ // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the
+ // data size passed in and out is large.
+ ComputationBuilder builder(client_, TestName());
+
+ // Number of floats in the data passed into and out of the computation.
+ constexpr int datalen = 15 * 1000;
+
+ // The inputs are initialized with a special pattern where in the first third
+ // of the data v1[i] > v2[i] and elsewhere it's vice versa.
+ std::vector<float> v1vec;
+ std::vector<float> v2vec;
+ std::vector<float> expected_vec;
+ for (int i = 0; i < datalen; ++i) {
+ float smaller = i;
+ float larger = i * 2;
+ if (i < datalen / 3) {
+ v1vec.push_back(larger);
+ v2vec.push_back(smaller);
+ } else {
+ v1vec.push_back(smaller);
+ v2vec.push_back(larger);
+ }
+ expected_vec.push_back(larger);
+ }
+
+ ComputationDataHandle v1, v2;
+ std::unique_ptr<GlobalData> param0_data =
+ CreateR1Parameter<float>(v1vec, /*parameter_number=*/0, /*name=*/"v1",
+ /*builder=*/&builder, /*data_handle=*/&v1);
+ std::unique_ptr<GlobalData> param1_data =
+ CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
+ /*builder=*/&builder, /*data_handle=*/&v2);
+
+ auto cmp = builder.Gt(v1, v2);
+ auto select = builder.Select(cmp, v1, v2);
+ ComputeAndCompareR1<float>(&builder, expected_vec,
+ {param0_data.get(), param1_data.get()},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) {
+ // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to
+ // select between two R1F32s.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({1, -1, 2, -2});
+ auto s = builder.ConstantR0<int32>(0);
+ auto cmp = builder.Gt(v, s);
+
+ auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
+ auto on_false =
+ builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) {
+ // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to
+ // select between two R1F32s.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
+ auto s = builder.ConstantR0<float>(2.5f);
+ auto cmp = builder.Gt(v, s);
+
+ auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
+ auto on_false =
+ builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) {
+ for (bool which : {false, true}) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(which);
+ auto on_true = builder.ConstantR1<float>({});
+ auto on_false = builder.ConstantR1<float>({});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+ }
+}
+
+TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(true);
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {}, error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(false);
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {10.0f, 5.0f}, {}, error_spec_);
+}
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}