aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/literal_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/literal_test.cc')
-rw-r--r--tensorflow/compiler/xla/literal_test.cc913
1 files changed, 448 insertions, 465 deletions
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index 1a64594db8..7ad287c897 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -92,48 +92,48 @@ class LiteralUtilTest : public ::testing::Test {
Layout layout_r3_dim0minor_;
Layout layout_r4_dim0major_;
Layout layout_r4_dim0minor_;
- std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0major_;
- std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0minor_;
+ Literal literal_r4_2x2x3x3_dim0major_;
+ Literal literal_r4_2x2x3x3_dim0minor_;
};
TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto true_lit = LiteralUtil::CreateR0<bool>(true);
- EXPECT_EQ("true", true_lit->ToString());
+ EXPECT_EQ("true", true_lit.ToString());
auto false_lit = LiteralUtil::CreateR0<bool>(false);
- EXPECT_EQ("false", false_lit->ToString());
+ EXPECT_EQ("false", false_lit.ToString());
auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
- EXPECT_EQ("42", u32_lit->ToString());
+ EXPECT_EQ("42", u32_lit.ToString());
auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
- EXPECT_EQ("-999", s32_lit->ToString());
+ EXPECT_EQ("-999", s32_lit.ToString());
auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
- EXPECT_EQ("3.14", f32_lit->ToString());
+ EXPECT_EQ("3.14", f32_lit.ToString());
auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
- EXPECT_EQ("0.5", f16_lit->ToString());
+ EXPECT_EQ("0.5", f16_lit.ToString());
auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
- EXPECT_EQ("(3.14, 2.78)", c64_lit->ToString());
+ EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString());
auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
- EXPECT_EQ("0.5", bf16_lit->ToString());
+ EXPECT_EQ("0.5", bf16_lit.ToString());
// 3.14 will be rounded to 3.14062 in bfloat16 format.
auto bf16_lit_truncated =
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
- ASSERT_EQ("3.14062", bf16_lit_truncated->ToString());
+ ASSERT_EQ("3.14062", bf16_lit_truncated.ToString());
auto bf16_lit_truncated2 =
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
- EXPECT_EQ("9", bf16_lit_truncated2->ToString());
+ EXPECT_EQ("9", bf16_lit_truncated2.ToString());
}
TEST_F(LiteralUtilTest, LiteralVectorToString) {
auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
- EXPECT_EQ("{101}", pred_vec->ToString());
+ EXPECT_EQ("{101}", pred_vec.ToString());
}
TEST_F(LiteralUtilTest, R2ToString) {
@@ -143,7 +143,7 @@ TEST_F(LiteralUtilTest, R2ToString) {
{ 3, 4 },
{ 5, 6 }
})";
- EXPECT_EQ(expected, literal->ToString());
+ EXPECT_EQ(expected, literal.ToString());
}
TEST_F(LiteralUtilTest, R3ToString) {
@@ -157,13 +157,13 @@ TEST_F(LiteralUtilTest, R3ToString) {
{ { 5 },
{ 6 } }
})";
- EXPECT_EQ(expected, literal->ToString());
+ EXPECT_EQ(expected, literal.ToString());
}
TEST_F(LiteralUtilTest, TupleToString) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
const string expected = R"((f32[], f32[2,2]) (
1,
f32[2,2] {
@@ -171,7 +171,7 @@ f32[2,2] {
{ 3, 4 }
}
))";
- EXPECT_EQ(expected, tuple->ToString());
+ EXPECT_EQ(expected, tuple.ToString());
}
TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
@@ -187,8 +187,8 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
// clang-format on
auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
- EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2));
- string result = literal->ToString();
+ EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2));
+ string result = literal.ToString();
const string expected = R"(f32[2,3,2] {
{ { 1, 2 },
{ 3, 4 },
@@ -220,10 +220,10 @@ TEST_F(LiteralUtilTest, CreateSparse) {
};
std::vector<int64> expected_values = {8, 9, 7, 10};
- EXPECT_EQ(literal->sparse_indices()->data(),
+ EXPECT_EQ(literal.sparse_indices()->data(),
absl::Span<const int64>(expected_indices.data(),
expected_indices.num_elements()));
- EXPECT_EQ(literal->data<int64>(), absl::Span<const int64>(expected_values));
+ EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
}
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
@@ -234,8 +234,8 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
{2001, 2002},
}, /*projection_p=*/1, /*projection_z=*/2);
// clang-format on
- EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2));
- string result = literal->ToString();
+ EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2));
+ string result = literal.ToString();
const string expected = R"(f32[1,2,3,2] {
{ /*i0=0*/
{ /*i1=0*/
@@ -254,9 +254,9 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
}
TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
- EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(),
+ EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(),
ElementsAre(2, 2, 3, 3));
- string result = literal_r4_2x2x3x3_dim0major_->ToString();
+ string result = literal_r4_2x2x3x3_dim0major_.ToString();
const string expected = R"(f32[2,2,3,3] {
{ /*i0=0*/
{ /*i1=0*/
@@ -294,7 +294,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
});
// clang-format on
std::vector<std::tuple<int64, int64, string>> seen;
- literal->EachCellAsString(
+ literal.EachCellAsString(
[&seen](absl::Span<const int64> indices, const string& value) {
seen.emplace_back(indices[0], indices[1], value);
});
@@ -310,14 +310,14 @@ TEST_F(LiteralUtilTest, ScalarEquality) {
auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
- EXPECT_EQ(*f32_42, *f32_42);
- EXPECT_EQ(*f32_42, *f32_42_clone);
+ EXPECT_EQ(f32_42, f32_42);
+ EXPECT_EQ(f32_42, f32_42_clone);
auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
- EXPECT_NE(*f32_42, *f32_123);
+ EXPECT_NE(f32_42, f32_123);
auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
- EXPECT_NE(*f32_42, *f64_42);
+ EXPECT_NE(f32_42, f64_42);
}
TEST_F(LiteralUtilTest, NonScalarEquality) {
@@ -330,12 +330,12 @@ TEST_F(LiteralUtilTest, NonScalarEquality) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
Literal nil(ShapeUtil::MakeNil());
- EXPECT_EQ(*matrix, *matrix);
- EXPECT_EQ(*matrix, *matrix_clone);
- EXPECT_NE(*matrix, *matrix_different);
- EXPECT_NE(*matrix, *vector_literal);
- EXPECT_NE(*matrix, *scalar);
- EXPECT_NE(*matrix, nil);
+ EXPECT_EQ(matrix, matrix);
+ EXPECT_EQ(matrix, matrix_clone);
+ EXPECT_NE(matrix, matrix_different);
+ EXPECT_NE(matrix, vector_literal);
+ EXPECT_NE(matrix, scalar);
+ EXPECT_NE(matrix, nil);
EXPECT_EQ(nil, nil);
}
@@ -344,57 +344,54 @@ TEST_F(LiteralUtilTest, TokenEquality) {
auto token1 = LiteralUtil::CreateToken();
auto scalar = LiteralUtil::CreateR0<float>(1.0);
- EXPECT_EQ(*token0, *token1);
- EXPECT_NE(*token0, *scalar);
+ EXPECT_EQ(token0, token1);
+ EXPECT_NE(token0, scalar);
- EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}),
- *LiteralUtil::MakeTuple({token0.get()}));
- EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
- *LiteralUtil::MakeTuple({token1.get(), scalar.get()}));
- EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
- *LiteralUtil::MakeTuple({scalar.get(), token1.get()}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&token0}),
+ LiteralUtil::MakeTuple({&token0}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}),
+ LiteralUtil::MakeTuple({&token1, &scalar}));
+ EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}),
+ LiteralUtil::MakeTuple({&scalar, &token1}));
}
TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
// Test equality with literals which have different layouts.
- auto colmajor = absl::make_unique<Literal>(
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
- colmajor->Set<float>({0, 0}, 1.0);
- colmajor->Set<float>({0, 1}, 2.0);
- colmajor->Set<float>({1, 0}, 3.0);
- colmajor->Set<float>({1, 1}, 4.0);
+ Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
+ colmajor.Set<float>({0, 0}, 1.0);
+ colmajor.Set<float>({0, 1}, 2.0);
+ colmajor.Set<float>({1, 0}, 3.0);
+ colmajor.Set<float>({1, 1}, 4.0);
- auto rowmajor = absl::make_unique<Literal>(
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
- rowmajor->Set<float>({0, 0}, 1.0);
- rowmajor->Set<float>({0, 1}, 2.0);
- rowmajor->Set<float>({1, 0}, 3.0);
- rowmajor->Set<float>({1, 1}, 4.0);
+ Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
+ rowmajor.Set<float>({0, 0}, 1.0);
+ rowmajor.Set<float>({0, 1}, 2.0);
+ rowmajor.Set<float>({1, 0}, 3.0);
+ rowmajor.Set<float>({1, 1}, 4.0);
- EXPECT_EQ(*rowmajor, *colmajor);
+ EXPECT_EQ(rowmajor, colmajor);
}
TEST_F(LiteralUtilTest, TupleEquality) {
// Test equality with tuples.
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix});
// Tuple with the same elements. One element is shared with the original
// tuple, the other is a clone of the element in the original tuple.
auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
- auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()});
- EXPECT_EQ(*tuple1, *tuple2);
+ auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix});
+ EXPECT_EQ(tuple1, tuple2);
// Tuple with elements reversed.
- auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()});
- EXPECT_NE(*tuple1, *reversed_tuple);
+ auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar});
+ EXPECT_NE(tuple1, reversed_tuple);
// Tuple with different value.
auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
- auto different_tuple =
- LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()});
- EXPECT_NE(*tuple1, *different_tuple);
+ auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix});
+ EXPECT_NE(tuple1, different_tuple);
}
TEST_F(LiteralUtilTest, C64Equality) {
@@ -405,162 +402,161 @@ TEST_F(LiteralUtilTest, C64Equality) {
// tuple, the other is a clone of the element in the original tuple.
auto vector_clone =
LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
- EXPECT_EQ(*vector, *vector_clone);
+ EXPECT_EQ(vector, vector_clone);
auto vector_reversed =
LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
- EXPECT_NE(*vector, *vector_reversed);
+ EXPECT_NE(vector, vector_reversed);
}
TEST_F(LiteralUtilTest, IsAllTuple) {
auto element1 = LiteralUtil::CreateR0<float>(0.0);
auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
- auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()});
+ auto tuple = LiteralUtil::MakeTuple({&element1, &element1});
// Tuples should always return false for IsAll.
- EXPECT_FALSE(tuple->IsAll(0));
- EXPECT_FALSE(tuple->IsAll(1));
+ EXPECT_FALSE(tuple.IsAll(0));
+ EXPECT_FALSE(tuple.IsAll(1));
}
// Verifies that CreateFromShape works for tuples.
TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
auto scalar = LiteralUtil::CreateR0<float>(0.0);
auto matrix = LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
- auto x = Literal::CreateFromShape(tuple->shape());
- EXPECT_EQ(*tuple, *x);
+ auto x = Literal::CreateFromShape(tuple.shape());
+ EXPECT_EQ(tuple, x);
}
TEST_F(LiteralUtilTest, IsAll) {
- EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false)->IsAll(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true)->IsAll(1));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(1));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(2));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(2));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(-1));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false).IsAll(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true).IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(-1));
// We shouldn't reinterpret int8_min as an unsigned type and then decide that
// it is equal to 255.
auto int8_min = std::numeric_limits<int8>::min();
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255)->IsAll(int8_min));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255).IsAll(int8_min));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0)->IsAll(42));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001)->IsAll(42));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0).IsAll(42));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001).IsAll(42));
- EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100})->IsAll(100));
- EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001})->IsAll(100));
+ EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100}).IsAll(100));
+ EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001}).IsAll(100));
- EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}}).IsAll(8));
half h8(8.0f);
half h9(9.0f);
- EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}}).IsAll(8));
bfloat16 b8(8.0f);
bfloat16 b9(9.0f);
- EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}}).IsAll(8));
// 9.001 will be truncated to 9.0
bfloat16 b91(9.001f);
bfloat16 b90(9.00f);
- EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}}).IsAll(9.0));
complex64 c8_9 = {8, 9};
- EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAll(8));
auto uint64_max = std::numeric_limits<uint64>::max();
EXPECT_FALSE(LiteralUtil::CreateR2<uint64>(
{{uint64_max, uint64_max}, {uint64_max, uint64_max}})
- ->IsAll(-1));
+ .IsAll(-1));
}
TEST_F(LiteralUtilTest, IsAllFloat) {
// IsAllFloat always returns false when the literal is not floating-point.
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllFloat(0));
-
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(0)->IsAllFloat(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.49));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllFloat(0));
+
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(0).IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5).IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.49));
EXPECT_FALSE(
- LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
- ->IsAllFloat(.5));
+ .IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(0)->IsAllFloat(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.49));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(0).IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5).IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.49));
EXPECT_FALSE(
- LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
}
TEST_F(LiteralUtilTest, IsAllComplex) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<double>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(0).IsAllComplex(0));
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
}
TEST_F(LiteralUtilTest, IsAllFirst) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2}).IsAllFirst());
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
- EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAllFirst());
- EXPECT_FALSE(
- LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}}).IsAllFirst());
}
TEST_F(LiteralUtilTest, IsZero) {
auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
- EXPECT_TRUE(scalar_zero->IsZero({}));
- EXPECT_FALSE(scalar_one->IsZero({}));
+ EXPECT_TRUE(scalar_zero.IsZero({}));
+ EXPECT_FALSE(scalar_one.IsZero({}));
auto array = LiteralUtil::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
- EXPECT_FALSE(array->IsZero({0, 1}));
- EXPECT_TRUE(array->IsZero({0, 2}));
- EXPECT_TRUE(array->IsZero({1, 1}));
- EXPECT_FALSE(array->IsZero({1, 2}));
+ EXPECT_FALSE(array.IsZero({0, 1}));
+ EXPECT_TRUE(array.IsZero({0, 2}));
+ EXPECT_TRUE(array.IsZero({1, 1}));
+ EXPECT_FALSE(array.IsZero({1, 2}));
auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
- EXPECT_TRUE(complex_zero->IsZero({}));
- EXPECT_FALSE(complex_nonzero->IsZero({}));
+ EXPECT_TRUE(complex_zero.IsZero({}));
+ EXPECT_FALSE(complex_nonzero.IsZero({}));
}
template <typename T>
@@ -576,19 +572,19 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
- auto data01 = data->Relayout(layout01);
- EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01));
- EXPECT_EQ(*data, *data01);
+ auto data01 = data.Relayout(layout01);
+ EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01));
+ EXPECT_EQ(data, data01);
- auto data10 = data->Relayout(layout10);
- EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10));
- EXPECT_EQ(*data, *data10);
+ auto data10 = data.Relayout(layout10);
+ EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10));
+ EXPECT_EQ(data, data10);
}
TEST_F(LiteralUtilTest, ReshapeR0) {
auto original = LiteralUtil::CreateR0<float>(1.7f);
- auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
- EXPECT_EQ(*original, *reshape);
+ auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
+ EXPECT_EQ(original, reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4) {
@@ -606,9 +602,9 @@ TEST_F(LiteralUtilTest, ReshapeR4) {
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
}, layout_r3_dim0major_);
// clang-format on
- auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
+ auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_EQ(*expected, *reshape);
+ EXPECT_EQ(expected, reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
@@ -626,15 +622,15 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
}, layout_r3_dim0major_);
// clang-format on
- auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
+ auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_EQ(*expected, *reshape);
+ EXPECT_EQ(expected, reshape);
}
TEST_F(LiteralUtilTest, TransposeR0) {
auto original = LiteralUtil::CreateR0<float>(1.7f);
- auto reshape = original->Transpose(/*permutation=*/{});
- EXPECT_EQ(*original, *reshape);
+ auto reshape = original.Transpose(/*permutation=*/{});
+ EXPECT_EQ(original, reshape);
}
TEST_F(LiteralUtilTest, TransposeR4) {
@@ -646,10 +642,10 @@ TEST_F(LiteralUtilTest, TransposeR4) {
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}});
// clang-format on
- auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1});
+ auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1});
- reshape->EachCell<float>([&](absl::Span<const int64> indices, float value) {
- EXPECT_EQ(value, original->Get<float>(
+ reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
+ EXPECT_EQ(value, original.Get<float>(
{indices[2], indices[3], indices[0], indices[1]}));
});
}
@@ -658,35 +654,35 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
// Tests that using Relayout on an array is equivalent to creating it in the
// target layout in the first place.
auto dim0minor_relaid_to_dim0major =
- literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_);
- EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major);
+ literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_);
+ EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major);
auto dim0major_relaid_to_dim0minor =
- literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_);
- EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor);
+ literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_);
+ EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor);
}
TEST_F(LiteralUtilTest, TestR2LinearLayout) {
// Test expected memory layout of R2 dim0-minor (column-major) literal.
auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
- EXPECT_EQ(mat_dim0minor->element_count(), 6);
- EXPECT_THAT(mat_dim0minor->data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
+ EXPECT_EQ(mat_dim0minor.element_count(), 6);
+ EXPECT_THAT(mat_dim0minor.data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
// Test expected memory layout when using Relayout to row major.
- auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_);
- EXPECT_THAT(relaid_mat_to_dim0major->data<int32>(),
+ auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_);
+ EXPECT_THAT(relaid_mat_to_dim0major.data<int32>(),
ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout of R2 created with dim0-major (row-major).
auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
- EXPECT_EQ(mat_dim0major->element_count(), 6);
- EXPECT_THAT(mat_dim0major->data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
+ EXPECT_EQ(mat_dim0major.element_count(), 6);
+ EXPECT_THAT(mat_dim0major.data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout when using Relayout to column major.
- auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_);
- EXPECT_THAT(relaid_mat_to_dim0minor->data<int32>(),
+ auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_);
+ EXPECT_THAT(relaid_mat_to_dim0minor.data<int32>(),
ElementsAre(1, 4, 2, 5, 3, 6));
}
@@ -707,77 +703,77 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
arr3d, layout_r3_dim0minor_);
- EXPECT_EQ(lit_dim0minor->element_count(), 12);
+ EXPECT_EQ(lit_dim0minor.element_count(), 12);
std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
- EXPECT_THAT(lit_dim0minor->data<int32>(),
+ EXPECT_THAT(lit_dim0minor.data<int32>(),
testing::ElementsAreArray(expected_dim0minor));
// Test expected memory layout when using Relayout to row major.
- auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_);
+ auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_);
std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
- EXPECT_THAT(relaid_lit_to_dim0major->data<int32>(),
+ EXPECT_THAT(relaid_lit_to_dim0major.data<int32>(),
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout of R3 created with dim0-major (row-major).
auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
arr3d, layout_r3_dim0major_);
- EXPECT_EQ(lit_dim0major->element_count(), 12);
- EXPECT_THAT(lit_dim0major->data<int32>(),
+ EXPECT_EQ(lit_dim0major.element_count(), 12);
+ EXPECT_THAT(lit_dim0major.data<int32>(),
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout when using Relayout to column major.
- auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_);
- EXPECT_THAT(relaid_lit_to_dim0minor->data<int32>(),
+ auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_);
+ EXPECT_THAT(relaid_lit_to_dim0minor.data<int32>(),
testing::ElementsAreArray(expected_dim0minor));
}
TEST_F(LiteralUtilTest, SliceR0S32) {
auto input = LiteralUtil::CreateR0<int32>(1);
- auto result = input->Slice({}, {});
- EXPECT_EQ(*input, *result);
+ auto result = input.Slice({}, {});
+ EXPECT_EQ(input, result);
}
TEST_F(LiteralUtilTest, SliceR1F32) {
auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
- auto result = input->Slice({3}, {4});
+ auto result = input.Slice({3}, {4});
auto expected = LiteralUtil::CreateR1<float>({4.0});
- EXPECT_EQ(*expected, *result);
+ EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, SliceR2U32) {
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
- auto result = input_3x4->Slice({0, 2}, {2, 4});
+ auto result = input_3x4.Slice({0, 2}, {2, 4});
auto expected = LiteralUtil::CreateR2<uint32>({{3, 4}, {7, 8}});
- EXPECT_EQ(*expected, *result);
+ EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, SliceR3U32Full) {
auto input_2x3x2 = LiteralUtil::CreateR3<uint32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2});
- EXPECT_EQ(*input_2x3x2, *result);
+ auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2});
+ EXPECT_EQ(input_2x3x2, result);
}
TEST_F(LiteralUtilTest, PopulateR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {1}));
output.PopulateR1<int64>({77});
auto expected = LiteralUtil::CreateR1<int64>({77});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR1U64) {
Literal output(ShapeUtil::MakeShape(U64, {2}));
output.PopulateR1<uint64>({{77, 88}});
auto expected = LiteralUtil::CreateR1<uint64>({{77, 88}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR1C64) {
Literal output(ShapeUtil::MakeShape(C64, {1}));
output.PopulateR1<complex64>({{77, 88}});
auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR2C64) {
@@ -785,7 +781,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
auto expected =
LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
@@ -793,7 +789,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
bfloat16 h(0.25f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR0<bfloat16>(h);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
@@ -801,7 +797,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
bfloat16 h(0.5f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
@@ -809,28 +805,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
bfloat16 h(2.0f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output(ShapeUtil::MakeShape(F32, {}));
output.PopulateWithValue<float>(2.5f);
auto expected = LiteralUtil::CreateR0<float>(2.5f);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {3}));
output.PopulateWithValue<int64>(-7);
auto expected = LiteralUtil::CreateR1<int64>({-7, -7, -7});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
output.PopulateWithValue<uint64>(42);
auto expected = LiteralUtil::CreateR2<uint64>({{42, 42}, {42, 42}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
@@ -838,7 +834,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
output.PopulateWithValue<complex64>({4, 2});
auto expected =
LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
@@ -846,7 +842,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
half h(0.25f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR0<half>(h);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
@@ -854,7 +850,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
half h(0.5f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR1<half>({h, h, h});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
@@ -862,18 +858,18 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
half h(2.0f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, ReplicateR2U32) {
auto input = LiteralUtil::CreateR2<uint32>(
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
- auto output = input->Replicate<uint32>(3);
+ auto output = input.Replicate<uint32>(3);
auto expected = LiteralUtil::CreateR3<uint32>(
{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
- EXPECT_EQ(*output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, CopySliceFrom) {
@@ -889,17 +885,17 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
const int64 step[] = {1, 1, 1, 1};
uint32 seqnr = 0;
auto init_proc = [&](absl::Span<const int64> indexes) {
- source->Set(indexes, ++seqnr);
+ source.Set(indexes, ++seqnr);
return true;
};
- ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step,
+ ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step,
init_proc);
auto blank = Literal::CreateFromShape(shape);
const int64 src_base[] = {3, 1, 5, 7};
const int64 dest_base[] = {6, 4, 12, 2};
const int64 copy_size[] = {7, 8, 11, 9};
- TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size));
+ TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size));
std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
@@ -911,12 +907,12 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
blank_indexes.begin(), std::plus<int64>());
- auto bval = blank->Get<uint32>(blank_indexes);
- matched = (bval != 0 && bval == source->Get<uint32>(source_indexes));
+ auto bval = blank.Get<uint32>(blank_indexes);
+ matched = (bval != 0 && bval == source.Get<uint32>(source_indexes));
return matched;
};
- ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step,
+ ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step,
check_proc);
EXPECT_TRUE(matched);
}
@@ -925,14 +921,14 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
TEST_F(LiteralUtilTest, CopyFromScalars) {
auto zero = LiteralUtil::CreateR0<uint32>(0);
auto nine = LiteralUtil::CreateR0<uint32>(9);
- TF_EXPECT_OK(zero->CopyFrom(*nine));
- EXPECT_EQ(*zero, *nine);
+ TF_EXPECT_OK(zero.CopyFrom(nine));
+ EXPECT_EQ(zero, nine);
auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
- TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {}));
- EXPECT_EQ(zero->Get<uint32>({}), 17);
- TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {}));
- EXPECT_EQ(vect->Get<uint32>({4}), 17);
+ TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {}));
+ EXPECT_EQ(zero.Get<uint32>({}), 17);
+ TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {}));
+ EXPECT_EQ(vect.Get<uint32>({4}), 17);
}
TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
@@ -945,17 +941,17 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
const auto empty = Literal::CreateFromShape(empty_r1_shape);
auto nine = LiteralUtil::CreateR1<float>({9});
- TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0}));
- EXPECT_EQ(*nine, *const_nine);
+ TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0}));
+ EXPECT_EQ(nine, const_nine);
}
{
// Copy 0 element to destination with zero elements.
- const auto empty = Literal::CreateFromShape(empty_r1_shape);
+ auto empty = Literal::CreateFromShape(empty_r1_shape);
auto nine = LiteralUtil::CreateR1<float>({9});
- TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0}));
- EXPECT_EQ(*empty, *const_empty);
+ TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0}));
+ EXPECT_EQ(empty, const_empty);
}
}
@@ -969,74 +965,75 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) {
TEST_F(LiteralUtilTest, CopyFromArrays) {
auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
- EXPECT_NE(*scalar_42, *scalar_123);
- TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{},
- /*src_shape_index=*/{}));
- EXPECT_EQ(*scalar_42, *scalar_123);
- EXPECT_EQ(scalar_42->Get<float>({}), 123.0f);
+ EXPECT_NE(scalar_42, scalar_123);
+ TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{},
+ /*src_shape_index=*/{}));
+ EXPECT_EQ(scalar_42, scalar_123);
+ EXPECT_EQ(scalar_42.Get<float>({}), 123.0f);
auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
- EXPECT_NE(*matrix_1234, *matrix_5678);
- EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 1.0f);
- TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{},
- /*src_shape_index=*/{}));
- EXPECT_EQ(*matrix_1234, *matrix_5678);
- EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 5.0f);
+ EXPECT_NE(matrix_1234, matrix_5678);
+ EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 1.0f);
+ TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{},
+ /*src_shape_index=*/{}));
+ EXPECT_EQ(matrix_1234, matrix_5678);
+ EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 5.0f);
}
TEST_F(LiteralUtilTest, CopyFromTuples) {
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {matrix.get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
- .get()});
+ Literal inner_elements[] = {LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR1<double>({23.0, 44.0})};
+ Literal inner_tuple = LiteralUtil::MakeTuple(
+ {&inner_elements[0], &inner_elements[1], &nil_literal});
+ Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple});
// Create a tuple the same shape as the inner tuple of nested_tuple but with
// different values..
- auto tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(-5).get(),
- LiteralUtil::CreateR1<double>({2.0, 4.0}).get(), &nil_literal});
+ Literal int32_minus5 = LiteralUtil::CreateR0<int32>(-5);
+ Literal double_2_4 = LiteralUtil::CreateR1<double>({2.0, 4.0});
+ Literal tuple =
+ LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal});
- EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
- EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
- EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 23.0);
- EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 44.0);
+ EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
+ EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), 42);
+ EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 23.0);
+ EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 44.0);
// Overwrite the inner tuple element of nested_tuple with the contents of
// 'tuple'.
- TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
- /*src_shape_index=*/{}));
+ TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
+ /*src_shape_index=*/{}));
// The matrix element should be unchanged.
- EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
+ EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
// The tuple element should have been copied from 'tuple'.
- EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), -5);
- EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 2.0);
- EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 4.0);
+ EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), -5);
+ EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 2.0);
+ EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 4.0);
}
TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
- auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(-2).get(),
- LiteralUtil::CreateR0<int32>(4).get()});
+ Literal elements[] = {LiteralUtil::CreateR0<int32>(-2),
+ LiteralUtil::CreateR0<int32>(4)};
+ Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
- EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
- EXPECT_EQ(tuple->Get<int32>({}, {1}), 4);
+ EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {1}), 4);
// Copy from one element to the other.
- TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
- /*src_shape_index=*/{0}));
+ TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
+ /*src_shape_index=*/{0}));
- EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
- EXPECT_EQ(tuple->Get<int32>({}, {1}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {1}), -2);
}
TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
- Status status = matrix->CopyFrom(*vector);
+ Status status = matrix.CopyFrom(vector);
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Destination subshape incompatible"));
@@ -1046,9 +1043,8 @@ TEST_F(LiteralUtilTest, F16) {
// Verify that the internal data views are consistent and that they
// are in little endian format
// TODO - modify if we make the data format machine endianess dependent
- auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
- Literal* l1 = m1.get();
- const char* d1 = reinterpret_cast<const char*>(l1->data<half>().data());
+ Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
+ const char* d1 = reinterpret_cast<const char*>(m1.data<half>().data());
EXPECT_EQ(d1[0], 0);
EXPECT_EQ(d1[1], 0);
EXPECT_EQ(d1[2], 0);
@@ -1061,8 +1057,7 @@ TEST_F(LiteralUtilTest, F16) {
half h1(1.0f);
half h2(2.0f);
auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
- Literal* l2 = m2.get();
- const char* d2 = reinterpret_cast<const char*>(l2->data<half>().data());
+ const char* d2 = reinterpret_cast<const char*>(m2.data<half>().data());
EXPECT_EQ(d2[0], 0);
EXPECT_EQ(d2[1], 0x3C);
EXPECT_EQ(d2[2], 0);
@@ -1091,25 +1086,25 @@ TEST_F(LiteralUtilTest, Populate) {
Shape shape = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
- auto literal = absl::make_unique<Literal>(shape);
+ Literal literal(shape);
auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
- return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
indexes) +
17;
};
- TF_EXPECT_OK(literal->Populate<uint32>(generator));
+ TF_EXPECT_OK(literal.Populate<uint32>(generator));
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
auto check_function = [&](absl::Span<const int64> indexes) {
- auto value = literal->Get<uint32>(indexes);
+ auto value = literal.Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;
};
- ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+ ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
check_function);
EXPECT_TRUE(matched);
}
@@ -1133,25 +1128,25 @@ TEST_F(LiteralUtilTest, PopulateParallel) {
Shape shape = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
- auto literal = absl::make_unique<Literal>(shape);
+ Literal literal(shape);
auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
- return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
indexes) +
17;
};
- TF_EXPECT_OK(literal->PopulateParallel<uint32>(generator));
+ TF_EXPECT_OK(literal.PopulateParallel<uint32>(generator));
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
auto check_function = [&](absl::Span<const int64> indexes) {
- auto value = literal->Get<uint32>(indexes);
+ auto value = literal.Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;
};
- ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+ ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
check_function);
EXPECT_TRUE(matched);
}
@@ -1170,10 +1165,9 @@ TEST_F(LiteralUtilTest, ConvertR4) {
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}}, layout_r4_dim0major_);
// clang-format on
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
- original->Convert(U32));
+ TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32));
- EXPECT_EQ(*expected, *converted);
+ EXPECT_EQ(expected, converted);
}
TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
@@ -1245,69 +1239,65 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
}}, layout_r4_dim0major_);
// clang-format on
- std::unique_ptr<Literal> conv;
+ Literal conv;
- conv = s8->Convert(U32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *u32);
+ conv = s8.Convert(U32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, u32);
- conv = s8->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = s8.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = s8->Convert(U64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *u64);
+ conv = s8.Convert(U64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, u64);
- conv = s8->Convert(S64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s64);
+ conv = s8.Convert(S64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s64);
- conv = s8->Convert(PRED).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *pred);
+ conv = s8.Convert(PRED).ConsumeValueOrDie();
+ EXPECT_EQ(conv, pred);
- conv = bf16->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = bf16.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = bf16->Convert(F32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f32);
+ conv = bf16.Convert(F32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f32);
- conv = pred->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *int32_pred);
+ conv = pred.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, int32_pred);
- conv = f32->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = f32.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = f64->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = f64.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = s32->Convert(F32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f32);
+ conv = s32.Convert(F32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f32);
- conv = f32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = f32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = f64->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = f64.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = s32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = s32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = u32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = u32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = s32->Convert(C64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *c64);
+ conv = s32.Convert(C64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, c64);
- conv = f16->Convert(C64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *c64);
+ conv = f16.Convert(C64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, c64);
- EXPECT_EQ(s32->Convert(TUPLE).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(s32->Convert(S16).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(s32->Convert(U16).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(c64->Convert(F32).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(c64->Convert(S32).status().code(),
+ EXPECT_EQ(s32.Convert(TUPLE).status().code(),
tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
}
TEST_F(LiteralUtilTest, BitcastConvert) {
@@ -1317,13 +1307,12 @@ TEST_F(LiteralUtilTest, BitcastConvert) {
tensorflow::bit_cast<uint32>(100.f), 0xbeef});
auto expected = LiteralUtil::CreateR1<float>(
{2.5f, -42.25f, 100.0f, tensorflow::bit_cast<float>(0xbeef)});
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
- original->BitcastConvert(F32));
+ TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32));
}
TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
auto literal = LiteralUtil::CreateR0<uint32>(1234);
- Status status = literal->BitcastConvert(F64).status();
+ Status status = literal.BitcastConvert(F64).status();
EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(
absl::StrContains(status.error_message(), "bit widths are different"));
@@ -1341,11 +1330,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
p.add_preds((i % 2) == (len % 2));
}
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
- Literal::CreateFromProto(p));
- ASSERT_EQ(len, literal->data<bool>().size());
+ TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
+ ASSERT_EQ(len, literal.data<bool>().size());
int i = 0;
- for (bool value : literal->data<bool>()) {
+ for (bool value : literal.data<bool>()) {
EXPECT_EQ((i % 2) == (len % 2), value);
++i;
}
@@ -1358,11 +1346,10 @@ TEST_F(LiteralUtilTest, ToProto_f16) {
half h2(2.0f);
auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
- Literal* l = m.get();
- EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape()));
- EXPECT_EQ(4, l->data<half>().size());
+ EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape()));
+ EXPECT_EQ(4, m.data<half>().size());
- LiteralProto p = l->ToProto();
+ LiteralProto p = m.ToProto();
EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape()));
EXPECT_EQ(8, p.f16s().size());
const char* d = p.f16s().data();
@@ -1389,9 +1376,8 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
LayoutUtil::SetToDefaultLayout(p.mutable_shape());
p.clear_f16s();
p.set_f16s(half_vals, 8);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
- Literal::CreateFromProto(p));
- auto r = literal->data<half>();
+ TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
+ auto r = literal.data<half>();
ASSERT_EQ(4, r.size());
EXPECT_EQ(h1, r[0]);
EXPECT_EQ(h2, r[1]);
@@ -1402,43 +1388,41 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
TEST_F(LiteralUtilTest, LiteralSliceTest) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
Literal nil(ShapeUtil::MakeNil());
- EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar);
- EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix);
- EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple);
+ EXPECT_EQ(LiteralSlice(scalar, {}), scalar);
+ EXPECT_EQ(LiteralSlice(matrix, {}), matrix);
+ EXPECT_EQ(LiteralSlice(tuple, {}), tuple);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple);
EXPECT_EQ(LiteralSlice(nil, {}), nil);
- EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar);
- EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix);
+ EXPECT_EQ(LiteralSlice(tuple, {0}), scalar);
+ EXPECT_EQ(LiteralSlice(tuple, {1}), matrix);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar);
}
TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
// Verify that changing the underlying data beneath the view changes the
// data of the view itself.
- const auto nested_tuple_view = LiteralSlice(*nested_tuple);
- EXPECT_EQ(
- nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
- 1.0f);
+ const auto nested_tuple_view = LiteralSlice(nested_tuple);
+ EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+ 1.0f);
EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
/*shape_index=*/{0, 0}),
1.0f);
- nested_tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
- EXPECT_EQ(
- nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
- 555.0f);
+ nested_tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
+ EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+ 555.0f);
EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
/*shape_index=*/{0, 0}),
555.0f);
@@ -1447,14 +1431,14 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
- const auto nested_tuple_view = LiteralSlice(*nested_tuple);
+ const auto nested_tuple_view = LiteralSlice(nested_tuple);
const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
EXPECT_EQ(matrix_view,
- *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
}
TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
@@ -1497,9 +1481,8 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
}
TEST_F(LiteralUtilTest, LiteralMove) {
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- Literal literal(std::move(*matrix));
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal(std::move(matrix));
EXPECT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
@@ -1511,17 +1494,21 @@ TEST_F(LiteralUtilTest, LiteralMove) {
TEST_F(LiteralUtilTest, DecomposeTuple) {
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
- .get(),
- &nil_literal});
-
- EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape()));
- std::vector<Literal> elements = nested_tuple->DecomposeTuple();
- EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape()));
+ Literal inner_elements[] = {
+ LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR1<double>({23.0, 44.0}),
+ };
+ Literal tuple_elements[] = {
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}),
+ LiteralUtil::MakeTuple(
+ {&inner_elements[0], &inner_elements[1], &nil_literal}),
+ };
+ Literal nested_tuple = LiteralUtil::MakeTuple(
+ {&tuple_elements[0], &tuple_elements[1], &nil_literal});
+
+ EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape()));
+ std::vector<Literal> elements = nested_tuple.DecomposeTuple();
+ EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape()));
ASSERT_EQ(elements.size(), 3);
@@ -1552,13 +1539,13 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
TEST_F(LiteralUtilTest, MoveIntoTuple) {
std::vector<Literal> elements;
- elements.push_back(std::move(*LiteralUtil::CreateR0<float>(1.0)));
- elements.push_back(std::move(*LiteralUtil::CreateR1<int32>({4, 8})));
- elements.push_back(std::move(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get()})
-
- ));
+ elements.push_back(LiteralUtil::CreateR0<float>(1.0));
+ elements.push_back(LiteralUtil::CreateR1<int32>({4, 8}));
+ std::vector<Literal> inner_elements;
+ inner_elements.push_back(LiteralUtil::CreateR0<int32>(42));
+ inner_elements.push_back(LiteralUtil::CreateR1<double>({23.0, 44.0}));
+ elements.push_back(
+ LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]}));
Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
@@ -1586,9 +1573,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
Literal literal;
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- literal = std::move(*matrix);
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ literal = std::move(matrix);
EXPECT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
@@ -1599,9 +1585,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
}
TEST_F(LiteralUtilTest, LiteralSliceCopy) {
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- const auto matrix_view = LiteralSlice(*matrix);
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ const auto matrix_view = LiteralSlice(matrix);
LiteralSlice matrix_view_copy(matrix_view);
EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
@@ -1611,45 +1596,43 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) {
}
TEST_F(LiteralUtilTest, GetSetTuple) {
- auto tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(42.0).get(),
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()});
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
- tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
-
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
- 3.0);
- tuple->Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
+ Literal elements[] = {
+ LiteralUtil::CreateR0<float>(42.0),
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ };
+ auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
+ tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
+
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0);
+ tuple.Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
-4.0);
}
TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
// Literals constructed using CreateFromShape should be zero initialized.
- std::unique_ptr<Literal> scalar_f32 =
- Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
- EXPECT_EQ(scalar_f32->Get<float>({}), 0.0);
- EXPECT_TRUE(scalar_f32->IsAll(0));
-
- std::unique_ptr<Literal> vector_s32 =
- Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
- EXPECT_EQ(vector_s32->Get<int32>({0}), 0);
- EXPECT_EQ(vector_s32->Get<int32>({1}), 0);
- EXPECT_EQ(vector_s32->Get<int32>({2}), 0);
- EXPECT_TRUE(vector_s32->IsAll(0));
-
- std::unique_ptr<Literal> tuple =
- Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
- {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
- ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
-
- EXPECT_EQ(tuple->Get<double>({}, {0}), 0.0);
- EXPECT_EQ(tuple->Get<bool>({0}, {1}), false);
- EXPECT_EQ(tuple->Get<bool>({1}, {1}), false);
- EXPECT_EQ(tuple->Get<uint64>({0, 0}, {2}), 0);
- EXPECT_EQ(tuple->Get<uint64>({1, 0}, {2}), 0);
- EXPECT_EQ(tuple->Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
+ Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
+ EXPECT_EQ(scalar_f32.Get<float>({}), 0.0);
+ EXPECT_TRUE(scalar_f32.IsAll(0));
+
+ Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
+ EXPECT_EQ(vector_s32.Get<int32>({0}), 0);
+ EXPECT_EQ(vector_s32.Get<int32>({1}), 0);
+ EXPECT_EQ(vector_s32.Get<int32>({2}), 0);
+ EXPECT_TRUE(vector_s32.IsAll(0));
+
+ Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
+ ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
+
+ EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
+ EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
+ EXPECT_EQ(tuple.Get<bool>({1}, {1}), false);
+ EXPECT_EQ(tuple.Get<uint64>({0, 0}, {2}), 0);
+ EXPECT_EQ(tuple.Get<uint64>({1, 0}, {2}), 0);
+ EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
}
TEST_F(LiteralUtilTest, ProtoRoundTrip) {
@@ -1657,6 +1640,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
auto one_f32 = LiteralUtil::CreateR0<float>(1.0);
auto two_f32 = LiteralUtil::CreateR0<float>(2.0);
auto vector_int8 = LiteralUtil::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
+ auto vector_uint8 = LiteralUtil::CreateR1<uint8>({128, 0, 2, 56, 127, 255});
auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
{bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
@@ -1665,25 +1649,27 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
auto matrix_pred =
LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
auto tuple = LiteralUtil::MakeTuple(
- {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()});
+ {&one_f32, &vector_half, &matrix_pred, &matrix_pred});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal});
+ auto nested_tuple =
+ LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal});
auto to_from_proto = [](const Literal& literal) -> Literal {
- return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie());
+ return Literal::CreateFromProto(literal.ToProto()).ValueOrDie();
};
- EXPECT_EQ(*one_f32, to_from_proto(*one_f32));
- EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64));
- EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16));
- EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred));
- EXPECT_EQ(*tuple, to_from_proto(*tuple));
- EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple));
+ EXPECT_EQ(one_f32, to_from_proto(one_f32));
+ EXPECT_EQ(vector_int8, to_from_proto(vector_int8));
+ EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8));
+ EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
+ EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
+ EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
+ EXPECT_EQ(tuple, to_from_proto(tuple));
+ EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple));
EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
- EXPECT_NE(*one_f32, *two_f32);
- EXPECT_NE(*one_f32, to_from_proto(*two_f32));
+ EXPECT_NE(one_f32, two_f32);
+ EXPECT_NE(one_f32, to_from_proto(two_f32));
}
TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
@@ -1802,11 +1788,11 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
TEST_F(LiteralUtilTest, SortSparseElements) {
auto literal = LiteralUtil::CreateSparse<float>({10, 10, 10},
SparseIndexArray(10, 3), {});
- literal->AppendSparseElement<float>({2, 3, 4}, 2.0);
- literal->AppendSparseElement<float>({3, 4, 5}, 3.0);
- literal->AppendSparseElement<float>({1, 2, 3}, 1.0);
- literal->SortSparseElements();
- EXPECT_EQ(literal->ToString(false),
+ literal.AppendSparseElement<float>({2, 3, 4}, 2.0);
+ literal.AppendSparseElement<float>({3, 4, 5}, 3.0);
+ literal.AppendSparseElement<float>({1, 2, 3}, 1.0);
+ literal.SortSparseElements();
+ EXPECT_EQ(literal.ToString(false),
"f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
}
@@ -1816,57 +1802,54 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) {
EXPECT_EQ(
LiteralUtil::CreateSparse<bool>(dimensions, indices, {true, false, true})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
"false");
EXPECT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat(int64{2}));
EXPECT_EQ(
LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat(double{2.0}));
EXPECT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
{half{1.0}, half{2.0}, half{3.0}})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat(static_cast<float>(half{2.0})));
EXPECT_EQ(LiteralUtil::CreateSparse<complex64>(
dimensions, indices,
std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
+ Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
- /*dimensions=*/{0}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{0}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
+ Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
- /*dimensions=*/{1}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{1}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(9);
+ Literal literal = LiteralUtil::CreateR0<int32>(9);
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
- /*dimensions=*/{}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
+ /*dimensions=*/{}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
}
} // namespace