aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc22
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
index df66408022..6041debc4a 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
@@ -137,6 +137,28 @@ TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) {
EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo");
}
+TEST_F(HloTfGraphBuilderTest, IncorparateTfOpsStructure) {
+ auto builder = HloComputation::Builder("GE");
+ auto param_1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32_, "param0"));
+ auto param_2 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r0f32_, "param1"));
+ auto ge = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2));
+ OpMetadata metadata;
+ metadata.set_op_name("x/y");
+ metadata.set_op_type("Y");
+ ge->set_metadata(metadata);
+ TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
+ GraphDef graph_def = generator_.GetGraphDef();
+ EXPECT_EQ(graph_def.node_size(), 3);
+ EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0");
+ EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1");
+ EXPECT_EQ(graph_def.node(2).input_size(), 2);
+ EXPECT_EQ(graph_def.node(2).name(), "GE/x/y/greater-than-or-equal-to");
+ EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo");
+}
+
TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) {
// Create computations with a diamond-shaped callgraph.
auto negate_computation = CreateNegateComputation();