diff options
author | 2018-09-13 15:27:12 -0700 | |
---|---|---|
committer | 2018-09-13 15:31:33 -0700 | |
commit | ea52ecd836098e0b1d37325cf1b91133f908547e (patch) | |
tree | 324efc21342c114d3553422eae8a7838f4cfb370 /tensorflow | |
parent | eb7953970c8b2b8a054cddf8ed4b78e66fcd2d02 (diff) |
Fix bug in kSlice implementation in evaluator.
Slice was producing a literal with a default layout rather than the layout of
the slice HLO instruction. This resulted in errors when the produced literal
was consumed by later operations.
PiperOrigin-RevId: 212889334
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_test.cc | 19 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 16 |
3 files changed, 28 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 064b86493d..06b6d5b559 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1339,6 +1339,12 @@ Status HloEvaluator::Preprocess(HloInstruction* hlo) { Status HloEvaluator::Postprocess(HloInstruction* hlo) { VLOG(2) << "Finished visiting " << hlo->ToString() << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); + // Out of convenience the literal may have been produced with a different + // layout. Relayout as indicated by the HLO instruction. + if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(), + hlo->shape())) { + evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape()); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 16411eb078..01e88566a5 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -2570,6 +2570,25 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); } +TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) { + // Regression test for b/114735354. + const string hlo_text = R"( +HloModule SliceWithDifferentLayout + +ENTRY main { + arg = f32[2,2,2]{0,1,2} parameter(0) + ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]} +} +)"; + ParseAndVerifyModule(hlo_text); + + Literal arg = LiteralUtil::CreateR3WithLayout<float>( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + LayoutUtil::MakeLayout({0, 1, 2})); + Literal actual = Evaluate({&arg}); + EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); +} + INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, ::testing::ValuesIn(use_bf16_params)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 7f090a52db..8fb17a0033 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -249,12 +249,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).Convert( convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = result.Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } @@ -265,11 +260,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( convert->shape().element_type())); - if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = result.Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } @@ -2350,8 +2341,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return operand_literal.Get<ReturnT>(operand_index); }; - auto result = LiteralUtil::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); + Literal result(shape); TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func)); parent_->evaluated_[slice] = std::move(result); return Status::OK(); |