aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/literal_util_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-09-21 15:05:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-21 15:08:28 -0700
commit36647440d2e62cb494e4e6f6d5d9144ceb0b29c7 (patch)
treed97c715ef6c79b205442f254679b1ffa03be94e4 /tensorflow/compiler/xla/literal_util_test.cc
parent57498a86c11dfc98dda84dc7318a3c84c85c6791 (diff)
Add methods to convert between Literals and ShapedBuffers to LocalClient. These "conversion" methods copy the data to/from the device into/from literals.
Also fix various issues I noticed along the way: * Move LocalClient tests into open source. * Add proper == operator to Literals * Add << overload for streaming Literals to output. * Add Literal::GetSubliteral methods. * Remove unused AllocatBufferOnDevice methods from LocalClient and LocalService. PiperOrigin-RevId: 169606342
Diffstat (limited to 'tensorflow/compiler/xla/literal_util_test.cc')
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc134
1 files changed, 76 insertions, 58 deletions
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index 61ceac4f9a..e7dedd0821 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -257,37 +257,37 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
}
TEST_F(LiteralUtilTest, ScalarEquality) {
- // Test Literal::Equal with scalars.
+ // Test equality with scalars.
auto f32_42 = Literal::CreateR0<float>(42.0);
auto f32_42_clone = Literal::CreateR0<float>(42.0);
- EXPECT_TRUE(f32_42->Equal(*f32_42));
- EXPECT_TRUE(f32_42->Equal(*f32_42_clone));
+ EXPECT_EQ(*f32_42, *f32_42);
+ EXPECT_EQ(*f32_42, *f32_42_clone);
auto f32_123 = Literal::CreateR0<float>(123.0);
- EXPECT_FALSE(f32_42->Equal(*f32_123));
+ EXPECT_NE(*f32_42, *f32_123);
auto f64_42 = Literal::CreateR0<double>(42.0);
- EXPECT_FALSE(f32_42->Equal(*f64_42));
+ EXPECT_NE(*f32_42, *f64_42);
}
TEST_F(LiteralUtilTest, NonScalarEquality) {
- // Test Literal::Equal with nonscalars.
+ // Test equality with nonscalars.
auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto matrix_clone = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto matrix_different = Literal::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
auto vector_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
auto scalar = Literal::CreateR0<float>(1.0);
- EXPECT_TRUE(matrix->Equal(*matrix));
- EXPECT_TRUE(matrix->Equal(*matrix_clone));
- EXPECT_FALSE(matrix->Equal(*matrix_different));
- EXPECT_FALSE(matrix->Equal(*vector_literal));
- EXPECT_FALSE(matrix->Equal(*scalar));
+ EXPECT_EQ(*matrix, *matrix);
+ EXPECT_EQ(*matrix, *matrix_clone);
+ EXPECT_NE(*matrix, *matrix_different);
+ EXPECT_NE(*matrix, *vector_literal);
+ EXPECT_NE(*matrix, *scalar);
}
TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
- // Test Literal::Equal with literals which have different layouts.
+ // Test equality with literals which have different layouts.
auto colmajor = MakeUnique<Literal>();
*colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2});
*colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
@@ -306,11 +306,11 @@ TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
rowmajor->Set<float>({1, 0}, 3.0);
rowmajor->Set<float>({1, 1}, 4.0);
- EXPECT_TRUE(rowmajor->Equal(*colmajor));
+ EXPECT_EQ(*rowmajor, *colmajor);
}
TEST_F(LiteralUtilTest, TupleEquality) {
- // Test Literal::Equal with tuples.
+ // Test equality with tuples.
auto scalar = Literal::CreateR0<float>(1.0);
auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()});
@@ -319,16 +319,16 @@ TEST_F(LiteralUtilTest, TupleEquality) {
// tuple, the other is a clone of the element in the original tuple.
auto scalar_clone = Literal::CreateR0<float>(1.0);
auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()});
- EXPECT_TRUE(tuple1->Equal(*tuple2));
+ EXPECT_EQ(*tuple1, *tuple2);
// Tuple with elements reversed.
auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()});
- EXPECT_FALSE(tuple1->Equal(*reversed_tuple));
+ EXPECT_NE(*tuple1, *reversed_tuple);
// Tuple with different value.
auto scalar_42 = Literal::CreateR0<float>(42.0);
auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()});
- EXPECT_FALSE(tuple1->Equal(*different_tuple));
+ EXPECT_NE(*tuple1, *different_tuple);
}
TEST_F(LiteralUtilTest, IsAllTuple) {
@@ -348,7 +348,7 @@ TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
auto x = Literal::CreateFromShape(tuple->shape());
- EXPECT_TRUE(tuple->Equal(*x));
+ EXPECT_EQ(*tuple, *x);
}
TEST_F(LiteralUtilTest, IsAll) {
@@ -439,17 +439,17 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
auto data01 = data->Relayout(layout01);
EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01));
- EXPECT_TRUE(data->Equal(*data01));
+ EXPECT_EQ(*data, *data01);
auto data10 = data->Relayout(layout10);
EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10));
- EXPECT_TRUE(data->Equal(*data10));
+ EXPECT_EQ(*data, *data10);
}
TEST_F(LiteralUtilTest, ReshapeR0) {
auto original = Literal::CreateR0<float>(1.7f);
auto reshape = original->Reshape(/*shape=*/{}).ConsumeValueOrDie();
- EXPECT_TRUE(original->Equal(*reshape));
+ EXPECT_EQ(*original, *reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4) {
@@ -469,7 +469,7 @@ TEST_F(LiteralUtilTest, ReshapeR4) {
// clang-format on
auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_TRUE(expected->Equal(*reshape));
+ EXPECT_EQ(*expected, *reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
@@ -489,13 +489,13 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
// clang-format on
auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_TRUE(expected->Equal(*reshape));
+ EXPECT_EQ(*expected, *reshape);
}
TEST_F(LiteralUtilTest, TransposeR0) {
auto original = Literal::CreateR0<float>(1.7f);
auto reshape = original->Transpose(/*permutation=*/{});
- EXPECT_TRUE(original->Equal(*reshape));
+ EXPECT_EQ(*original, *reshape);
}
TEST_F(LiteralUtilTest, TransposeR4) {
@@ -521,13 +521,11 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
// target layout in the first place.
auto dim0minor_relaid_to_dim0major =
literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_);
- EXPECT_TRUE(
- literal_r4_2x2x3x3_dim0major_->Equal(*dim0minor_relaid_to_dim0major));
+ EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major);
auto dim0major_relaid_to_dim0minor =
literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_);
- EXPECT_TRUE(
- literal_r4_2x2x3x3_dim0minor_->Equal(*dim0major_relaid_to_dim0minor));
+ EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor);
}
TEST_F(LiteralUtilTest, TestR2LinearLayout) {
@@ -596,14 +594,14 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
TEST_F(LiteralUtilTest, SliceR0S32) {
auto input = Literal::CreateR0<int32>(1);
auto result = input->Slice({}, {});
- EXPECT_TRUE(input->Equal(*result));
+ EXPECT_EQ(*input, *result);
}
TEST_F(LiteralUtilTest, SliceR1F32) {
auto input = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
auto result = input->Slice({3}, {4});
auto expected = Literal::CreateR1<float>({4.0});
- EXPECT_TRUE(expected->Equal(*result));
+ EXPECT_EQ(*expected, *result);
}
TEST_F(LiteralUtilTest, SliceR2U32) {
@@ -611,49 +609,49 @@ TEST_F(LiteralUtilTest, SliceR2U32) {
Literal::CreateR2<uint32>({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
auto result = input_3x4->Slice({0, 2}, {2, 4});
auto expected = Literal::CreateR2<uint32>({{3, 4}, {7, 8}});
- EXPECT_TRUE(expected->Equal(*result));
+ EXPECT_EQ(*expected, *result);
}
TEST_F(LiteralUtilTest, SliceR3U32Full) {
auto input_2x3x2 = Literal::CreateR3<uint32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2});
- EXPECT_TRUE(input_2x3x2->Equal(*result));
+ EXPECT_EQ(*input_2x3x2, *result);
}
TEST_F(LiteralUtilTest, PopulateR1S64) {
Literal output;
output.PopulateR1<int64>({77});
auto expected = Literal::CreateR1<int64>({77});
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateR2U64) {
Literal output;
output.PopulateR1<uint64>({{77, 88}});
auto expected = Literal::CreateR1<uint64>({{77, 88}});
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output;
output.PopulateWithValue<float>(2.5f, {});
auto expected = Literal::CreateR0<float>(2.5f);
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
Literal output;
output.PopulateWithValue<int64>(-7, {3});
auto expected = Literal::CreateR1<int64>({-7, -7, -7});
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
Literal output;
output.PopulateWithValue<uint64>(42, {2, 2});
auto expected = Literal::CreateR2<uint64>({{42, 42}, {42, 42}});
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
@@ -661,7 +659,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
half h(0.25f);
output.PopulateWithValue<half>(h, {});
auto expected = Literal::CreateR0<half>(h);
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
@@ -669,7 +667,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
half h(0.5f);
output.PopulateWithValue<half>(h, {3});
auto expected = Literal::CreateR1<half>({h, h, h});
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
@@ -677,7 +675,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
half h(2.0f);
output.PopulateWithValue<half>(h, {2, 2});
auto expected = Literal::CreateR2<half>({{h, h}, {h, h}});
- EXPECT_TRUE(output.Equal(*expected));
+ EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, ReplicateR2U32) {
@@ -688,7 +686,7 @@ TEST_F(LiteralUtilTest, ReplicateR2U32) {
{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
- EXPECT_TRUE(output->Equal(*expected));
+ EXPECT_EQ(*output, *expected);
}
TEST_F(LiteralUtilTest, Copy) {
@@ -741,7 +739,7 @@ TEST_F(LiteralUtilTest, CopyScalars) {
auto zero = Literal::CreateR0<uint32>(0);
auto nine = Literal::CreateR0<uint32>(9);
TF_EXPECT_OK(zero->Copy(*nine, {}, {}, {}));
- EXPECT_TRUE(zero->Equal(*nine));
+ EXPECT_EQ(*zero, *nine);
auto vect = Literal::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
TF_EXPECT_OK(zero->Copy(*vect, {5}, {}, {}));
@@ -761,7 +759,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
auto nine = Literal::CreateR1<float>({9});
TF_EXPECT_OK(nine->Copy(*empty, {0}, {0}, {0}));
- EXPECT_TRUE(nine->Equal(*const_nine));
+ EXPECT_EQ(*nine, *const_nine);
}
{
@@ -770,7 +768,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
auto nine = Literal::CreateR1<float>({9});
TF_EXPECT_OK(empty->Copy(*nine, {0}, {0}, {0}));
- EXPECT_TRUE(empty->Equal(*const_empty));
+ EXPECT_EQ(*empty, *const_empty);
}
}
@@ -863,7 +861,7 @@ TEST_F(LiteralUtilTest, ConvertR4) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
original->Convert(U32));
- EXPECT_TRUE(expected->Equal(*converted));
+ EXPECT_EQ(*expected, *converted);
}
TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
@@ -925,43 +923,43 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
std::unique_ptr<Literal> conv;
conv = s8->Convert(U32).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*u32));
+ EXPECT_EQ(*conv, *u32);
conv = s8->Convert(S32).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*s32));
+ EXPECT_EQ(*conv, *s32);
conv = s8->Convert(U64).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*u64));
+ EXPECT_EQ(*conv, *u64);
conv = s8->Convert(S64).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*s64));
+ EXPECT_EQ(*conv, *s64);
conv = s8->Convert(PRED).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*pred));
+ EXPECT_EQ(*conv, *pred);
conv = pred->Convert(S32).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*int32_pred));
+ EXPECT_EQ(*conv, *int32_pred);
conv = f32->Convert(S32).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*s32));
+ EXPECT_EQ(*conv, *s32);
conv = f64->Convert(S32).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*s32));
+ EXPECT_EQ(*conv, *s32);
conv = s32->Convert(F32).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*f32));
+ EXPECT_EQ(*conv, *f32);
conv = f32->Convert(F16).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*f16));
+ EXPECT_EQ(*conv, *f16);
conv = f64->Convert(F16).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*f16));
+ EXPECT_EQ(*conv, *f16);
conv = s32->Convert(F16).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*f16));
+ EXPECT_EQ(*conv, *f16);
conv = u32->Convert(F16).ConsumeValueOrDie();
- EXPECT_TRUE(conv->Equal(*f16));
+ EXPECT_EQ(*conv, *f16);
EXPECT_EQ(s32->Convert(TUPLE).status().code(),
tensorflow::error::INVALID_ARGUMENT);
@@ -1045,5 +1043,25 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
ASSERT_EQ(h1, r[3]);
}
+TEST_F(LiteralUtilTest, Subliterals) {
+ auto scalar = Literal::CreateR0<float>(1.0);
+ auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
+ auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+
+ EXPECT_EQ(&scalar->GetSubliteral(/*index=*/{}), scalar.get());
+ EXPECT_EQ(&matrix->GetSubliteral(/*index=*/{}), matrix.get());
+ EXPECT_EQ(&tuple->GetSubliteral(/*index=*/{}), tuple.get());
+ EXPECT_EQ(&nested_tuple->GetSubliteral(/*index=*/{}), nested_tuple.get());
+
+ EXPECT_EQ(tuple->GetSubliteral(/*index=*/{0}), *scalar);
+ EXPECT_EQ(tuple->GetSubliteral(/*index=*/{1}), *matrix);
+
+ EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0}), *tuple);
+ EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 0}), *scalar);
+ EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 1}), *matrix);
+ EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{1}), *scalar);
+}
+
} // namespace
} // namespace xla