aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h203
1 files changed, 92 insertions, 111 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 63303aef1e..8fb17a0033 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -246,32 +246,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).Convert(
convert->shape().element_type()));
-
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
- }
+ parent_->evaluated_[convert] = std::move(result);
return Status::OK();
}
Status HandleBitcastConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
convert->shape().element_type()));
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
- }
+ parent_->evaluated_[convert] = std::move(result);
return Status::OK();
}
@@ -978,10 +967,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- auto result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> out_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> out_index) {
std::vector<int64> from_index(out_index.begin(), out_index.end());
for (const int64 dim : reverse_dimensions) {
from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
@@ -1157,8 +1146,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return static_cast<ReturnT>(result_val);
};
- auto result = absl::make_unique<Literal>(result_shape);
- TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
+ Literal result(result_shape);
+ TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func));
parent_->evaluated_[conv] = std::move(result);
return Status::OK();
@@ -1231,9 +1220,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
}
- auto result = absl::make_unique<Literal>(dot->shape());
+ Literal result(dot->shape());
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> result_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
for (int64 i = 0; i < result_index.size(); i++) {
@@ -1280,8 +1269,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Create new HLO of padded shape with padding value.
ReturnT scalar =
parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
- auto result = absl::make_unique<Literal>(pad->shape());
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ Literal result(pad->shape());
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
[&scalar](absl::Span<const int64> multi_index) { return scalar; }));
const Literal& evaluated_operand =
@@ -1289,7 +1278,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
0);
- std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
+ std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0);
// Loop through each element of the operand, assign them to the
// corresponding index of the resulting padded literal.
@@ -1311,8 +1300,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return true;
}
}
- result->Set<ReturnT>(target_index,
- evaluated_operand.Get<ReturnT>(input_index));
+ result.Set<ReturnT>(target_index,
+ evaluated_operand.Get<ReturnT>(input_index));
return true;
};
@@ -1439,16 +1428,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename NativeT>
- StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
+ StatusOr<Literal> MapImpl(HloInstruction* map) {
auto operands = map->operands();
HloComputation* computation = map->to_apply();
- auto result = absl::make_unique<Literal>(map->shape());
+ Literal result(map->shape());
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
- std::vector<std::unique_ptr<Literal>> arg_literals;
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ std::vector<Literal> arg_literals;
arg_literals.reserve(operands.size());
// Construct scalar literal parameters to be passed to the map
@@ -1463,16 +1452,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
arg_literals.push_back(std::move(curr_val_literal));
}
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator
- .Evaluate<std::unique_ptr<Literal>>(*computation,
- arg_literals)
+ Literal computed_result =
+ embedded_evaluator.Evaluate<Literal>(*computation, arg_literals)
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on
// the same computation.
embedded_evaluator.ResetVisitStates();
- return computed_result->Get<ReturnT>({});
+ return computed_result.Get<ReturnT>({});
}));
return std::move(result);
}
@@ -1557,9 +1544,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
[](const ReturnT& a, const ReturnT& b) {
return SafeLess<ReturnT>(a, b);
});
- auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
- result_literal->PopulateR1(absl::Span<const ReturnT>(result_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
+ Literal result_literal(keys_literal.shape());
+ result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
+ VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
return result_literal;
};
@@ -1568,16 +1555,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
+ Literal result_literal(keys_literal.shape());
int64 r1_length = keys->shape().dimensions(1);
for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto r1_slice,
keys_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
- auto r1_result = sort_r1(*r1_slice);
- TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
- *r1_result, {0, 0}, {row, 0}, {1, r1_length}));
+ .Reshape({r1_length}));
+ auto r1_result = sort_r1(r1_slice);
+ TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
+ r1_result, {0, 0}, {row, 0}, {1, r1_length}));
}
parent_->evaluated_[sort] = std::move(result_literal);
}
@@ -1651,9 +1638,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args);
+ absl::InlinedVector<Literal, 1> results(num_args);
for (int64 i = 0; i < num_args; ++i) {
- results[i] = absl::make_unique<Literal>(result_shape);
+ results[i] = Literal(result_shape);
}
Status eval_status;
@@ -1667,7 +1654,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
for (int64 input = 0; input < num_args; ++input) {
- TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>(
+ TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>(
[&](absl::Span<const int64> multi_index) {
if (!eval_status.ok()) {
return init_scalars[input];
@@ -1703,8 +1690,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
// Evaluate computation with specified literal operands.
- absl::InlinedVector<std::unique_ptr<Literal>, 1>
- embedded_operands;
+ absl::InlinedVector<Literal, 1> embedded_operands;
for (ReturnT value : result_values) {
embedded_operands.push_back(
LiteralUtil::CreateR0<ReturnT>(value));
@@ -1717,11 +1703,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_operands.size());
std::transform(embedded_operands.begin(), embedded_operands.end(),
embedded_operands_ptrs.begin(),
- [](const std::unique_ptr<Literal>& ptr) {
- return ptr.get();
- });
+ [](Literal& literal) { return &literal; });
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
+ TF_ASSIGN_OR_RETURN(Literal computed_result,
embedded_evaluator.Evaluate<const Literal*>(
*function, embedded_operands_ptrs));
// Clear visit states so that we can use the evaluator again on
@@ -1729,10 +1713,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_evaluator.ResetVisitStates();
// Assign computed result to result_val.
if (!has_tuple_output) {
- result_values[0] = computed_result->Get<ReturnT>({});
+ result_values[0] = computed_result.Get<ReturnT>({});
} else {
for (int64 i = 0; i < num_args; ++i) {
- result_values[i] = computed_result->Get<ReturnT>(
+ result_values[i] = computed_result.Get<ReturnT>(
/*multi_index=*/{}, /*shape_index=*/{i});
}
}
@@ -1748,9 +1732,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (!has_tuple_output) {
parent_->evaluated_[reduce] = std::move(results[0]);
} else {
- auto tuple_result = absl::make_unique<Literal>(reduce->shape());
+ Literal tuple_result(reduce->shape());
for (int64 i = 0; i < num_args; ++i) {
- TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i}));
+ TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
}
parent_->evaluated_[reduce] = std::move(tuple_result);
}
@@ -1781,10 +1765,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
auto init_scalar = init_literal.Get<ReturnT>({});
- auto result = absl::make_unique<Literal>(select_and_scatter->shape());
+ Literal result(select_and_scatter->shape());
// Initialize result array with the init value.
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
[&](absl::Span<const int64> output_index) { return init_scalar; }));
std::vector<int64> window_dimension_sizes;
@@ -1834,15 +1818,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
selected_val = curr_val;
selected_index = operand_index;
}
- curr_val_literal->Set({}, curr_val);
- selected_val_literal->Set({}, *selected_val);
- std::unique_ptr<Literal> computed_result =
+ curr_val_literal.Set({}, curr_val);
+ selected_val_literal.Set({}, *selected_val);
+ Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
- *select,
- {selected_val_literal.get(), curr_val_literal.get()})
+ *select, {&selected_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
- bool selected = !computed_result->Get<bool>({});
+ bool selected = !computed_result.Get<bool>({});
if (selected) {
selected_val = curr_val;
selected_index = operand_index;
@@ -1856,16 +1839,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (std::equal(operand_index.begin(), operand_index.end(),
selected_index->begin())) {
auto source = source_literal.Get<ReturnT>(source_index);
- auto scattered = result->Get<ReturnT>(operand_index);
- source_literal_scatter->Set({}, source);
- scattered_literal->Set({}, scattered);
- std::unique_ptr<Literal> computed_result =
+ auto scattered = result.Get<ReturnT>(operand_index);
+ source_literal_scatter.Set({}, source);
+ scattered_literal.Set({}, scattered);
+ Literal computed_result =
embedded_evaluator
- .Evaluate<const Literal*>(*scatter,
- {source_literal_scatter.get(),
- scattered_literal.get()})
+ .Evaluate<const Literal*>(
+ *scatter,
+ {&source_literal_scatter, &scattered_literal})
.ConsumeValueOrDie();
- result->Set(operand_index, computed_result->Get<ReturnT>({}));
+ result.Set(operand_index, computed_result.Get<ReturnT>({}));
// Clear visit states so that the we can use the evaluator again
// on the same computation.
embedded_evaluator.ResetVisitStates();
@@ -1916,10 +1899,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- auto result = absl::make_unique<Literal>(reduce_window->shape());
+ Literal result(reduce_window->shape());
// For each resulting dimension, calculate and assign computed value.
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> output_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
ReturnT result_val = init_scalar;
std::fill(window_index.begin(), window_index.end(), 0);
@@ -1935,18 +1918,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
LiteralUtil::CreateR0<ReturnT>(curr_val);
const auto result_val_literal =
LiteralUtil::CreateR0<ReturnT>(result_val);
- std::unique_ptr<Literal> computed_result =
+ Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
- *function,
- {result_val_literal.get(), curr_val_literal.get()})
+ *function, {&result_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again
// on the same computation.
embedded_evaluator.ResetVisitStates();
- result_val = computed_result->Get<ReturnT>({});
+ result_val = computed_result.Get<ReturnT>({});
});
return result_val;
@@ -1961,7 +1943,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// literal (if there is one) to `reshaped_indices`.
StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
int64 index_vector_dim, const Literal& indices,
- std::unique_ptr<Literal>* reshaped_indices) {
+ Literal* reshaped_indices) {
if (indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(indices);
}
@@ -1970,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
indices.shape().dimensions().end());
new_shape.push_back(1);
TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape));
- return std::cref(**reshaped_indices);
+ return std::cref(*reshaped_indices);
}
// Returns an ShapeUtil::IndexIterationSpace that iterates over the update
@@ -2230,7 +2212,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
scatter->scatter_dimension_numbers();
const Literal& operand =
parent_->GetEvaluatedLiteralFor(scatter->operand(0));
- std::unique_ptr<Literal> reshaped_scatter_indices;
+ Literal reshaped_scatter_indices;
TF_ASSIGN_OR_RETURN(const Literal& scatter_indices,
ReshapedScatterIndices(dim_numbers.index_vector_dim(),
parent_->GetEvaluatedLiteralFor(
@@ -2260,7 +2242,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Initialize the result with the operand. This makes it easier to handle
// the updates even when the indices are repeated.
- std::unique_ptr<Literal> result = operand.CloneToUnique();
+ Literal result = operand.Clone();
HloEvaluator embedded_evaluator;
auto scatter_inner_loop_body =
[&](absl::Span<const int64> update_window_index,
@@ -2299,19 +2281,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
auto result_value_literal =
- LiteralUtil::CreateR0<ReturnT>(result->Get<ReturnT>(input_index));
+ LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index));
auto update_value_literal =
LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index));
- std::unique_ptr<Literal> updated_result =
+ Literal updated_result =
embedded_evaluator
.Evaluate<const Literal*>(
*scatter->to_apply(),
- {result_value_literal.get(), update_value_literal.get()})
+ {&result_value_literal, &update_value_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on the
// same computation.
embedded_evaluator.ResetVisitStates();
- result->Set<ReturnT>(input_index, updated_result->Get<ReturnT>({}));
+ result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({}));
return true;
};
@@ -2359,9 +2341,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return operand_literal.Get<ReturnT>(operand_index);
};
- auto result = LiteralUtil::CreateFromDimensions(
- shape.element_type(), AsInt64Slice(shape.dimensions()));
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
+ Literal result(shape);
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func));
parent_->evaluated_[slice] = std::move(result);
return Status::OK();
}
@@ -2575,7 +2556,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (ShapeUtil::Rank(iota->shape()) > 1) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[iota],
- result->Broadcast(iota->shape(), {iota->iota_dimension()}));
+ result.Broadcast(iota->shape(), {iota->iota_dimension()}));
} else {
TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
parent_->evaluated_[iota] = std::move(result);
@@ -2645,9 +2626,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicSlice(
- const Literal& operand_literal, const Literal& start_indices_literal,
- const Shape& result_shape) {
+ StatusOr<Literal> DynamicSlice(const Literal& operand_literal,
+ const Literal& start_indices_literal,
+ const Shape& result_shape) {
auto start_indices_typed = start_indices_literal.data<IndexT>();
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
@@ -2660,9 +2641,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
std::vector<int64> operand_indices(start.size());
- auto result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
for (int64 i = 0; i < operand_indices.size(); ++i) {
CHECK_GE(multi_index[i] + start[i], 0);
operand_indices[i] = multi_index[i] + start[i];
@@ -2676,12 +2657,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
- const Literal& operand_literal, const Literal& update_literal,
- const Literal& start_indices_literal) {
- auto result = operand_literal.CloneToUnique();
+ StatusOr<Literal> DynamicUpdateSlice(const Literal& operand_literal,
+ const Literal& update_literal,
+ const Literal& start_indices_literal) {
+ auto result = operand_literal.Clone();
auto start_indices_typed = start_indices_literal.data<IndexT>();
- const auto rank = ShapeUtil::Rank(result->shape());
+ const auto rank = ShapeUtil::Rank(result.shape());
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
// Clamp the update start indices so the slice is in-bounds w.r.t the
@@ -2689,15 +2670,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
for (int64 i = 0; i < rank; ++i) {
start[i] = std::min<int64>(
std::max<int64>(0, start[i]),
- result->shape().dimensions(i) - update_literal.shape().dimensions(i));
+ result.shape().dimensions(i) - update_literal.shape().dimensions(i));
}
std::vector<int64> result_index(rank, 0);
auto func = [&](absl::Span<const int64> update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());
- result->Set<ReturnT>(result_index,
- update_literal.Get<ReturnT>(update_index));
+ result.Set<ReturnT>(result_index,
+ update_literal.Get<ReturnT>(update_index));
return true;
};
@@ -2710,7 +2691,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result);
}
- StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
+ StatusOr<Literal> ElementWiseUnaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
const Literal& operand_literal =
@@ -2723,7 +2704,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result_literal);
}
- StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
+ StatusOr<Literal> ElementWiseBinaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
binary_op) {
@@ -2745,10 +2726,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ConvertBinaryFunction(binary_op)(
lhs_literal.Get<ReturnT>(multi_index),
rhs_literal.Get<ReturnT>(multi_index));
@@ -2757,7 +2738,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename LhsType, typename RhsType, typename EhsType>
- StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
+ StatusOr<Literal> ElementwiseTernaryOp(
HloInstruction* instruction,
const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
const auto shape = instruction->shape();
@@ -2782,10 +2763,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ternary_op(lhs_literal.Get<LhsType>(multi_index),
rhs_literal.Get<RhsType>(multi_index),
ehs_literal.Get<EhsType>(multi_index));