aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-12 18:43:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-12 18:47:39 -0800
commit7022e6b62908f688cae446f1abbc667c5273db04 (patch)
treea1c5da6d93aed4c7dc6c821076629bea8d5936db
parent4476407836c5f0de0b746269af1714a469a09c32 (diff)
[XLA] Add some scalar tests.
PiperOrigin-RevId: 181819491
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc7
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc47
3 files changed, 68 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index b5e7570778..debf2d2d31 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -69,6 +69,13 @@ class ScalarComputationsTest : public ClientLibraryTestBase {
}
};
+XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR0<float>(2.1f);
+
+ ComputeAndCompareR0<float>(&builder, 2.1f, {}, error_spec_);
+}
+
XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) {
ComputationBuilder builder(client_, TestName());
builder.Neg(builder.ConstantR0<float>(2.1f));
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index 65489cfff1..a8bca70d85 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -57,6 +57,20 @@ XLA_TEST_F(TupleTest, TupleConstant) {
ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
}
+// Tests a tuple made of scalar constants.
+XLA_TEST_F(TupleTest, TupleScalarConstant) {
+ ComputationBuilder builder(client_, TestName());
+
+ const float constant_scalar1 = 7.3f;
+ const float constant_scalar2 = 1.2f;
+ auto value =
+ Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar1).get(),
+ Literal::CreateR0<float>(constant_scalar2).get()});
+
+ auto result = builder.ConstantLiteral(*value);
+ ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
+}
+
// Tests the creation of tuple data.
XLA_TEST_F(TupleTest, TupleCreate) {
ComputationBuilder builder(client_, TestName());
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 7e7f6b1486..52157b837c 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -563,6 +563,53 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
}
+TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
+ std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeShape(S32, {})};
+ Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
+
+ // Create a computation for the condition.
+ // Repeat for 5 iterations.
+ Computation condition;
+ {
+ ComputationBuilder builder(client_, "condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto iteration = builder.GetTupleElement(prev, 0);
+ builder.Gt(builder.ConstantR0<int32>(5), iteration);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the body.
+ // Add 1 to the iteration variable and set the other tuple element to a
+ // constant.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto iteration = builder.GetTupleElement(prev, 0);
+ auto result =
+ builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)),
+ builder.ConstantR0<int32>(7)});
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ ComputationBuilder builder(client_, "while");
+ auto init = builder.Tuple(
+ {builder.ConstantR0<int32>(0), builder.ConstantR0<int32>(7)});
+ auto result = builder.While(condition, body, init);
+ VLOG(2) << "while = "
+ << ShapeUtil::HumanString(
+ *builder.GetShape(result).ConsumeValueOrDie());
+
+ auto expected_counter = Literal::CreateR0<int32>(5);
+ auto expected_data = Literal::CreateR0<int32>(7);
+ auto expected =
+ Literal::MakeTuple({expected_counter.get(), expected_data.get()});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
+ ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+}
+
// Tests two while nodes when the result type T is a Tuple and the second
// while node uses the result of the first while node which is used in two
// nodes.