diff options
-rw-r--r-- | tensorflow/compiler/xla/service/indexed_array_analysis.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/indexed_array_analysis_test.cc | 28 |
2 files changed, 29 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 11d931cbd4..8b3fa6c157 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -689,7 +689,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, Array* operand) { auto* scalar_indexed_const = dynamic_cast<ScalarIndexedConstantArray*>(operand); - if (operand == nullptr) { + if (scalar_indexed_const == nullptr) { return nullptr; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 68f247bfc3..373556ebeb 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -472,5 +472,33 @@ ENTRY main { AssertArrayForRootExpressionIs(hlo_text, "%add"); } + +TEST_F(IndexedArrayAnalysisTest, RegularUnaryOp) { + string hlo_text = R"( +HloModule RegularUnaryOp + +ENTRY main { + input = f32[100] parameter(0) + ROOT tanh = f32[100] tanh(input) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%tanh"); +} + +TEST_F(IndexedArrayAnalysisTest, RegularBinaryOp) { + string hlo_text = R"( +HloModule RegularUnaryOp + +ENTRY main { + input0 = f32[100] parameter(0) + input1 = f32[100] parameter(1) + ROOT add = f32[100] add(input0, input1) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%add"); +} + } // namespace } // namespace xla |