aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 15:35:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 15:40:02 -0700
commite6cce55e57722d8ba587965b8ef511838c6d1391 (patch)
tree89491579681e64f2f36fe273dd66c77ead7df802
parent5b4e922701c6b1476bbfa0271783df5b3e855bf2 (diff)
Fix some build breakage due to de-std::unique_ptr cleanup.
PiperOrigin-RevId: 212347506
-rw-r--r--tensorflow/compiler/xla/service/cpu/sample_harness.cc30
-rw-r--r--tensorflow/compiler/xla/tools/show_literal.cc4
-rw-r--r--tensorflow/compiler/xla/tools/show_text_literal.cc16
3 files changed, 24 insertions, 26 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
index 942e2ddd39..55d5925642 100644
--- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc
+++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
@@ -37,21 +37,20 @@ int main(int argc, char** argv) {
xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie());
// Transfer parameters.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<xla::GlobalData> param0_data =
- client->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR2<float>(
- {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR2<float>(
+ {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
std::unique_ptr<xla::GlobalData> param1_data =
- client->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client->TransferToServer(param1_literal).ConsumeValueOrDie();
// Build computation.
xla::XlaBuilder builder("");
- auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p1, p0, {0});
xla::StatusOr<xla::XlaComputation> computation_status = builder.Build();
@@ -59,17 +58,16 @@ int main(int argc, char** argv) {
// Execute and transfer result of computation.
xla::ExecutionProfile profile;
- xla::StatusOr<std::unique_ptr<xla::Literal>> result =
- client->ExecuteAndTransfer(
- computation,
- /*arguments=*/{param0_data.get(), param1_data.get()},
- /*execution_options=*/nullptr,
- /*execution_profile=*/&profile);
- std::unique_ptr<xla::Literal> actual = result.ConsumeValueOrDie();
+ xla::StatusOr<xla::Literal> result = client->ExecuteAndTransfer(
+ computation,
+ /*arguments=*/{param0_data.get(), param1_data.get()},
+ /*execution_options=*/nullptr,
+ /*execution_profile=*/&profile);
+ xla::Literal actual = result.ConsumeValueOrDie();
LOG(INFO) << absl::StrFormat("computation took %dns",
profile.compute_time_ns());
- LOG(INFO) << actual->ToString();
+ LOG(INFO) << actual.ToString();
return 0;
}
diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc
index 51909190a3..4f8852f8c1 100644
--- a/tensorflow/compiler/xla/tools/show_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_literal.cc
@@ -40,8 +40,8 @@ int main(int argc, char **argv) {
xla::LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
&literal_proto));
- std::unique_ptr<xla::Literal> literal =
+ xla::Literal literal =
xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
LOG(INFO) << "literal: " << literal_proto.ShortDebugString();
- fprintf(stderr, "%s\n", literal->ToString().c_str());
+ fprintf(stderr, "%s\n", literal.ToString().c_str());
}
diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc
index 48c8374811..4b5c276bdf 100644
--- a/tensorflow/compiler/xla/tools/show_text_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_text_literal.cc
@@ -36,16 +36,16 @@ int main(int argc, char **argv) {
LOG(QFATAL) << "Usage: " << argv[0] << " <path-to-serialized-literal-text>";
}
- std::unique_ptr<xla::Literal> literal =
+ xla::Literal literal =
xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie();
- LOG(INFO) << "literal: " << *literal;
- fprintf(stderr, "%s\n", literal->ToString().c_str());
- if (literal->shape().element_type() == xla::F32) {
- float min = *std::min_element(literal->data<float>().begin(),
- literal->data<float>().end());
- float max = *std::max_element(literal->data<float>().begin(),
- literal->data<float>().end());
+ LOG(INFO) << "literal: " << literal;
+ fprintf(stderr, "%s\n", literal.ToString().c_str());
+ if (literal.shape().element_type() == xla::F32) {
+ float min = *std::min_element(literal.data<float>().begin(),
+ literal.data<float>().end());
+ float max = *std::max_element(literal.data<float>().begin(),
+ literal.data<float>().end());
fprintf(stderr, "min: %a=%f\n", min, min);
fprintf(stderr, "max: %a=%f\n", max, max);
}