diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 203 |
1 files changed, 92 insertions, 111 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 63303aef1e..8fb17a0033 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -246,32 +246,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).Convert( convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } Status HandleBitcastConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( convert->shape().element_type())); - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } @@ -978,10 +967,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = absl::make_unique<Literal>(result_shape); + Literal result(result_shape); TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> out_index) { + result.Populate<ReturnT>([&](absl::Span<const int64> out_index) { std::vector<int64> from_index(out_index.begin(), out_index.end()); for (const int64 dim : reverse_dimensions) { from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; @@ -1157,8 +1146,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return static_cast<ReturnT>(result_val); }; - auto result = absl::make_unique<Literal>(result_shape); - TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func)); + Literal result(result_shape); + TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func)); parent_->evaluated_[conv] = std::move(result); return Status::OK(); @@ -1231,9 +1220,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } } - auto result = absl::make_unique<Literal>(dot->shape()); + Literal result(dot->shape()); TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> result_index) { + result.Populate<ReturnT>([&](absl::Span<const int64> result_index) { ElementwiseT result_val = static_cast<ElementwiseT>(0); for (int64 i = 0; i < result_index.size(); i++) { @@ -1280,8 +1269,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Create new HLO of padded shape with padding value. ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({}); - auto result = absl::make_unique<Literal>(pad->shape()); - TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + Literal result(pad->shape()); + TF_RETURN_IF_ERROR(result.Populate<ReturnT>( [&scalar](absl::Span<const int64> multi_index) { return scalar; })); const Literal& evaluated_operand = @@ -1289,7 +1278,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()), 0); - std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0); + std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0); // Loop through each element of the operand, assign them to the // corresponding index of the resulting padded literal. @@ -1311,8 +1300,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return true; } } - result->Set<ReturnT>(target_index, - evaluated_operand.Get<ReturnT>(input_index)); + result.Set<ReturnT>(target_index, + evaluated_operand.Get<ReturnT>(input_index)); return true; }; @@ -1439,16 +1428,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template <typename NativeT> - StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) { + StatusOr<Literal> MapImpl(HloInstruction* map) { auto operands = map->operands(); HloComputation* computation = map->to_apply(); - auto result = absl::make_unique<Literal>(map->shape()); + Literal result(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) { - std::vector<std::unique_ptr<Literal>> arg_literals; + result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { + std::vector<Literal> arg_literals; arg_literals.reserve(operands.size()); // Construct scalar literal parameters to be passed to the map @@ -1463,16 +1452,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { arg_literals.push_back(std::move(curr_val_literal)); } - std::unique_ptr<Literal> computed_result = - embedded_evaluator - .Evaluate<std::unique_ptr<Literal>>(*computation, - arg_literals) + Literal computed_result = + embedded_evaluator.Evaluate<Literal>(*computation, arg_literals) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on // the same computation. embedded_evaluator.ResetVisitStates(); - return computed_result->Get<ReturnT>({}); + return computed_result.Get<ReturnT>({}); })); return std::move(result); } @@ -1557,9 +1544,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { [](const ReturnT& a, const ReturnT& b) { return SafeLess<ReturnT>(a, b); }); - auto result_literal = absl::make_unique<Literal>(keys_literal.shape()); - result_literal->PopulateR1(absl::Span<const ReturnT>(result_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); + Literal result_literal(keys_literal.shape()); + result_literal.PopulateR1(absl::Span<const ReturnT>(result_data)); + VLOG(3) << "HandleSort result_literal: " << result_literal.ToString(); return result_literal; }; @@ -1568,16 +1555,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto result_literal = absl::make_unique<Literal>(keys_literal.shape()); + Literal result_literal(keys_literal.shape()); int64 r1_length = keys->shape().dimensions(1); for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto r1_slice, keys_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - auto r1_result = sort_r1(*r1_slice); - TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( - *r1_result, {0, 0}, {row, 0}, {1, r1_length})); + .Reshape({r1_length})); + auto r1_result = sort_r1(r1_slice); + TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( + r1_result, {0, 0}, {row, 0}, {1, r1_length})); } parent_->evaluated_[sort] = std::move(result_literal); } @@ -1651,9 +1638,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args); + absl::InlinedVector<Literal, 1> results(num_args); for (int64 i = 0; i < num_args; ++i) { - results[i] = absl::make_unique<Literal>(result_shape); + results[i] = Literal(result_shape); } Status eval_status; @@ -1667,7 +1654,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } for (int64 input = 0; input < num_args; ++input) { - TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>( + TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>( [&](absl::Span<const int64> multi_index) { if (!eval_status.ok()) { return init_scalars[input]; @@ -1703,8 +1690,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } // Evaluate computation with specified literal operands. - absl::InlinedVector<std::unique_ptr<Literal>, 1> - embedded_operands; + absl::InlinedVector<Literal, 1> embedded_operands; for (ReturnT value : result_values) { embedded_operands.push_back( LiteralUtil::CreateR0<ReturnT>(value)); @@ -1717,11 +1703,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { embedded_operands.size()); std::transform(embedded_operands.begin(), embedded_operands.end(), embedded_operands_ptrs.begin(), - [](const std::unique_ptr<Literal>& ptr) { - return ptr.get(); - }); + [](Literal& literal) { return &literal; }); - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result, + TF_ASSIGN_OR_RETURN(Literal computed_result, embedded_evaluator.Evaluate<const Literal*>( *function, embedded_operands_ptrs)); // Clear visit states so that we can use the evaluator again on @@ -1729,10 +1713,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { embedded_evaluator.ResetVisitStates(); // Assign computed result to result_val. if (!has_tuple_output) { - result_values[0] = computed_result->Get<ReturnT>({}); + result_values[0] = computed_result.Get<ReturnT>({}); } else { for (int64 i = 0; i < num_args; ++i) { - result_values[i] = computed_result->Get<ReturnT>( + result_values[i] = computed_result.Get<ReturnT>( /*multi_index=*/{}, /*shape_index=*/{i}); } } @@ -1748,9 +1732,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (!has_tuple_output) { parent_->evaluated_[reduce] = std::move(results[0]); } else { - auto tuple_result = absl::make_unique<Literal>(reduce->shape()); + Literal tuple_result(reduce->shape()); for (int64 i = 0; i < num_args; ++i) { - TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i})); + TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i})); } parent_->evaluated_[reduce] = std::move(tuple_result); } @@ -1781,10 +1765,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get<ReturnT>({}); - auto result = absl::make_unique<Literal>(select_and_scatter->shape()); + Literal result(select_and_scatter->shape()); // Initialize result array with the init value. - TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + TF_RETURN_IF_ERROR(result.Populate<ReturnT>( [&](absl::Span<const int64> output_index) { return init_scalar; })); std::vector<int64> window_dimension_sizes; @@ -1834,15 +1818,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_val = curr_val; selected_index = operand_index; } - curr_val_literal->Set({}, curr_val); - selected_val_literal->Set({}, *selected_val); - std::unique_ptr<Literal> computed_result = + curr_val_literal.Set({}, curr_val); + selected_val_literal.Set({}, *selected_val); + Literal computed_result = embedded_evaluator .Evaluate<const Literal*>( - *select, - {selected_val_literal.get(), curr_val_literal.get()}) + *select, {&selected_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); - bool selected = !computed_result->Get<bool>({}); + bool selected = !computed_result.Get<bool>({}); if (selected) { selected_val = curr_val; selected_index = operand_index; @@ -1856,16 +1839,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (std::equal(operand_index.begin(), operand_index.end(), selected_index->begin())) { auto source = source_literal.Get<ReturnT>(source_index); - auto scattered = result->Get<ReturnT>(operand_index); - source_literal_scatter->Set({}, source); - scattered_literal->Set({}, scattered); - std::unique_ptr<Literal> computed_result = + auto scattered = result.Get<ReturnT>(operand_index); + source_literal_scatter.Set({}, source); + scattered_literal.Set({}, scattered); + Literal computed_result = embedded_evaluator - .Evaluate<const Literal*>(*scatter, - {source_literal_scatter.get(), - scattered_literal.get()}) + .Evaluate<const Literal*>( + *scatter, + {&source_literal_scatter, &scattered_literal}) .ConsumeValueOrDie(); - result->Set(operand_index, computed_result->Get<ReturnT>({})); + result.Set(operand_index, computed_result.Get<ReturnT>({})); // Clear visit states so that the we can use the evaluator again // on the same computation. embedded_evaluator.ResetVisitStates(); @@ -1916,10 +1899,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = absl::make_unique<Literal>(reduce_window->shape()); + Literal result(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> output_index) { + result.Populate<ReturnT>([&](absl::Span<const int64> output_index) { ReturnT result_val = init_scalar; std::fill(window_index.begin(), window_index.end(), 0); @@ -1935,18 +1918,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0<ReturnT>(curr_val); const auto result_val_literal = LiteralUtil::CreateR0<ReturnT>(result_val); - std::unique_ptr<Literal> computed_result = + Literal computed_result = embedded_evaluator .Evaluate<const Literal*>( - *function, - {result_val_literal.get(), curr_val_literal.get()}) + *function, {&result_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again // on the same computation. embedded_evaluator.ResetVisitStates(); - result_val = computed_result->Get<ReturnT>({}); + result_val = computed_result.Get<ReturnT>({}); }); return result_val; @@ -1961,7 +1943,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // literal (if there is one) to `reshaped_indices`. StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices( int64 index_vector_dim, const Literal& indices, - std::unique_ptr<Literal>* reshaped_indices) { + Literal* reshaped_indices) { if (indices.shape().dimensions_size() != index_vector_dim) { return std::cref(indices); } @@ -1970,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { indices.shape().dimensions().end()); new_shape.push_back(1); TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape)); - return std::cref(**reshaped_indices); + return std::cref(*reshaped_indices); } // Returns an ShapeUtil::IndexIterationSpace that iterates over the update @@ -2230,7 +2212,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { scatter->scatter_dimension_numbers(); const Literal& operand = parent_->GetEvaluatedLiteralFor(scatter->operand(0)); - std::unique_ptr<Literal> reshaped_scatter_indices; + Literal reshaped_scatter_indices; TF_ASSIGN_OR_RETURN(const Literal& scatter_indices, ReshapedScatterIndices(dim_numbers.index_vector_dim(), parent_->GetEvaluatedLiteralFor( @@ -2260,7 +2242,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Initialize the result with the operand. This makes it easier to handle // the updates even when the indices are repeated. - std::unique_ptr<Literal> result = operand.CloneToUnique(); + Literal result = operand.Clone(); HloEvaluator embedded_evaluator; auto scatter_inner_loop_body = [&](absl::Span<const int64> update_window_index, @@ -2299,19 +2281,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } auto result_value_literal = - LiteralUtil::CreateR0<ReturnT>(result->Get<ReturnT>(input_index)); + LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index)); auto update_value_literal = LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index)); - std::unique_ptr<Literal> updated_result = + Literal updated_result = embedded_evaluator .Evaluate<const Literal*>( *scatter->to_apply(), - {result_value_literal.get(), update_value_literal.get()}) + {&result_value_literal, &update_value_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on the // same computation. embedded_evaluator.ResetVisitStates(); - result->Set<ReturnT>(input_index, updated_result->Get<ReturnT>({})); + result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({})); return true; }; @@ -2359,9 +2341,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return operand_literal.Get<ReturnT>(operand_index); }; - auto result = LiteralUtil::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func)); + Literal result(shape); + TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func)); parent_->evaluated_[slice] = std::move(result); return Status::OK(); } @@ -2575,7 +2556,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (ShapeUtil::Rank(iota->shape()) > 1) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[iota], - result->Broadcast(iota->shape(), {iota->iota_dimension()})); + result.Broadcast(iota->shape(), {iota->iota_dimension()})); } else { TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); parent_->evaluated_[iota] = std::move(result); @@ -2645,9 +2626,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template <typename IndexT> - StatusOr<std::unique_ptr<Literal>> DynamicSlice( - const Literal& operand_literal, const Literal& start_indices_literal, - const Shape& result_shape) { + StatusOr<Literal> DynamicSlice(const Literal& operand_literal, + const Literal& start_indices_literal, + const Shape& result_shape) { auto start_indices_typed = start_indices_literal.data<IndexT>(); std::vector<int64> start(start_indices_typed.begin(), start_indices_typed.end()); @@ -2660,9 +2641,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } std::vector<int64> operand_indices(start.size()); - auto result = absl::make_unique<Literal>(result_shape); + Literal result(result_shape); TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) { + result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { CHECK_GE(multi_index[i] + start[i], 0); operand_indices[i] = multi_index[i] + start[i]; @@ -2676,12 +2657,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template <typename IndexT> - StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice( - const Literal& operand_literal, const Literal& update_literal, - const Literal& start_indices_literal) { - auto result = operand_literal.CloneToUnique(); + StatusOr<Literal> DynamicUpdateSlice(const Literal& operand_literal, + const Literal& update_literal, + const Literal& start_indices_literal) { + auto result = operand_literal.Clone(); auto start_indices_typed = start_indices_literal.data<IndexT>(); - const auto rank = ShapeUtil::Rank(result->shape()); + const auto rank = ShapeUtil::Rank(result.shape()); std::vector<int64> start(start_indices_typed.begin(), start_indices_typed.end()); // Clamp the update start indices so the slice is in-bounds w.r.t the @@ -2689,15 +2670,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { for (int64 i = 0; i < rank; ++i) { start[i] = std::min<int64>( std::max<int64>(0, start[i]), - result->shape().dimensions(i) - update_literal.shape().dimensions(i)); + result.shape().dimensions(i) - update_literal.shape().dimensions(i)); } std::vector<int64> result_index(rank, 0); auto func = [&](absl::Span<const int64> update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus<int64>()); - result->Set<ReturnT>(result_index, - update_literal.Get<ReturnT>(update_index)); + result.Set<ReturnT>(result_index, + update_literal.Get<ReturnT>(update_index)); return true; }; @@ -2710,7 +2691,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result); } - StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp( + StatusOr<Literal> ElementWiseUnaryOp( HloInstruction* instruction, const std::function<ElementwiseT(ElementwiseT)>& unary_op) { const Literal& operand_literal = @@ -2723,7 +2704,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result_literal); } - StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp( + StatusOr<Literal> ElementWiseBinaryOp( HloInstruction* instruction, const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>& binary_op) { @@ -2745,10 +2726,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = absl::make_unique<Literal>(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) { + result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { return ConvertBinaryFunction(binary_op)( lhs_literal.Get<ReturnT>(multi_index), rhs_literal.Get<ReturnT>(multi_index)); @@ -2757,7 +2738,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template <typename LhsType, typename RhsType, typename EhsType> - StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp( + StatusOr<Literal> ElementwiseTernaryOp( HloInstruction* instruction, const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) { const auto shape = instruction->shape(); @@ -2782,10 +2763,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = absl::make_unique<Literal>(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) { + result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { return ternary_op(lhs_literal.Get<LhsType>(multi_index), rhs_literal.Get<RhsType>(multi_index), ehs_literal.Get<EhsType>(multi_index)); |