aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/gather_operation_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/gather_operation_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc199
1 files changed, 118 insertions, 81 deletions
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index b8404826b1..2008d69237 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -22,9 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-// NB! TODO(b/74360564): These tests do not test out of bounds behavior since
-// that hasn't been specced yet.
-
namespace xla {
namespace {
@@ -63,8 +61,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -84,8 +83,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -105,9 +105,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 2}, {2, 1}});
+ LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -127,9 +127,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
+ LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -149,9 +149,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
+ LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -171,11 +171,11 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -195,11 +195,11 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -219,8 +219,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -240,9 +241,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{2, 1}, {1, 1}});
+ LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -261,18 +262,15 @@ ENTRY main {
window_bounds={1, 0}
}
)";
- std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) {
// Out of bounds indices must not crash, and the indices in range should
// produce the same values across all backends.
- //
- // TODO(b/74360564): Once we have a well defined semantics for OOB accesses,
- // we should get rid of the mask and check that backends produce the same
- // value for OOB indices too.
const string hlo_text = R"(
HloModule BatchDynamicSlice
@@ -286,29 +284,45 @@ ENTRY main {
gather_dims_to_operand_dims={0,1},
index_vector_dim=1,
window_bounds={1,1}
- gather_reshaped = s32[6]{0} reshape(gather)
- in_bounds_mask = s32[6]{0} parameter(2)
- ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask)
+ ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR2<int32>(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- std::unique_ptr<Literal> in_bounds_mask =
- Literal::CreateR1<int32>({0, 1, 1, 0, 0, 1});
+ RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) {
+ // Out of bounds indices must not crash, and the indices in range should
+ // produce the same values across all backends.
- RunTest(hlo_text,
- {operand.get(), gather_indices.get(), in_bounds_mask.get()});
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = u32[6,2]{1,0} parameter(1)
+ gather = s32[6,1,1]{2,1,0} gather(operand, indices),
+ output_window_dims={1,2},
+ elided_window_dims={},
+ gather_dims_to_operand_dims={0,1},
+ index_vector_dim=1,
+ window_bounds={1,1}
+ ROOT result = s32[6]{0} reshape(gather)
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<uint32>(
+ {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
+ RunTest(hlo_text, operand.get(), gather_indices.get());
}
XLA_TEST_F(GatherOperationTest, NegativeIndex) {
// Negative indices must not crash, and the indices in range should produce
// the same values across all backends.
- //
- // TODO(b/74360564): Once we have a well defined semantics for negative
- // accesses, we should get rid of the mask and check that backends produce the
- // same value for negative indices too.
const string hlo_text = R"(
HloModule BatchDynamicSlice
@@ -322,20 +336,40 @@ ENTRY main {
gather_dims_to_operand_dims={0,1},
index_vector_dim=1,
window_bounds={1,1}
- gather_reshaped = s32[6]{0} reshape(gather)
- in_bounds_mask = s32[6]{0} parameter(2)
- ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask)
+ ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR2<int32>(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- std::unique_ptr<Literal> in_bounds_mask =
- Literal::CreateR1<int32>({0, 1, 1, 0, 0, 1});
+ RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) {
+ // Negative indices must not crash, and the indices in range should produce
+ // the same values across all backends.
- RunTest(hlo_text,
- {operand.get(), gather_indices.get(), in_bounds_mask.get()});
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+ENTRY main {
+ operand = u32[3,3]{1,0} parameter(0)
+ indices = s32[6,2]{1,0} parameter(1)
+ gather = u32[6,1,1]{2,1,0} gather(operand, indices),
+ output_window_dims={1,2},
+ elided_window_dims={},
+ gather_dims_to_operand_dims={0,1},
+ index_vector_dim=1,
+ window_bounds={1,1}
+ ROOT result = u32[6]{0} reshape(gather)
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
+ RunTest(hlo_text, operand.get(), gather_indices.get());
}
XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
@@ -353,9 +387,9 @@ ENTRY main {
window_bounds={1,3,2}
}
)";
- std::unique_ptr<Literal> operand = Literal::CreateR3<int32>(
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR0<int32>(1);
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -374,8 +408,8 @@ ENTRY main {
window_bounds={1}
}
)";
- std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR0<int32>(1);
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -395,8 +429,8 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR1<int32>({});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -419,8 +453,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -443,9 +478,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 2}, {2, 1}});
+ LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -468,9 +503,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
+ LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -493,11 +528,11 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -521,11 +556,11 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -548,8 +583,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -572,9 +608,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{2, 1}, {1, 1}});
+ LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -609,12 +645,13 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
Gather(operand, indices, dim_numbers, {1, 3});
std::vector<int32> expected = {};
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> operand_arg,
- client_->TransferToServer(*Literal::CreateR2<int32>(
- {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<GlobalData> operand_arg,
+ client_->TransferToServer(
+ *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> indices_arg,
- client_->TransferToServer(*Literal::CreateR1<int32>({0, 2})));
+ client_->TransferToServer(*LiteralUtil::CreateR1<int32>({0, 2})));
TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
client_->GetDeviceHandles(1));
xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions();