diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/hlo_test_base.h')
-rw-r--r-- | tensorflow/compiler/xla/tests/hlo_test_base.h | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 9009d67cea..66719b1460 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -200,6 +200,13 @@ class HloTestBase : public ::testing::Test { ->ResetLayout(layout); } + void ForceResultLayout(HloModule* module, const Layout& layout, + ShapeIndexView shape_index) { + module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(layout, shape_index); + } + // Convenience method to clear the layout of the computation result in // 'module'. void ForceClearResultLayout(HloModule* module) { |