aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/while_loop_analysis.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.cc19
1 files changed, 7 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc
index c3c2603c7e..541b117e02 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.cc
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc
@@ -183,8 +183,7 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
HloEvaluator evaluator(/*max_loop_iterations=*/0);
auto* while_init = while_op->mutable_operand(0);
auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
- StatusOr<std::unique_ptr<Literal>> indvar_init_result =
- evaluator.Evaluate(indvar_init);
+ StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
if (!indvar_init_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable init: "
<< indvar_init_result.status();
@@ -197,31 +196,27 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
// The initial value of the induction variable.
- std::unique_ptr<Literal> indvar_iter_val =
- std::move(indvar_init_result).ValueOrDie();
+ Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
for (int64 trip_count = 0; trip_count != max_value_returned + 1;
++trip_count) {
auto* while_cond = while_op->while_condition();
auto* while_cond_root = while_cond->root_instruction();
auto* while_cond_indvar = NonConstantOperand(while_cond_root);
- StatusOr<std::unique_ptr<Literal>> result =
- evaluator.EvaluateWithSubstitutions(
- while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}});
+ StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
+ while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
if (!result.ok()) {
VLOG(2) << "Couldn't evaluate while cond: " << result.status();
return nullopt;
}
- if (result.ValueOrDie()->data<bool>() == absl::Span<const bool>{false}) {
+ if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
VLOG(2) << "Loop has static trip count of " << trip_count;
return trip_count;
}
// Calculate the value of the induction variable after one iteration of the
// loop, and check whether the while condition is true with this new value.
- StatusOr<std::unique_ptr<Literal>> indvar_next_result =
- evaluator.EvaluateWithSubstitutions(
- while_body_indvar_update,
- {{while_body_indvar, indvar_iter_val.get()}});
+ StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
+ while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
if (!indvar_next_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable update: "
<< indvar_next_result.status();