diff options
author | 2017-09-21 15:05:33 -0700 | |
---|---|---|
committer | 2017-09-21 15:08:28 -0700 | |
commit | 36647440d2e62cb494e4e6f6d5d9144ceb0b29c7 (patch) | |
tree | d97c715ef6c79b205442f254679b1ffa03be94e4 /tensorflow/compiler/xla/literal_util_test.cc | |
parent | 57498a86c11dfc98dda84dc7318a3c84c85c6791 (diff) |
Add methods to convert between Literals and ShapedBuffers to LocalClient. These "conversion" methods copy the data to/from the device into/from literals.
Also fix various issues I noticed along the way:
* Move LocalClient tests into open source.
* Add proper == operator to Literals
* Add << overload for streaming Literals to output.
* Add Literal::GetSubliteral methods.
* Remove unused AllocatBufferOnDevice methods from LocalClient and LocalService.
PiperOrigin-RevId: 169606342
Diffstat (limited to 'tensorflow/compiler/xla/literal_util_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/literal_util_test.cc | 134 |
1 files changed, 76 insertions, 58 deletions
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 61ceac4f9a..e7dedd0821 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -257,37 +257,37 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { } TEST_F(LiteralUtilTest, ScalarEquality) { - // Test Literal::Equal with scalars. + // Test equality with scalars. auto f32_42 = Literal::CreateR0<float>(42.0); auto f32_42_clone = Literal::CreateR0<float>(42.0); - EXPECT_TRUE(f32_42->Equal(*f32_42)); - EXPECT_TRUE(f32_42->Equal(*f32_42_clone)); + EXPECT_EQ(*f32_42, *f32_42); + EXPECT_EQ(*f32_42, *f32_42_clone); auto f32_123 = Literal::CreateR0<float>(123.0); - EXPECT_FALSE(f32_42->Equal(*f32_123)); + EXPECT_NE(*f32_42, *f32_123); auto f64_42 = Literal::CreateR0<double>(42.0); - EXPECT_FALSE(f32_42->Equal(*f64_42)); + EXPECT_NE(*f32_42, *f64_42); } TEST_F(LiteralUtilTest, NonScalarEquality) { - // Test Literal::Equal with nonscalars. + // Test equality with nonscalars. auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); auto matrix_clone = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); auto matrix_different = Literal::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}}); auto vector_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0}); auto scalar = Literal::CreateR0<float>(1.0); - EXPECT_TRUE(matrix->Equal(*matrix)); - EXPECT_TRUE(matrix->Equal(*matrix_clone)); - EXPECT_FALSE(matrix->Equal(*matrix_different)); - EXPECT_FALSE(matrix->Equal(*vector_literal)); - EXPECT_FALSE(matrix->Equal(*scalar)); + EXPECT_EQ(*matrix, *matrix); + EXPECT_EQ(*matrix, *matrix_clone); + EXPECT_NE(*matrix, *matrix_different); + EXPECT_NE(*matrix, *vector_literal); + EXPECT_NE(*matrix, *scalar); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { - // Test Literal::Equal with literals which have different layouts. + // Test equality with literals which have different layouts. auto colmajor = MakeUnique<Literal>(); *colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); *colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -306,11 +306,11 @@ TEST_F(LiteralUtilTest, DifferentLayoutEquality) { rowmajor->Set<float>({1, 0}, 3.0); rowmajor->Set<float>({1, 1}, 4.0); - EXPECT_TRUE(rowmajor->Equal(*colmajor)); + EXPECT_EQ(*rowmajor, *colmajor); } TEST_F(LiteralUtilTest, TupleEquality) { - // Test Literal::Equal with tuples. + // Test equality with tuples. auto scalar = Literal::CreateR0<float>(1.0); auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()}); @@ -319,16 +319,16 @@ TEST_F(LiteralUtilTest, TupleEquality) { // tuple, the other is a clone of the element in the original tuple. auto scalar_clone = Literal::CreateR0<float>(1.0); auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()}); - EXPECT_TRUE(tuple1->Equal(*tuple2)); + EXPECT_EQ(*tuple1, *tuple2); // Tuple with elements reversed. auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()}); - EXPECT_FALSE(tuple1->Equal(*reversed_tuple)); + EXPECT_NE(*tuple1, *reversed_tuple); // Tuple with different value. auto scalar_42 = Literal::CreateR0<float>(42.0); auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()}); - EXPECT_FALSE(tuple1->Equal(*different_tuple)); + EXPECT_NE(*tuple1, *different_tuple); } TEST_F(LiteralUtilTest, IsAllTuple) { @@ -348,7 +348,7 @@ TEST_F(LiteralUtilTest, CreateFromShapeTuple) { auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto x = Literal::CreateFromShape(tuple->shape()); - EXPECT_TRUE(tuple->Equal(*x)); + EXPECT_EQ(*tuple, *x); } TEST_F(LiteralUtilTest, IsAll) { @@ -439,17 +439,17 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { auto data01 = data->Relayout(layout01); EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01)); - EXPECT_TRUE(data->Equal(*data01)); + EXPECT_EQ(*data, *data01); auto data10 = data->Relayout(layout10); EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10)); - EXPECT_TRUE(data->Equal(*data10)); + EXPECT_EQ(*data, *data10); } TEST_F(LiteralUtilTest, ReshapeR0) { auto original = Literal::CreateR0<float>(1.7f); auto reshape = original->Reshape(/*shape=*/{}).ConsumeValueOrDie(); - EXPECT_TRUE(original->Equal(*reshape)); + EXPECT_EQ(*original, *reshape); } TEST_F(LiteralUtilTest, ReshapeR4) { @@ -469,7 +469,7 @@ TEST_F(LiteralUtilTest, ReshapeR4) { // clang-format on auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_TRUE(expected->Equal(*reshape)); + EXPECT_EQ(*expected, *reshape); } TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { @@ -489,13 +489,13 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { // clang-format on auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_TRUE(expected->Equal(*reshape)); + EXPECT_EQ(*expected, *reshape); } TEST_F(LiteralUtilTest, TransposeR0) { auto original = Literal::CreateR0<float>(1.7f); auto reshape = original->Transpose(/*permutation=*/{}); - EXPECT_TRUE(original->Equal(*reshape)); + EXPECT_EQ(*original, *reshape); } TEST_F(LiteralUtilTest, TransposeR4) { @@ -521,13 +521,11 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { // target layout in the first place. auto dim0minor_relaid_to_dim0major = literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_); - EXPECT_TRUE( - literal_r4_2x2x3x3_dim0major_->Equal(*dim0minor_relaid_to_dim0major)); + EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major); auto dim0major_relaid_to_dim0minor = literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_); - EXPECT_TRUE( - literal_r4_2x2x3x3_dim0minor_->Equal(*dim0major_relaid_to_dim0minor)); + EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor); } TEST_F(LiteralUtilTest, TestR2LinearLayout) { @@ -596,14 +594,14 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { TEST_F(LiteralUtilTest, SliceR0S32) { auto input = Literal::CreateR0<int32>(1); auto result = input->Slice({}, {}); - EXPECT_TRUE(input->Equal(*result)); + EXPECT_EQ(*input, *result); } TEST_F(LiteralUtilTest, SliceR1F32) { auto input = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0}); auto result = input->Slice({3}, {4}); auto expected = Literal::CreateR1<float>({4.0}); - EXPECT_TRUE(expected->Equal(*result)); + EXPECT_EQ(*expected, *result); } TEST_F(LiteralUtilTest, SliceR2U32) { @@ -611,49 +609,49 @@ TEST_F(LiteralUtilTest, SliceR2U32) { Literal::CreateR2<uint32>({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); auto result = input_3x4->Slice({0, 2}, {2, 4}); auto expected = Literal::CreateR2<uint32>({{3, 4}, {7, 8}}); - EXPECT_TRUE(expected->Equal(*result)); + EXPECT_EQ(*expected, *result); } TEST_F(LiteralUtilTest, SliceR3U32Full) { auto input_2x3x2 = Literal::CreateR3<uint32>( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); - EXPECT_TRUE(input_2x3x2->Equal(*result)); + EXPECT_EQ(*input_2x3x2, *result); } TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output; output.PopulateR1<int64>({77}); auto expected = Literal::CreateR1<int64>({77}); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR2U64) { Literal output; output.PopulateR1<uint64>({{77, 88}}); auto expected = Literal::CreateR1<uint64>({{77, 88}}); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output; output.PopulateWithValue<float>(2.5f, {}); auto expected = Literal::CreateR0<float>(2.5f); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { Literal output; output.PopulateWithValue<int64>(-7, {3}); auto expected = Literal::CreateR1<int64>({-7, -7, -7}); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { Literal output; output.PopulateWithValue<uint64>(42, {2, 2}); auto expected = Literal::CreateR2<uint64>({{42, 42}, {42, 42}}); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { @@ -661,7 +659,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { half h(0.25f); output.PopulateWithValue<half>(h, {}); auto expected = Literal::CreateR0<half>(h); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { @@ -669,7 +667,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { half h(0.5f); output.PopulateWithValue<half>(h, {3}); auto expected = Literal::CreateR1<half>({h, h, h}); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { @@ -677,7 +675,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { half h(2.0f); output.PopulateWithValue<half>(h, {2, 2}); auto expected = Literal::CreateR2<half>({{h, h}, {h, h}}); - EXPECT_TRUE(output.Equal(*expected)); + EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, ReplicateR2U32) { @@ -688,7 +686,7 @@ TEST_F(LiteralUtilTest, ReplicateR2U32) { {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); - EXPECT_TRUE(output->Equal(*expected)); + EXPECT_EQ(*output, *expected); } TEST_F(LiteralUtilTest, Copy) { @@ -741,7 +739,7 @@ TEST_F(LiteralUtilTest, CopyScalars) { auto zero = Literal::CreateR0<uint32>(0); auto nine = Literal::CreateR0<uint32>(9); TF_EXPECT_OK(zero->Copy(*nine, {}, {}, {})); - EXPECT_TRUE(zero->Equal(*nine)); + EXPECT_EQ(*zero, *nine); auto vect = Literal::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21}); TF_EXPECT_OK(zero->Copy(*vect, {5}, {}, {})); @@ -761,7 +759,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { auto nine = Literal::CreateR1<float>({9}); TF_EXPECT_OK(nine->Copy(*empty, {0}, {0}, {0})); - EXPECT_TRUE(nine->Equal(*const_nine)); + EXPECT_EQ(*nine, *const_nine); } { @@ -770,7 +768,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { auto nine = Literal::CreateR1<float>({9}); TF_EXPECT_OK(empty->Copy(*nine, {0}, {0}, {0})); - EXPECT_TRUE(empty->Equal(*const_empty)); + EXPECT_EQ(*empty, *const_empty); } } @@ -863,7 +861,7 @@ TEST_F(LiteralUtilTest, ConvertR4) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted, original->Convert(U32)); - EXPECT_TRUE(expected->Equal(*converted)); + EXPECT_EQ(*expected, *converted); } TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { @@ -925,43 +923,43 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { std::unique_ptr<Literal> conv; conv = s8->Convert(U32).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*u32)); + EXPECT_EQ(*conv, *u32); conv = s8->Convert(S32).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*s32)); + EXPECT_EQ(*conv, *s32); conv = s8->Convert(U64).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*u64)); + EXPECT_EQ(*conv, *u64); conv = s8->Convert(S64).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*s64)); + EXPECT_EQ(*conv, *s64); conv = s8->Convert(PRED).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*pred)); + EXPECT_EQ(*conv, *pred); conv = pred->Convert(S32).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*int32_pred)); + EXPECT_EQ(*conv, *int32_pred); conv = f32->Convert(S32).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*s32)); + EXPECT_EQ(*conv, *s32); conv = f64->Convert(S32).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*s32)); + EXPECT_EQ(*conv, *s32); conv = s32->Convert(F32).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*f32)); + EXPECT_EQ(*conv, *f32); conv = f32->Convert(F16).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*f16)); + EXPECT_EQ(*conv, *f16); conv = f64->Convert(F16).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*f16)); + EXPECT_EQ(*conv, *f16); conv = s32->Convert(F16).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*f16)); + EXPECT_EQ(*conv, *f16); conv = u32->Convert(F16).ConsumeValueOrDie(); - EXPECT_TRUE(conv->Equal(*f16)); + EXPECT_EQ(*conv, *f16); EXPECT_EQ(s32->Convert(TUPLE).status().code(), tensorflow::error::INVALID_ARGUMENT); @@ -1045,5 +1043,25 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { ASSERT_EQ(h1, r[3]); } +TEST_F(LiteralUtilTest, Subliterals) { + auto scalar = Literal::CreateR0<float>(1.0); + auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + + EXPECT_EQ(&scalar->GetSubliteral(/*index=*/{}), scalar.get()); + EXPECT_EQ(&matrix->GetSubliteral(/*index=*/{}), matrix.get()); + EXPECT_EQ(&tuple->GetSubliteral(/*index=*/{}), tuple.get()); + EXPECT_EQ(&nested_tuple->GetSubliteral(/*index=*/{}), nested_tuple.get()); + + EXPECT_EQ(tuple->GetSubliteral(/*index=*/{0}), *scalar); + EXPECT_EQ(tuple->GetSubliteral(/*index=*/{1}), *matrix); + + EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0}), *tuple); + EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 0}), *scalar); + EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 1}), *matrix); + EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{1}), *scalar); +} + } // namespace } // namespace xla |