aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eli Bendersky <eliben@google.com>2017-05-26 15:35:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-26 15:38:54 -0700
commitd4756a5cf768408c6e94fc79dbbe0de5d8e00fb9 (patch)
tree5ef187d391ad17c0082c5a05e78af97932163f29
parent2be3ae8a1268c2745c4761665d8e2444294bca8b (diff)
Move layout-forcing methods onto HloTestBase where all HLO tests can use them
PiperOrigin-RevId: 157271246
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc13
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h28
2 files changed, 30 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index 1e5f00c30d..6b28d9b032 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -184,10 +184,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
auto hlo_module = MakeUnique<HloModule>("test_module");
hlo_module->AddEntryComputation(std::move(computation));
- *hlo_module->mutable_entry_computation_layout()->mutable_result_layout() =
- ShapeLayout(ShapeUtil::MakeShapeWithLayout(
- constant->shape().element_type(),
- AsInt64Slice(constant->shape().dimensions()), {1, 2, 0}));
+ ForceResultLayout(hlo_module.get(), LayoutUtil::MakeLayout({1, 2, 0}));
std::unique_ptr<Literal> result =
ExecuteAndTransfer(std::move(hlo_module), {});
@@ -222,13 +219,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(
auto hlo_module = MakeUnique<HloModule>("test_module");
hlo_module->AddEntryComputation(std::move(computation));
- *hlo_module->mutable_entry_computation_layout()->mutable_result_layout() =
- ShapeLayout(ShapeUtil::MakeShapeWithLayout(
- constant->shape().element_type(),
- AsInt64Slice(constant->shape().dimensions()), ({
- std::vector<int64> p(permutation.rbegin(), permutation.rend());
- p;
- })));
+ ForceResultLayout(hlo_module.get(), LayoutUtil::MakeLayout(permutation));
std::unique_ptr<Literal> result =
ExecuteAndTransfer(std::move(hlo_module), {});
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index a401e13228..dcf7bad2f9 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -66,6 +66,34 @@ class HloTestBase : public ::testing::Test {
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments);
+ // Convenience method to force the layout of a given parameter in a module.
+ // The layout of parameter number 'param_no' in the 'module' is set to
+ // 'layout'.
+ void ForceParameterLayout(HloModule* module, int64 param_no,
+ const Layout& layout) {
+ ASSERT_LT(param_no,
+ module->mutable_entry_computation_layout()->parameter_count());
+ module->mutable_entry_computation_layout()
+ ->mutable_parameter_layout(param_no)
+ ->ResetLayout(layout);
+ }
+
+ // Convenience method to force the layout of the computation result in a
+ // module. The result layout of 'module' is set to 'layout'.
+ void ForceResultLayout(HloModule* module, const Layout& layout) {
+ module->mutable_entry_computation_layout()
+ ->mutable_result_layout()
+ ->ResetLayout(layout);
+ }
+
+ // Convenience method to clear the layout of the computation result in
+ // 'module'.
+ void ForceClearResultLayout(HloModule* module) {
+ module->mutable_entry_computation_layout()
+ ->mutable_result_layout()
+ ->Clear();
+ }
+
string TestName() const;
std::unique_ptr<Backend> backend_;