From d4756a5cf768408c6e94fc79dbbe0de5d8e00fb9 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Fri, 26 May 2017 15:35:03 -0700 Subject: Move layout-forcing methods onto HloTestBase where all HLO tests can use them PiperOrigin-RevId: 157271246 --- tensorflow/compiler/xla/tests/copy_test.cc | 13 ++----------- tensorflow/compiler/xla/tests/hlo_test_base.h | 28 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 1e5f00c30d..6b28d9b032 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -184,10 +184,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { auto hlo_module = MakeUnique("test_module"); hlo_module->AddEntryComputation(std::move(computation)); - *hlo_module->mutable_entry_computation_layout()->mutable_result_layout() = - ShapeLayout(ShapeUtil::MakeShapeWithLayout( - constant->shape().element_type(), - AsInt64Slice(constant->shape().dimensions()), {1, 2, 0})); + ForceResultLayout(hlo_module.get(), LayoutUtil::MakeLayout({1, 2, 0})); std::unique_ptr result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -222,13 +219,7 @@ void CopyOpTest::TestCopyConstantLayoutR4( auto hlo_module = MakeUnique("test_module"); hlo_module->AddEntryComputation(std::move(computation)); - *hlo_module->mutable_entry_computation_layout()->mutable_result_layout() = - ShapeLayout(ShapeUtil::MakeShapeWithLayout( - constant->shape().element_type(), - AsInt64Slice(constant->shape().dimensions()), ({ - std::vector p(permutation.rbegin(), permutation.rend()); - p; - }))); + ForceResultLayout(hlo_module.get(), LayoutUtil::MakeLayout(permutation)); std::unique_ptr result = ExecuteAndTransfer(std::move(hlo_module), {}); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index a401e13228..dcf7bad2f9 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -66,6 +66,34 @@ class HloTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments); + // Convenience method to force the layout of a given parameter in a module. + // The layout of parameter number 'param_no' in the 'module' is set to + // 'layout'. + void ForceParameterLayout(HloModule* module, int64 param_no, + const Layout& layout) { + ASSERT_LT(param_no, + module->mutable_entry_computation_layout()->parameter_count()); + module->mutable_entry_computation_layout() + ->mutable_parameter_layout(param_no) + ->ResetLayout(layout); + } + + // Convenience method to force the layout of the computation result in a + // module. The result layout of 'module' is set to 'layout'. + void ForceResultLayout(HloModule* module, const Layout& layout) { + module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(layout); + } + + // Convenience method to clear the layout of the computation result in + // 'module'. + void ForceClearResultLayout(HloModule* module) { + module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->Clear(); + } + string TestName() const; std::unique_ptr backend_; -- cgit v1.2.3