diff options
author | Yuanzhong Xu <yuanzx@google.com> | 2018-09-06 16:40:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-06 16:46:11 -0700 |
commit | c4df798540b83026ccc74d69da38960e43af8f55 (patch) | |
tree | 73945b9307eabdfef9604a0a27a9ecae8fa2cc37 /tensorflow/compiler/xla/service | |
parent | 7caba396ba81a0a19efd92a01aa7a3b695e3009b (diff) |
[XLA] Handle kDomain in HloCostAnalysis.
PiperOrigin-RevId: 211891325
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cost_analysis.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cost_analysis.h | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc | 26 |
3 files changed, 34 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 939b5114c3..8b4eaad82e 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -227,6 +227,14 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) { + // Domain does not have any computation or data transfer. + current_should_compute_bottleneck_time_ = false; + current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; + return Status::OK(); +} + Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); const Shape& rhs_shape = dot->operand(1)->shape(); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 9bb3f12ee2..46b4bbeef2 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -67,6 +67,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleRecvDone(const HloInstruction* recv_done) override; Status HandleConvert(const HloInstruction* convert) override; Status HandleCopy(const HloInstruction* copy) override; + Status HandleDomain(const HloInstruction* domain) override; Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 2c854eea18..15a5f8374d 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -415,7 +415,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { - XlaBuilder builder("matmul"); + XlaBuilder builder("tuple"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x"); auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y"); Tuple(&builder, {x, y}); @@ -430,6 +430,30 @@ TEST_F(HloCostAnalysisTest, TupleCost) { EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2); } +using DomainCostAnalysis = HloTestBase; +TEST_F(DomainCostAnalysis, DomainCost) { + HloCostAnalysis analysis(ShapeSize); + + HloComputation::Builder builder("domain"); + auto x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {123}), "x")); + auto y = builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y")); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y})); + auto domain = builder.AddInstruction( + HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain); + ASSERT_IS_OK(domain->Accept(&analysis)); + + EXPECT_EQ(analysis.flop_count(*domain), 0); + EXPECT_EQ(analysis.transcendental_count(*domain), 0); + EXPECT_EQ(analysis.bytes_accessed(*domain), 0); +} + TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { XlaBuilder builder("BaseDilatedConvolution"); auto input = Parameter( |