diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index df66408022..6041debc4a 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -137,6 +137,28 @@ TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) { EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); } +TEST_F(HloTfGraphBuilderTest, IncorparateTfOpsStructure) { + auto builder = HloComputation::Builder("GE"); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32_, "param1")); + auto ge = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); + OpMetadata metadata; + metadata.set_op_name("x/y"); + metadata.set_op_type("Y"); + ge->set_metadata(metadata); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 3); + EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); + EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); + EXPECT_EQ(graph_def.node(2).input_size(), 2); + EXPECT_EQ(graph_def.node(2).name(), "GE/x/y/greater-than-or-equal-to"); + EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); +} + TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) { // Create computations with a diamond-shaped callgraph. auto negate_computation = CreateNegateComputation(); |