aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/hlo_test_base.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/hlo_test_base.h')
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 9009d67cea..66719b1460 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -200,6 +200,13 @@ class HloTestBase : public ::testing::Test {
->ResetLayout(layout);
}
+ void ForceResultLayout(HloModule* module, const Layout& layout,
+ ShapeIndexView shape_index) {
+ module->mutable_entry_computation_layout()
+ ->mutable_result_layout()
+ ->ResetLayout(layout, shape_index);
+ }
+
// Convenience method to clear the layout of the computation result in
// 'module'.
void ForceClearResultLayout(HloModule* module) {