aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar Nick Desaulniers <ndesaulniers@google.com>2018-04-10 10:52:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 10:55:15 -0700
commitc276b8314cd3161c5626d845edcfb6697cefd043 (patch)
treefc7b2253f32700eaeb444b7f38871a2ebd903316 /tensorflow/compiler/xla
parent36a07c59954b8ace54879b8732b6a7ae2dce6450 (diff)
[TF:XLA] fix a segfault in MakeFakeArguments, and add a test case.
PiperOrigin-RevId: 192310749
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/tests/BUILD13
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc2
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc57
3 files changed, 71 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 6c43014b33..699b077d80 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1969,3 +1969,16 @@ tf_cc_test(
"//tensorflow/core:test",
],
)
+
+xla_test(
+ name = "test_utils_test",
+ srcs = ["test_utils_test.cc"],
+ deps = [
+ ":local_client_test_base",
+ ":test_utils",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
+ ],
+)
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 68f75d50cb..e30d115fae 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -165,7 +165,7 @@ enum class ConstantType { kUnknown, kZero, kOne };
// Return the constant type required by this computation, if known.
ConstantType GetInitValue(const HloComputation& computation) {
const HloInstruction* const root = computation.root_instruction();
- if (computation.num_parameters() != 2 ||
+ if (computation.num_parameters() != 2 || root->operand_count() != 2 ||
root->operand(0)->opcode() != HloOpcode::kParameter ||
root->operand(1)->opcode() != HloOpcode::kParameter ||
root->operand(0) == root->operand(1)) {
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
new file mode 100644
index 0000000000..e8efc6e2a8
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -0,0 +1,57 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+// A test fixture is used because we need a client for our computation builder.
+class TestUtilsTest : public LocalClientTestBase {};
+
+XLA_TEST_F(TestUtilsTest, UnusedParam) {
+ ComputationBuilder builder(local_client_, TestName());
+ // Make the reduction lambda.
+ Shape single_float = ShapeUtil::MakeShape(F32, {});
+ builder.Parameter(0, single_float, "unused");
+ builder.Parameter(1, single_float, "used");
+ auto computation_status = builder.Build();
+ TF_ASSERT_OK(computation_status.status());
+
+ // Make the reduction.
+ Shape pair_float = ShapeUtil::MakeShape(F32, {2});
+ builder.Reduce(builder.Parameter(0, pair_float, "operand"),
+ builder.Parameter(1, single_float, "init"),
+ computation_status.ValueOrDie(), {0});
+ computation_status = builder.Build();
+ TF_ASSERT_OK(computation_status.status());
+
+ auto executable_status = local_client_->Compile(
+ computation_status.ValueOrDie(), {&pair_float, &single_float},
+ ExecutableBuildOptions());
+ TF_ASSERT_OK(executable_status.status());
+ HloModule& module = const_cast<HloModule&>(
+ executable_status.ValueOrDie()->executable()->module());
+ TF_ASSERT_OK(MakeFakeArguments(&module).status());
+}
+
+} // namespace
+} // namespace xla