aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc28
2 files changed, 29 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 11d931cbd4..8b3fa6c157 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -689,7 +689,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
Array* operand) {
auto* scalar_indexed_const =
dynamic_cast<ScalarIndexedConstantArray*>(operand);
- if (operand == nullptr) {
+ if (scalar_indexed_const == nullptr) {
return nullptr;
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index 68f247bfc3..373556ebeb 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -472,5 +472,33 @@ ENTRY main {
AssertArrayForRootExpressionIs(hlo_text, "%add");
}
+
+TEST_F(IndexedArrayAnalysisTest, RegularUnaryOp) {
+ string hlo_text = R"(
+HloModule RegularUnaryOp
+
+ENTRY main {
+ input = f32[100] parameter(0)
+ ROOT tanh = f32[100] tanh(input)
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%tanh");
+}
+
+TEST_F(IndexedArrayAnalysisTest, RegularBinaryOp) {
+ string hlo_text = R"(
+HloModule RegularUnaryOp
+
+ENTRY main {
+ input0 = f32[100] parameter(0)
+ input1 = f32[100] parameter(1)
+ ROOT add = f32[100] add(input0, input1)
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%add");
+}
+
} // namespace
} // namespace xla