aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_ordering_test.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-01-24 12:12:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-24 12:15:50 -0800
commit02c2214b8da916a4a9e9eb3fde2711e49e7b4703 (patch)
tree965383e084bec7465fe7bc2421466c9be0b51300 /tensorflow/compiler/xla/service/hlo_ordering_test.cc
parent94faa1e921d647f321ca1b3204a25aa0dbede856 (diff)
[XLA] Fix crash in HloOrdering::ToString().
HloOrdering::ToString() was crashing when given a computation containing a fusion instruction. PiperOrigin-RevId: 183121645
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_ordering_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc52
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 33bafd05c1..aba66114de 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -310,5 +311,56 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
*dataflow));
}
+// Regression test for HloOrdering::ToString() crashing when fed a computation
+// containing a fusion node.
+TEST_F(HloOrderingTest, ToStringDoesNotCrash) {
+ const char* module_str = R"(
+HloModule test_module
+
+body.v8 {
+ prev.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
+ get-tuple-element.4 = s32[] get-tuple-element(prev.1), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.4, constant.1)
+ get-tuple-element.5 = f32[3]{0} get-tuple-element(prev.1), index=3
+ get-tuple-element.6 = f32[3]{0} get-tuple-element(prev.1), index=1
+ get-tuple-element.7 = f32[3]{0} get-tuple-element(prev.1), index=2
+ ROOT tuple = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(add, get-tuple-element.5, get-tuple-element.6, get-tuple-element.7)
+}
+
+condition.v4 {
+ constant.2 = s32[] constant(2)
+ prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
+ get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0
+ ROOT greater-than = pred[] greater-than(constant.2, get-tuple-element.8)
+}
+
+fused_computation {
+ get-tuple-element.5.param_1 = f32[3]{0} parameter(1)
+ get-tuple-element.6.param_2 = f32[3]{0} parameter(2)
+ add.4 = f32[3]{0} add(get-tuple-element.5.param_1, get-tuple-element.6.param_2)
+ get-tuple-element.7.param_1.1 = f32[3]{0} parameter(0)
+ ROOT add.5 = f32[3]{0} add(add.4, get-tuple-element.7.param_1.1)
+}
+
+ENTRY while.v11 {
+ constant.5 = s32[] constant(0)
+ constant.6 = f32[3]{0} constant({1, 1, 1})
+ constant.7 = f32[3]{0} constant({2, 2, 2})
+ constant.8 = f32[3]{0} constant({3, 3, 3})
+ tuple.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(constant.5, constant.6, constant.7, constant.8)
+ while = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) while(tuple.1), condition=condition.v4, body=body.v8
+ get-tuple-element.9 = f32[3]{0} get-tuple-element(while), index=3
+ get-tuple-element.10 = f32[3]{0} get-tuple-element(while), index=1
+ get-tuple-element.11 = f32[3]{0} get-tuple-element(while), index=2
+ ROOT fusion = f32[3]{0} fusion(get-tuple-element.9, get-tuple-element.10, get-tuple-element.11), kind=kLoop, calls=fused_computation
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(module_str));
+ DependencyHloOrdering ordering(module.get());
+ ordering.ToString(); // Shouldn't crash.
+}
+
} // namespace
} // namespace xla