aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/dynamic_ops_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc67
1 files changed, 44 insertions, 23 deletions
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index f3c258a4d4..7f6f203a1b 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/local_service.h"
@@ -124,11 +124,11 @@ class DynamicSliceTest : public ClientLibraryTestBase {
// vector<bool> is special so that it cannot be an ArraySlice<bool>, which
// is what the code below wants. So instead we do this.
Literal input_values =
- std::move(*Literal::CreateR1(input_values_int)
+ std::move(*LiteralUtil::CreateR1(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR1(expected_values_int)
+ std::move(*LiteralUtil::CreateR1(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -150,11 +150,11 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR2FromArray2D(input_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR2FromArray2D(expected_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -176,11 +176,11 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR3FromArray3D(input_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR3FromArray3D(expected_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -202,18 +202,28 @@ XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64, float>(); }
+XLA_TEST_F(DynamicSliceTest, UInt32R1OOB) {
+ RunR1<uint32, int32>({0, 1, 2, 3, 4}, {2147483648u}, {2}, {3, 4});
+}
XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2<int32, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64, int32>(); }
+XLA_TEST_F(DynamicSliceTest, UInt32R2OOB) {
+ RunR2<uint32, int32>({{0, 1}, {2, 3}}, {2147483648u, 0}, {1, 1}, {{2}});
+}
XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3<int32, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32, float>(); }
XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB<int32, float>(); }
XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64, float>(); }
+XLA_TEST_F(DynamicSliceTest, UInt32R3OOB) {
+ RunR3<uint32, int32>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}},
+ {2147483648u, 0, 2147483648u}, {1, 1, 1}, {{{5}}});
+}
XLA_TEST_F(DynamicSliceTest, Int32R1Pred) {
// Slice at dimension start.
@@ -349,15 +359,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
void RunR0(int input_value_int, int update_value_int,
const std::vector<IndexT> slice_starts, int expected_value_int) {
Literal input_value =
- std::move(*Literal::CreateR0(input_value_int)
+ std::move(*LiteralUtil::CreateR0(input_value_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_value =
- std::move(*Literal::CreateR0(update_value_int)
+ std::move(*LiteralUtil::CreateR0(update_value_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_value =
- std::move(*Literal::CreateR0(expected_value_int)
+ std::move(*LiteralUtil::CreateR0(expected_value_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -380,15 +390,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
tensorflow::gtl::ArraySlice<int> expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR1(input_values_int)
+ std::move(*LiteralUtil::CreateR1(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
- std::move(*Literal::CreateR1(update_values_int)
+ std::move(*LiteralUtil::CreateR1(update_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR1(expected_values_int)
+ std::move(*LiteralUtil::CreateR1(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -411,15 +421,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR2FromArray2D(input_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
- std::move(*Literal::CreateR2FromArray2D(update_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR2FromArray2D(expected_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -442,15 +452,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR3FromArray3D(input_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
- std::move(*Literal::CreateR3FromArray3D(update_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR3FromArray3D(expected_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -520,7 +530,7 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
template <typename NativeT>
void DumpArray(const string& name, const Array3D<NativeT> values) {
std::unique_ptr<Literal> literal =
- Literal::CreateR3FromArray3D<NativeT>(values);
+ LiteralUtil::CreateR3FromArray3D<NativeT>(values);
LOG(INFO) << name << ":" << literal->ToString();
}
};
@@ -530,21 +540,32 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0<int64, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0<uint64, float>(); }
-// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10.
XLA_TEST_F(DynamicUpdateSliceTest, Int32R1BF16) { TestR1<int32, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1<int64, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1<uint64, float>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, UInt32R1OOB) {
+ RunR1<uint32, int32>({0, 1, 2, 3, 4}, {5, 6}, {2147483648u}, {0, 1, 2, 5, 6});
+}
XLA_TEST_F(DynamicUpdateSliceTest, Int32R2BF16) { TestR2<int32, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2<int64, int64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2<uint64, int32>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, UInt32R2OOB) {
+ RunR2<uint32, int32>({{0, 1}, {2, 3}}, {{4}}, {2147483648u, 0},
+ {{0, 1}, {4, 3}});
+}
XLA_TEST_F(DynamicUpdateSliceTest, Int32R3BF16) { TestR3<int32, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64, int64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64, uint64>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, UInt32R3OOB) {
+ RunR3<uint32, int32>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}, {{{8}}},
+ {2147483648u, 0, 2147483648u},
+ {{{0, 1}, {2, 3}}, {{4, 8}, {6, 7}}});
+}
XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB<int32, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB<int32, float>(); }
@@ -695,7 +716,7 @@ void BM_DynamicSlice(int num_iters) {
XlaBuilder builder("DynamicSlice");
// Create input as a constant: shape [1, 2, 3, 4]
- auto input_literal = Literal::CreateR4(
+ auto input_literal = LiteralUtil::CreateR4(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
auto input = ConstantLiteral(&builder, *input_literal);
@@ -715,7 +736,7 @@ void BM_DynamicSlice(int num_iters) {
start_indices_shape, &allocator, /*device_ordinal=*/0)
.ConsumeValueOrDie();
- auto start_indices_literal = Literal::CreateR1<int32>({0, 1, 2, 3});
+ auto start_indices_literal = LiteralUtil::CreateR1<int32>({0, 1, 2, 3});
auto stream =
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(