diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation_test.cc | 84 |
1 files changed, 44 insertions, 40 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index a8f3f0e9c2..e4c5470331 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include <set> -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -118,7 +118,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { // Test GetInstructionPostOrder for a computation with one instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); @@ -129,7 +129,7 @@ TEST_F(HloComputationTest, PostOrderSimple) { // instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( @@ -144,7 +144,7 @@ TEST_F(HloComputationTest, PostOrderTrace) { // Test GetInstructionPostOrder for a computation with a trace instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto trace = @@ -163,13 +163,13 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), @@ -181,11 +181,11 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -205,11 +205,11 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { // computation has multiple roots (dead code). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); // Add three disconnected add expressions. builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant2)); @@ -256,7 +256,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { // Test that DeepCopyInstruction properly copies an array. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1<float>({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); @@ -268,9 +268,9 @@ TEST_F(HloComputationTest, DeepCopyTuple) { // Test that DeepCopyInstruction properly copies a tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1<float>({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -289,7 +289,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) { // copy are specified. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1<float>({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}))); auto computation = builder.Build(); { @@ -314,9 +314,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { // specified by the given indices. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1<float>({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto computation = builder.Build(); @@ -375,7 +375,7 @@ TEST_F(HloComputationTest, DeepCopyToken) { // Test that DeepCopyInstruction properly handles tokens which should not be // copied. auto builder = HloComputation::Builder(TestName()); - auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); @@ -388,9 +388,9 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) { // Test that DeepCopyInstruction properly handles tokens which should not be // copied. auto builder = HloComputation::Builder(TestName()); - auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({token, constant})); auto module = CreateNewModule(); @@ -407,7 +407,7 @@ TEST_F(HloComputationTest, CycleDetection) { // Test whether the visitor can detect cycles in the graph. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto add = builder.AddInstruction( @@ -433,7 +433,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { // twice. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto dead_negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary( @@ -456,9 +456,9 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { TEST_F(HloComputationTest, CloneWithControlDependency) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); @@ -502,9 +502,9 @@ TEST_F(HloComputationTest, Reachability) { // There is a control dependency from 'add' to 'exp'. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto negate = builder.AddInstruction( @@ -607,13 +607,14 @@ TEST_F(HloComputationTest, Stringification) { auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); - EXPECT_EQ(computation->ToString(options), - R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + const string expected_computation = + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { %x = f32[5,10]{1,0} parameter(0) %y = f32[20,10]{1,0} parameter(1) %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"); +})"; + EXPECT_EQ(computation->ToString(options), expected_computation); } TEST_F(HloComputationTest, StringificationIndent) { @@ -639,13 +640,14 @@ TEST_F(HloComputationTest, StringificationIndent) { auto options = HloPrintOptions().set_print_metadata(false).set_indent_amount(2); - EXPECT_EQ(computation->ToString(options), - R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + const string expected_computation = + R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { %x = f32[5,10]{1,0} parameter(0) %y = f32[20,10]{1,0} parameter(1) %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} - })"); + })"; + EXPECT_EQ(computation->ToString(options), expected_computation); } TEST_F(HloComputationTest, StringificationCanonical) { @@ -670,21 +672,23 @@ TEST_F(HloComputationTest, StringificationCanonical) { auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); - EXPECT_EQ(computation->ToString(options), - R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + const string expected_computation1 = + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { %x = f32[5,10]{1,0} parameter(0) %y = f32[20,10]{1,0} parameter(1) %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"); +})"; + EXPECT_EQ(computation->ToString(options), expected_computation1); options = HloPrintOptions().Canonical(); - EXPECT_EQ(computation->ToString(options), R"(TransposeDot { + const string expected_computation2 = R"(TransposeDot { tmp_0 = f32[5,10]{1,0} parameter(0) tmp_1 = f32[20,10]{1,0} parameter(1) tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"); +})"; + EXPECT_EQ(computation->ToString(options), expected_computation2); } } // namespace |