aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 050fceca9c..eeabc61ec8 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -609,6 +610,25 @@ TEST_F(HloInstructionTest, ChainFusionOp) {
UnorderedElementsAre(fusion.get(), exp1.get()));
}
+TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
+ // Create a chain of fused unary ops.
+ auto constant =
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f));
+ auto exp1 =
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
+ auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get());
+ OpMetadata metadata;
+ metadata.set_op_name("tf_op");
+ exp1->set_metadata(metadata);
+ exp2->set_metadata(metadata);
+
+ auto fusion = HloInstruction::CreateFusion(
+ r0f32_, HloInstruction::FusionKind::kLoop, exp2.get());
+ auto* fused = fusion->FuseInstruction(exp1.get());
+ EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
+ EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata()));
+}
+
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
// Create a fusion instruction containing a single unary operation.
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});