diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction_test.cc | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 050fceca9c..eeabc61ec8 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -609,6 +610,25 @@ TEST_F(HloInstructionTest, ChainFusionOp) { UnorderedElementsAre(fusion.get(), exp1.get())); } +TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { + // Create a chain of fused unary ops. + auto constant = + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)); + auto exp1 = + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); + auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get()); + OpMetadata metadata; + metadata.set_op_name("tf_op"); + exp1->set_metadata(metadata); + exp2->set_metadata(metadata); + + auto fusion = HloInstruction::CreateFusion( + r0f32_, HloInstruction::FusionKind::kLoop, exp2.get()); + auto* fused = fusion->FuseInstruction(exp1.get()); + EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata())); + EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata())); +} + TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); |