aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar Yuanzhong Xu <yuanzx@google.com>2018-09-06 16:40:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 16:46:11 -0700
commitc4df798540b83026ccc74d69da38960e43af8f55 (patch)
tree73945b9307eabdfef9604a0a27a9ecae8fa2cc37 /tensorflow/compiler/xla/service
parent7caba396ba81a0a19efd92a01aa7a3b695e3009b (diff)
[XLA] Handle kDomain in HloCostAnalysis.
PiperOrigin-RevId: 211891325
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc26
3 files changed, 34 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 939b5114c3..8b4eaad82e 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -227,6 +227,14 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) {
return Status::OK();
}
+Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
+ // Domain does not have any computation or data transfer.
+ current_should_compute_bottleneck_time_ = false;
+ current_properties_[kBytesAccessedKey] = 0;
+ current_properties_[kOptimalSecondsKey] = 0;
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
const Shape& lhs_shape = dot->operand(0)->shape();
const Shape& rhs_shape = dot->operand(1)->shape();
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 9bb3f12ee2..46b4bbeef2 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -67,6 +67,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleRecvDone(const HloInstruction* recv_done) override;
Status HandleConvert(const HloInstruction* convert) override;
Status HandleCopy(const HloInstruction* copy) override;
+ Status HandleDomain(const HloInstruction* domain) override;
Status HandleDot(const HloInstruction* dot) override;
Status HandleConvolution(const HloInstruction* convolution) override;
Status HandleFft(const HloInstruction* fft) override;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 2c854eea18..15a5f8374d 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -415,7 +415,7 @@ TEST_F(FusionCostAnalysis, NoLayout) {
TEST_F(HloCostAnalysisTest, TupleCost) {
HloCostAnalysis analysis(ShapeSize);
{
- XlaBuilder builder("matmul");
+ XlaBuilder builder("tuple");
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x");
auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y");
Tuple(&builder, {x, y});
@@ -430,6 +430,30 @@ TEST_F(HloCostAnalysisTest, TupleCost) {
EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
}
+using DomainCostAnalysis = HloTestBase;
+TEST_F(DomainCostAnalysis, DomainCost) {
+ HloCostAnalysis analysis(ShapeSize);
+
+ HloComputation::Builder builder("domain");
+ auto x = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {123}), "x"));
+ auto y = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y"));
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y}));
+ auto domain = builder.AddInstruction(
+ HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
+
+ auto hlo_module = CreateNewModule();
+ hlo_module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain);
+ ASSERT_IS_OK(domain->Accept(&analysis));
+
+ EXPECT_EQ(analysis.flop_count(*domain), 0);
+ EXPECT_EQ(analysis.transcendental_count(*domain), 0);
+ EXPECT_EQ(analysis.bytes_accessed(*domain), 0);
+}
+
TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
XlaBuilder builder("BaseDilatedConvolution");
auto input = Parameter(