aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc4
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc3
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc79
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc92
6 files changed, 178 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 8519012f7a..f833d7fe09 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -2419,7 +2419,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
HloComputation* condition = xla_while->while_condition();
TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
condition->root_instruction()->shape().element_type() == PRED)
- << "While condition computation must return bool";
+ << "While condition computation must return bool; got: "
+ << ShapeUtil::HumanString(condition->root_instruction()->shape());
// Check that all while-related buffers share an allocation slice.
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
xla_while->shape(),
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 7ad33c8947..cf42e030c2 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -36,8 +36,8 @@ void DumpModule(const HloModule& module,
const string& message) {
hlo_graph_dumper::MaybeDumpHloModule(module, message);
- VLOG(2) << "HLO " << message << ":";
- XLA_VLOG_LINES(2, module.ToString());
+ VLOG(3) << "HLO " << message << ":";
+ XLA_VLOG_LINES(3, module.ToString());
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 43683ebe21..d8ed473138 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -370,6 +370,7 @@ xla_test(
xla_test(
name = "params_test",
srcs = ["params_test.cc"],
+ shard_count = 30,
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal_util",
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 3001813dd4..9f3b66e256 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -258,7 +258,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
LOG(WARNING) << "performing exact comparison of floating point numbers";
} else {
TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) ||
- expected.shape().element_type() == PRED);
+ expected.shape().element_type() == PRED)
+ << ShapeUtil::HumanString(expected.shape());
}
auto expect_equal = [&](const Literal& actual, const string& error_message) {
LiteralTestUtil::ExpectEqual(expected, actual, error_message);
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index aa84b8ff1a..3301e4c6ee 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -251,6 +251,85 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
ComputeAndCompareR1<float>(&builder, sum, param_data, ErrorSpec(0.0001f));
}
+// TODO(b/65525254) Fails on GPU on 2017-09-10 because we try to reserve too
+// much space in parameter memory for the kernel.
+//
+// TODO(b/65526061) Failed on CPU on 2017-09-10 due to timeout in LLVM
+// compilation.
+XLA_TEST_F(ParamsTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(ThreeThousandParameters))) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<std::unique_ptr<GlobalData>> param_data_owner;
+ ComputationDataHandle sum_handle = builder.ConstantR0<float>(0.0f);
+ float target = 0.0;
+ constexpr int kParamCount = 3000;
+ for (int i = 0; i < kParamCount; ++i) {
+ target += i;
+ std::unique_ptr<Literal> literal = Literal::CreateR0<float>(i);
+ param_data_owner.push_back(
+ std::move(client_->TransferToServer(*literal)).ValueOrDie());
+ ComputationDataHandle param =
+ builder.Parameter(i, literal->shape(), "param");
+ sum_handle = builder.Add(sum_handle, param);
+ }
+
+ std::vector<GlobalData*> param_data;
+ param_data.reserve(param_data_owner.size());
+ for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
+ param_data.push_back(data.get());
+ }
+
+ ComputeAndCompareR0<float>(&builder, target, param_data, ErrorSpec(0.0001f));
+}
+
+// TODO(b/65525254) Fails on GPU on 2017-09-10 because we try to reserve too
+// much space in parameter memory for the kernel.
+//
+// TODO(b/65526061) Failed on CPU on 2017-09-10 due to timeout in LLVM
+// compilation.
+XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
+ ThreeThousandParametersAndOutputElements))) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<std::unique_ptr<GlobalData>> param_data_owner;
+ ComputationDataHandle sum_handle = builder.ConstantR1<int32>({0, 0});
+ int32 target = 0;
+ constexpr int kParamCount = 3000;
+ std::vector<ComputationDataHandle> params;
+ for (int i = 0; i < kParamCount; ++i) {
+ target += i;
+ std::unique_ptr<Literal> literal = Literal::CreateR1<int32>({i, i});
+ param_data_owner.push_back(
+ std::move(client_->TransferToServer(*literal)).ValueOrDie());
+ ComputationDataHandle param =
+ builder.Parameter(i, literal->shape(), "param");
+ params.push_back(param);
+ sum_handle = builder.Add(sum_handle, param);
+ }
+
+ std::vector<ComputationDataHandle> outputs;
+ for (int i = 0; i < kParamCount; ++i) {
+ outputs.push_back(builder.Add(params[i], sum_handle));
+ }
+
+ builder.Tuple(outputs);
+
+ std::vector<GlobalData*> param_data;
+ param_data.reserve(param_data_owner.size());
+ for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
+ param_data.push_back(data.get());
+ }
+
+ std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<const Literal*> ptrs;
+ for (int i = 0; i < kParamCount; ++i) {
+ elements.push_back(Literal::CreateR1<int32>({target + i, target + i}));
+ ptrs.push_back(elements.back().get());
+ }
+ ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data);
+}
+
XLA_TEST_F(ParamsTest,
DISABLED_ON_CPU_PARALLEL(TupleOfR1ParametersAddedTogether)) {
ComputationBuilder builder(client_, TestName());
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index cafaf5bcc6..1865004911 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -770,6 +771,97 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) {
}
}
+// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11.
+TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithTupleElement)) {
+ auto element_shape = ShapeUtil::MakeShape(F32, {2});
+
+ ComputationBuilder outer(client_, "outer");
+ auto p = outer.Parameter(0, element_shape, "param");
+ auto t = outer.Tuple({p, outer.ConstantR1<float>({1, 1})});
+
+ TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<Shape> tuple_shape,
+ outer.GetShape(t));
+
+ ComputationBuilder cond(client_, "cond");
+ auto cond_t = cond.Parameter(0, *tuple_shape, "t");
+ TF_ASSERT_OK(Any(cond.Eq(cond.GetTupleElement(cond_t, 0),
+ cond.ConstantR1<float>({42, 42})),
+ &cond)
+ .status());
+
+ ComputationBuilder body(client_, "body");
+ auto body_t = body.Parameter(0, *tuple_shape, "t");
+ auto e = body.GetTupleElement(body_t, 1);
+ body.Tuple({e, e});
+
+ TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
+ TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
+ outer.While(cond_computation, body_computation, t);
+
+ auto expected_element = Literal::CreateR1<float>({1, 1});
+ auto expected =
+ Literal::MakeTuple({expected_element.get(), expected_element.get()});
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<GlobalData> parameter_data,
+ client_->TransferToServer(*Literal::CreateR1<float>({42, 42})));
+ ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
+ ErrorSpec(1e-6));
+}
+
+// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11.
+TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithBroadcast)) {
+ auto element_shape = ShapeUtil::MakeShape(F32, {2});
+
+ ComputationBuilder outer(client_, "outer");
+ auto p = outer.Parameter(0, element_shape, "param");
+
+ ComputationBuilder cond(client_, "cond");
+ auto cond_t = cond.Parameter(0, element_shape, "t");
+ TF_ASSERT_OK(
+ Any(cond.Eq(cond_t, cond.ConstantR1<float>({42, 42})), &cond).status());
+
+ ComputationBuilder body(client_, "body");
+ auto body_t = body.Parameter(0, element_shape, "t");
+ auto e = body.Broadcast(body.ConstantR0<float>(1.0), {2});
+
+ TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
+ TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
+ outer.While(cond_computation, body_computation, p);
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<GlobalData> parameter_data,
+ client_->TransferToServer(*Literal::CreateR1<float>({42, 42})));
+ ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
+ ErrorSpec(1e-6));
+}
+
+TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
+ auto element_shape = ShapeUtil::MakeShape(F32, {});
+
+ ComputationBuilder outer(client_, "outer");
+ auto p = outer.Parameter(0, element_shape, "param");
+
+ ComputationBuilder cond(client_, "cond");
+ auto cond_t = cond.Parameter(0, element_shape, "t");
+ cond.Eq(cond_t, cond.ConstantR0<float>(42));
+
+ ComputationBuilder body(client_, "body");
+ auto body_t = body.Parameter(0, element_shape, "t");
+ auto tuple =
+ body.Tuple({body_t, body.Add(body_t, body.ConstantR0<float>(1))});
+ auto e = body.GetTupleElement(tuple, 1);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
+ TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
+ outer.While(cond_computation, body_computation, p);
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<GlobalData> parameter_data,
+ client_->TransferToServer(*Literal::CreateR0<float>(42)));
+ ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
+ ErrorSpec(1e-6));
+}
+
// Tests nested while loops.
//
// int32 result = 0;