diff options
author | 2018-01-12 18:43:55 -0800 | |
---|---|---|
committer | 2018-01-12 18:47:39 -0800 | |
commit | 7022e6b62908f688cae446f1abbc667c5273db04 (patch) | |
tree | a1c5da6d93aed4c7dc6c821076629bea8d5936db | |
parent | 4476407836c5f0de0b746269af1714a469a09c32 (diff) |
[XLA] Add some scalar tests.
PiperOrigin-RevId: 181819491
-rw-r--r-- | tensorflow/compiler/xla/tests/scalar_computations_test.cc | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/tuple_test.cc | 14 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/while_test.cc | 47 |
3 files changed, 68 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index b5e7570778..debf2d2d31 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -69,6 +69,13 @@ class ScalarComputationsTest : public ClientLibraryTestBase { } }; +XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) { + ComputationBuilder builder(client_, TestName()); + builder.ConstantR0<float>(2.1f); + + ComputeAndCompareR0<float>(&builder, 2.1f, {}, error_spec_); +} + XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) { ComputationBuilder builder(client_, TestName()); builder.Neg(builder.ConstantR0<float>(2.1f)); diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 65489cfff1..a8bca70d85 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -57,6 +57,20 @@ XLA_TEST_F(TupleTest, TupleConstant) { ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } +// Tests a tuple made of scalar constants. +XLA_TEST_F(TupleTest, TupleScalarConstant) { + ComputationBuilder builder(client_, TestName()); + + const float constant_scalar1 = 7.3f; + const float constant_scalar2 = 1.2f; + auto value = + Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar1).get(), + Literal::CreateR0<float>(constant_scalar2).get()}); + + auto result = builder.ConstantLiteral(*value); + ComputeAndCompareTuple(&builder, *value, {}, error_spec_); +} + // Tests the creation of tuple data. XLA_TEST_F(TupleTest, TupleCreate) { ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 7e7f6b1486..52157b837c 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -563,6 +563,53 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); } +TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { + std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(S32, {})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0<int32>(5), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and set the other tuple element to a + // constant. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto result = + builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)), + builder.ConstantR0<int32>(7)}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0<int32>(0), builder.ConstantR0<int32>(7)}); + auto result = builder.While(condition, body, init); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + auto expected_counter = Literal::CreateR0<int32>(5); + auto expected_data = Literal::CreateR0<int32>(7); + auto expected = + Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); +} + // Tests two while nodes when the result type T is a Tuple and the second // while node uses the result of the first while node which is used in two // nodes. |