aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-20 07:02:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-20 08:25:03 -0700
commit88b81ac944dff7e5c1fa820cc442717bbc75f62a (patch)
tree44a97a364590e63349847f7334f6dddeb05d9307
parent3a95e41426dd0745a772914815e23b997419848e (diff)
[XLA:HLO] Also clone metadata when cloning instructions e.g. in fusion.
Without the metadata, it's hard to correlate HLO instructions to TF ops after fusion. Change: 153709552
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc20
3 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 7ca2832ec7..1970b213c9 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -178,6 +178,7 @@ cc_test(
deps = [
":hlo",
"//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index d15b8236bb..1ede4e963f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -435,6 +435,7 @@ HloInstruction::CreateSelectAndScatter(
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
instruction->fusion_kind_ = fusion_kind;
instruction->set_parent(fused_root->parent());
+ instruction->set_metadata(fused_root->metadata());
instruction->CloneAndFuseInternal(fused_root);
instruction->CheckFusionInstruction();
return instruction;
@@ -858,6 +859,7 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix) {
CloneWithNewOperands(shape_, operands_);
clone->name_ = name() + "." + suffix;
clone->set_parent(parent());
+ clone->set_metadata(metadata_);
return clone;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 050fceca9c..eeabc61ec8 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -609,6 +610,25 @@ TEST_F(HloInstructionTest, ChainFusionOp) {
UnorderedElementsAre(fusion.get(), exp1.get()));
}
+TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
+ // Create a chain of fused unary ops.
+ auto constant =
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f));
+ auto exp1 =
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
+ auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get());
+ OpMetadata metadata;
+ metadata.set_op_name("tf_op");
+ exp1->set_metadata(metadata);
+ exp2->set_metadata(metadata);
+
+ auto fusion = HloInstruction::CreateFusion(
+ r0f32_, HloInstruction::FusionKind::kLoop, exp2.get());
+ auto* fused = fusion->FuseInstruction(exp1.get());
+ EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
+ EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata()));
+}
+
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
// Create a fusion instruction containing a single unary operation.
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});