diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/compute_constant_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/compute_constant_test.cc | 50 |
1 files changed, 25 insertions, 25 deletions
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index ba22530f1c..672fb06de6 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -99,7 +99,7 @@ TEST_F(ComputeConstantTest, ScalarInt32Literal) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.ConstantR0<int32>(42); + auto computation = ConstantR0<int32>(&b, 42); EXPECT_TRUE(IsConstant(computation, &b)); auto value = ComputeConstantScalar<int32>(client, computation, &b); @@ -113,7 +113,7 @@ TEST_F(ComputeConstantTest, ScalarFloatAdd) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR0<float>(42.5f), b.ConstantR0<float>(1.5f)); + Add(ConstantR0<float>(&b, 42.5f), ConstantR0<float>(&b, 1.5f)); EXPECT_TRUE(IsConstant(computation, &b)); auto value = ComputeConstantScalar<float>(client, computation, &b); @@ -127,8 +127,8 @@ TEST_F(ComputeConstantTest, ScalarRng) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.RngUniform(b.ConstantR0<float>(1.1f), b.ConstantR0<float>(2.1f), - ShapeUtil::MakeShape(F32, {})); + RngUniform(ConstantR0<float>(&b, 1.1f), ConstantR0<float>(&b, 2.1f), + ShapeUtil::MakeShape(F32, {})); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar<float>(client, computation, &b); @@ -141,7 +141,7 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); + auto computation = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param"); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar<float>(client, computation, &b); @@ -156,8 +156,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR0<float>(1.0f), - b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + Add(ConstantR0<float>(&b, 1.0f), + Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param")); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar<float>(client, computation, &b); @@ -174,18 +174,18 @@ TEST_F(ComputeConstantTest, UnrelatedParam) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); + auto param_a = Parameter(&b, 10, ShapeUtil::MakeShape(F32, {}), "param0"); auto constant_4 = - b.Add(b.ConstantR0<float>(2.5f), b.ConstantR0<float>(1.5f)); - auto not_constant_a = b.Add(constant_4, param_a); + Add(ConstantR0<float>(&b, 2.5f), ConstantR0<float>(&b, 1.5f)); + auto not_constant_a = Add(constant_4, param_a); - auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1"); + auto param_b = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "param1"); auto constant_9 = - b.Mul(b.ConstantR0<float>(2.0f), b.ConstantR0<float>(4.5f)); - auto not_constant_b = b.Add(param_b, constant_9); + Mul(ConstantR0<float>(&b, 2.0f), ConstantR0<float>(&b, 4.5f)); + auto not_constant_b = Add(param_b, constant_9); - auto constant_13 = b.Add(constant_4, constant_9); - b.Add(not_constant_b, b.Add(constant_13, not_constant_a)); + auto constant_13 = Add(constant_4, constant_9); + Add(not_constant_b, Add(constant_13, not_constant_a)); EXPECT_TRUE(IsConstant(constant_13, &b)); @@ -201,13 +201,13 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR1<int32>({1, 2}), b.ConstantR1<int32>({3, 4})); + Add(ConstantR1<int32>(&b, {1, 2}), ConstantR1<int32>(&b, {3, 4})); EXPECT_TRUE(IsConstant(computation, &b)); TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); std::unique_ptr<Literal> expected_literal = - Literal::CreateR1<int32>({4, 6}); + LiteralUtil::CreateR1<int32>({4, 6}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -216,12 +216,12 @@ TEST_F(ComputeConstantTest, IntegerDivide) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.Div(b.ConstantR0<int32>(15), b.ConstantR0<int32>(3)); + auto computation = Div(ConstantR0<int32>(&b, 15), ConstantR0<int32>(&b, 3)); EXPECT_TRUE(IsConstant(computation, &b)); TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5); + std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -237,13 +237,13 @@ XLA_TEST_F(ComputeConstantTest, Layout) { TF_ASSERT_OK_AND_ASSIGN( auto computed, ComputeConstantLiteral( client, - b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}), - b.ConstantR2<int32>({{10, 20}, {30, 40}})), + Add(ConstantR2<int32>(&b, {{1, 2}, {3, 4}}), + ConstantR2<int32>(&b, {{10, 20}, {30, 40}})), &b, &layout_proto)); std::unique_ptr<Literal> expected_literal = - Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}}, - LayoutUtil::MakeLayout(layout)); + LiteralUtil::CreateR2WithLayout<int32>( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( expected_literal->shape(), computed->shape())); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); |