aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/compute_constant_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/compute_constant_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc26
1 files changed, 12 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index 8226b6de3f..3b0414a604 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -69,9 +69,9 @@ class ComputeConstantTest : public ::testing::Test {
LOG(FATAL) << "invalid client_type value";
}
- StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
- Client* client, const XlaOp& operand, XlaBuilder* builder,
- Layout* output_layout = nullptr) {
+ StatusOr<Literal> ComputeConstantLiteral(Client* client, const XlaOp& operand,
+ XlaBuilder* builder,
+ Layout* output_layout = nullptr) {
TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand));
TF_ASSIGN_OR_RETURN(auto computed,
client->ComputeConstant(subgraph, output_layout));
@@ -83,7 +83,7 @@ class ComputeConstantTest : public ::testing::Test {
XlaBuilder* builder) {
TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand,
builder, nullptr));
- return literal->Get<Scalar>({});
+ return literal.Get<Scalar>({});
}
bool IsConstant(const XlaOp& operand, XlaBuilder* builder) {
@@ -206,9 +206,8 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<int32>({4, 6});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ Literal expected_literal = LiteralUtil::CreateR1<int32>({4, 6});
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
@@ -221,8 +220,8 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ Literal expected_literal = LiteralUtil::CreateR0<int32>(5);
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
@@ -241,12 +240,11 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
ConstantR2<int32>(&b, {{10, 20}, {30, 40}})),
&b, &layout_proto));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR2WithLayout<int32>(
- {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
+ Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
- expected_literal->shape(), computed->shape()));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ expected_literal.shape(), computed.shape()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
}