aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/compute_constant_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/compute_constant_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc50
1 files changed, 25 insertions, 25 deletions
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index ba22530f1c..672fb06de6 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -20,9 +20,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -99,7 +99,7 @@ TEST_F(ComputeConstantTest, ScalarInt32Literal) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
- auto computation = b.ConstantR0<int32>(42);
+ auto computation = ConstantR0<int32>(&b, 42);
EXPECT_TRUE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<int32>(client, computation, &b);
@@ -113,7 +113,7 @@ TEST_F(ComputeConstantTest, ScalarFloatAdd) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation =
- b.Add(b.ConstantR0<float>(42.5f), b.ConstantR0<float>(1.5f));
+ Add(ConstantR0<float>(&b, 42.5f), ConstantR0<float>(&b, 1.5f));
EXPECT_TRUE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
@@ -127,8 +127,8 @@ TEST_F(ComputeConstantTest, ScalarRng) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation =
- b.RngUniform(b.ConstantR0<float>(1.1f), b.ConstantR0<float>(2.1f),
- ShapeUtil::MakeShape(F32, {}));
+ RngUniform(ConstantR0<float>(&b, 1.1f), ConstantR0<float>(&b, 2.1f),
+ ShapeUtil::MakeShape(F32, {}));
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
@@ -141,7 +141,7 @@ TEST_F(ComputeConstantTest, DirectParamMissing) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
- auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param");
+ auto computation = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param");
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
@@ -156,8 +156,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation =
- b.Add(b.ConstantR0<float>(1.0f),
- b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"));
+ Add(ConstantR0<float>(&b, 1.0f),
+ Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param"));
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
@@ -174,18 +174,18 @@ TEST_F(ComputeConstantTest, UnrelatedParam) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
- auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0");
+ auto param_a = Parameter(&b, 10, ShapeUtil::MakeShape(F32, {}), "param0");
auto constant_4 =
- b.Add(b.ConstantR0<float>(2.5f), b.ConstantR0<float>(1.5f));
- auto not_constant_a = b.Add(constant_4, param_a);
+ Add(ConstantR0<float>(&b, 2.5f), ConstantR0<float>(&b, 1.5f));
+ auto not_constant_a = Add(constant_4, param_a);
- auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1");
+ auto param_b = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "param1");
auto constant_9 =
- b.Mul(b.ConstantR0<float>(2.0f), b.ConstantR0<float>(4.5f));
- auto not_constant_b = b.Add(param_b, constant_9);
+ Mul(ConstantR0<float>(&b, 2.0f), ConstantR0<float>(&b, 4.5f));
+ auto not_constant_b = Add(param_b, constant_9);
- auto constant_13 = b.Add(constant_4, constant_9);
- b.Add(not_constant_b, b.Add(constant_13, not_constant_a));
+ auto constant_13 = Add(constant_4, constant_9);
+ Add(not_constant_b, Add(constant_13, not_constant_a));
EXPECT_TRUE(IsConstant(constant_13, &b));
@@ -201,13 +201,13 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
XlaBuilder b(TestName());
auto computation =
- b.Add(b.ConstantR1<int32>({1, 2}), b.ConstantR1<int32>({3, 4}));
+ Add(ConstantR1<int32>(&b, {1, 2}), ConstantR1<int32>(&b, {3, 4}));
EXPECT_TRUE(IsConstant(computation, &b));
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR1<int32>({4, 6});
+ LiteralUtil::CreateR1<int32>({4, 6});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
@@ -216,12 +216,12 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
- auto computation = b.Div(b.ConstantR0<int32>(15), b.ConstantR0<int32>(3));
+ auto computation = Div(ConstantR0<int32>(&b, 15), ConstantR0<int32>(&b, 3));
EXPECT_TRUE(IsConstant(computation, &b));
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
- std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5);
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
@@ -237,13 +237,13 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
TF_ASSERT_OK_AND_ASSIGN(
auto computed, ComputeConstantLiteral(
client,
- b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
- b.ConstantR2<int32>({{10, 20}, {30, 40}})),
+ Add(ConstantR2<int32>(&b, {{1, 2}, {3, 4}}),
+ ConstantR2<int32>(&b, {{10, 20}, {30, 40}})),
&b, &layout_proto));
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
- LayoutUtil::MakeLayout(layout));
+ LiteralUtil::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
expected_literal->shape(), computed->shape()));
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));