diff options
author | 2018-06-07 17:19:25 -0700 | |
---|---|---|
committer | 2018-06-07 17:23:31 -0700 | |
commit | 138e790ab9cb778430168d2b5f6abac1501aa2d8 (patch) | |
tree | 2d79d5a45010e4faec8155d0cb7f14856d794056 /tensorflow | |
parent | 3bb7a913be6ba47df6fb1796dd8ce639cdbf1608 (diff) |
[XLA] Handle kSlice correctly in HloCostAnalysis
Slice doesn't read the entire input. It only reads enough to make the output.
PiperOrigin-RevId: 199722987
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cost_analysis.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc | 15 |
2 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 94c9c7eabc..b9d30ee802 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -172,7 +172,8 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) { + current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 16fdda8a8b..72adf09c83 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -460,5 +460,20 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { EXPECT_EQ(analysis.flop_count(), 1472); } +TEST_F(HloCostAnalysisTest, Slice) { + // Test the analysis on a slice. + XlaBuilder builder("slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.Slice(x, {0}, {1}, {1}); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + } // namespace } // namespace xla |