aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/convolution_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/convolution_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc68
1 files changed, 34 insertions, 34 deletions
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 0f6d54d042..a8b8f74ca9 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
};
@@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
{7.0f, 8.0f},
}));
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
};
@@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
};
@@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
{{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
// clang-format on
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
};
@@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
Array3D<float> expected({{{510, 610, 710, 810}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -434,15 +434,15 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
iota(input_elems.begin(), input_elems.end(), 1.0f);
- auto input_r1 = Literal::CreateR1<float>(input_elems);
+ auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota(filter_elems.begin(), filter_elems.end(), 1.0f);
- auto filter_r1 = Literal::CreateR1<float>(filter_elems);
+ auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
- auto expected_r1 = Literal::CreateR1<float>(
+ auto expected_r1 = LiteralUtil::CreateR1<float>(
{19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
@@ -497,15 +497,15 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
- auto input_r1 = Literal::CreateR1<T>(input_elems);
+ auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
- auto filter_r1 = Literal::CreateR1<T>(filter_elems);
+ auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
- auto expected_r1 = Literal::CreateR1<T>(
+ auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
@@ -561,8 +561,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
expected_result.Fill(0);
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(param0)),
- std::move(*Literal::CreateFromArray(param1))},
+ {std::move(*LiteralUtil::CreateFromArray(param0)),
+ std::move(*LiteralUtil::CreateFromArray(param1))},
error_spec_);
}
@@ -617,18 +617,18 @@ class Convolve1D1WindowTestBase
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
static_cast<T>(1.0f));
- auto input_r1 = Literal::CreateR1<T>(input_elems);
+ auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
static_cast<T>(1.0f));
- auto filter_r1 = Literal::CreateR1<T>(filter_elems);
+ auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
std::vector<T> expect_elems(batch * output_feature * num_windows,
static_cast<T>(window_size * input_feature));
- auto expected_r1 = Literal::CreateR1<T>(expect_elems);
+ auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
auto expected_r3 =
expected_r1->Reshape({batch, num_windows, output_feature})
.ConsumeValueOrDie();
@@ -737,8 +737,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
}));
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
@@ -761,8 +761,8 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
filter_data.FillIota(10);
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))});
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))});
}
} // namespace