aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/map_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/map_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc150
1 files changed, 75 insertions, 75 deletions
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 0732e195d4..4d327a6fe9 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -169,11 +169,11 @@ class MapTest : public ClientLibraryTestBase {
TEST_F(MapTest, MapEachElemPlusOneR0) {
// Applies lambda (x) (+ x 1)) to an input scalar.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(42.0);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(42.0);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {});
ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
@@ -183,11 +183,11 @@ TEST_F(MapTest, MapEachElemPlusOneR0) {
XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
@@ -197,12 +197,12 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
TEST_F(MapTest, MapEachElemPlusOneR1S4) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
@@ -211,12 +211,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) {
TEST_F(MapTest, MapEachF32ElementToS32Constant) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateScalarOne<int32>(), {0});
ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
@@ -224,12 +224,12 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) {
TEST_F(MapTest, MapEachF32ElementToU32Constant) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateScalarOne<uint32>(), {0});
ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
@@ -238,12 +238,12 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) {
TEST_F(MapTest, MapEachElemLongerChainR1) {
// Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOneTimesItself(), {0});
ComputeAndCompareR1<float>(
@@ -255,11 +255,11 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
Map(&builder, {map1}, CreateMulByTwo(), {0});
@@ -271,12 +271,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
Map(&builder, {map1}, CreateMulByTwo(), {0});
@@ -287,12 +287,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
TEST_F(MapTest, MapEachElemPlusOneR2) {
// Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+ Literal param0_literal = LiteralUtil::CreateR2<float>(
{{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0, 1});
Array2D<float> expected_array(
@@ -342,17 +342,17 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) {
TEST_F(MapTest, MapBinaryAdder) {
// Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder),
{0});
@@ -365,18 +365,18 @@ TEST_F(MapTest, MapBinaryAdder) {
// for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2WithLayout(
+ Literal param0_literal = LiteralUtil::CreateR2WithLayout(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR2WithLayout(
+ Literal param1_literal = LiteralUtil::CreateR2WithLayout(
{{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
{0, 1});
@@ -391,18 +391,18 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XLA_TEST_F(MapTest, AddR3_3x0x2) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ Literal param1_literal =
LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
{0, 1, 2});
@@ -413,22 +413,22 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) {
TEST_F(MapTest, MapTernaryAdder) {
// Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param2_literal =
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
+ Literal param2_literal =
LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
std::unique_ptr<GlobalData> param2_data =
- client_->TransferToServer(*param2_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param2_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
- auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
+ auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2");
Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0});
ComputeAndCompareR1<float>(
@@ -475,17 +475,17 @@ TEST_F(MapTest, MapOperantionWithBuildError) {
Add(x, y);
auto error_add = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, error_add, {0});
StatusOr<XlaComputation> computation_status = builder.Build();
@@ -513,15 +513,15 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) {
Pow(x, y);
auto power = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, power, {});
ComputeAndCompareR0<float>(&builder, 32.0f,
@@ -540,15 +540,15 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
Sub(y, x); // note that this is y - x, not x - y
auto sub_opposite = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, sub_opposite, {});
ComputeAndCompareR0<float>(
@@ -565,11 +565,11 @@ TEST_F(MapTestWithFullOpt, MapSquare) {
Mul(x, x);
auto square = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(10.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(10.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param0}, square, {});
ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},