diff options
-rw-r--r-- | tensorflow/compiler/xla/tests/copy_test.cc | 13 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/hlo_test_base.h | 28 |
2 files changed, 30 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 1e5f00c30d..6b28d9b032 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -184,10 +184,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { auto hlo_module = MakeUnique<HloModule>("test_module"); hlo_module->AddEntryComputation(std::move(computation)); - *hlo_module->mutable_entry_computation_layout()->mutable_result_layout() = - ShapeLayout(ShapeUtil::MakeShapeWithLayout( - constant->shape().element_type(), - AsInt64Slice(constant->shape().dimensions()), {1, 2, 0})); + ForceResultLayout(hlo_module.get(), LayoutUtil::MakeLayout({1, 2, 0})); std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -222,13 +219,7 @@ void CopyOpTest::TestCopyConstantLayoutR4( auto hlo_module = MakeUnique<HloModule>("test_module"); hlo_module->AddEntryComputation(std::move(computation)); - *hlo_module->mutable_entry_computation_layout()->mutable_result_layout() = - ShapeLayout(ShapeUtil::MakeShapeWithLayout( - constant->shape().element_type(), - AsInt64Slice(constant->shape().dimensions()), ({ - std::vector<int64> p(permutation.rbegin(), permutation.rend()); - p; - }))); + ForceResultLayout(hlo_module.get(), LayoutUtil::MakeLayout(permutation)); std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(hlo_module), {}); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index a401e13228..dcf7bad2f9 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -66,6 +66,34 @@ class HloTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> arguments); + // Convenience method to force the layout of a given parameter in a module. + // The layout of parameter number 'param_no' in the 'module' is set to + // 'layout'. + void ForceParameterLayout(HloModule* module, int64 param_no, + const Layout& layout) { + ASSERT_LT(param_no, + module->mutable_entry_computation_layout()->parameter_count()); + module->mutable_entry_computation_layout() + ->mutable_parameter_layout(param_no) + ->ResetLayout(layout); + } + + // Convenience method to force the layout of the computation result in a + // module. The result layout of 'module' is set to 'layout'. + void ForceResultLayout(HloModule* module, const Layout& layout) { + module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(layout); + } + + // Convenience method to clear the layout of the computation result in + // 'module'. + void ForceClearResultLayout(HloModule* module) { + module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->Clear(); + } + string TestName() const; std::unique_ptr<Backend> backend_; |