aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/lib/util_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib/util_test.cc')
-rw-r--r--tensorflow/compiler/tf2xla/lib/util_test.cc24
1 files changed, 8 insertions, 16 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc
index 2a332c933f..442fe92c34 100644
--- a/tensorflow/compiler/tf2xla/lib/util_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/util_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -70,8 +70,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) {
auto a_data = CreateR2Parameter<float>(BValsRight(), 0, "a", &builder, &a);
auto x_data = CreateR0Parameter<int>(2, 1, "x", &builder, &x);
auto y_data = CreateR0Parameter<int>(1, 2, "y", &builder, &y);
- auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1});
- TF_ASSERT_OK(result.status());
+ DynamicSliceInMinorDims(a, {x, y}, {1, 1});
ComputeAndCompareR2<float>(&builder, {{10}},
{a_data.get(), x_data.get(), y_data.get()},
@@ -86,10 +85,8 @@ XLA_TEST_F(UtilTest, Simple3dLookup) {
CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
- TF_ASSERT_OK(
- DynamicSliceInMinorDims(
- &builder, a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, 4})
- .status());
+ DynamicSliceInMinorDims(a, {index, xla::ConstantR0<int32>(&builder, 0)},
+ {1, 4});
ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}},
{a_data.get(), index_data.get()});
@@ -104,8 +101,7 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) {
auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x);
auto y_data = CreateR0Parameter<int>(1, 3, "y", &builder, &y);
- auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y});
- TF_ASSERT_OK(result.status());
+ DynamicUpdateSliceInMinorDims(a, b, {x, y});
xla::Array2D<float> expected(
{{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}});
@@ -128,13 +124,9 @@ XLA_TEST_F(UtilTest, RowBatchDot) {
// Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
- TF_ASSERT_OK_AND_ASSIGN(
- auto l_index,
- DynamicSliceInMinorDims(
- &builder, a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, n}));
- TF_ASSERT_OK(BatchDot(&builder, l_index, row,
- /*transpose_x=*/false, /*transpose_y=*/true)
- .status());
+ auto l_index = DynamicSliceInMinorDims(
+ a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, n});
+ BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true);
ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
{a_data.get(), row_data.get(), index_data.get()});