aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-06-07 17:19:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 17:23:31 -0700
commit138e790ab9cb778430168d2b5f6abac1501aa2d8 (patch)
tree2d79d5a45010e4faec8155d0cb7f14856d794056 /tensorflow
parent3bb7a913be6ba47df6fb1796dd8ce639cdbf1608 (diff)
[XLA] Handle kSlice correctly in HloCostAnalysis
Slice doesn't read the entire input. It only reads enough to make the output. PiperOrigin-RevId: 199722987
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc15
2 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 94c9c7eabc..b9d30ee802 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -172,7 +172,8 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
return Status::OK();
}
-Status HloCostAnalysis::HandleSlice(const HloInstruction*) {
+Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) {
+ current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2;
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 16fdda8a8b..72adf09c83 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -460,5 +460,20 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
EXPECT_EQ(analysis.flop_count(), 1472);
}
+TEST_F(HloCostAnalysisTest, Slice) {
+ // Test the analysis on a slice.
+ XlaBuilder builder("slice");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x");
+ auto slice = builder.Slice(x, {0}, {1}, {1});
+ auto hlo_module = BuildHloGraph(&builder);
+
+ // Run HLO cost analysis.
+ HloCostAnalysis analysis(ShapeSize);
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ EXPECT_EQ(analysis.bytes_accessed(), 8);
+}
+
} // namespace
} // namespace xla