diff options
author | Justin Lebar <jlebar@google.com> | 2018-08-20 20:20:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-20 20:23:24 -0700 |
commit | e924d67bff8c4fb58c8316d00b662f8d1e80eb95 (patch) | |
tree | bf1b0f5b9d0c699150295f98187b19d6a10710a6 /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | |
parent | 49115abfd39d30506679d9fdc572ccd2f7c22dbe (diff) |
[XLA] Use absl::make_unique instead of xla::MakeUnique.
Same for WrapUnique.
PiperOrigin-RevId: 209531124
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 108 |
1 files changed, 55 insertions, 53 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 71c30e19a2..dea2a31920 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -30,7 +31,6 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" @@ -384,7 +384,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { int64 feature_index_value = feature_index->literal().Get<int64>({}); thunk_sequence_->emplace_back( - MakeUnique<CudnnBatchNormForwardInferenceThunk>( + absl::make_unique<CudnnBatchNormForwardInferenceThunk>( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -414,7 +414,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); thunk_sequence_->emplace_back( - MakeUnique<CudnnBatchNormForwardTrainingThunk>( + absl::make_unique<CudnnBatchNormForwardTrainingThunk>( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -444,19 +444,20 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_grad_offset = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); - thunk_sequence_->emplace_back(MakeUnique<CudnnBatchNormBackwardThunk>( - /*operand=*/GetAllocationSlice(*custom_call->operand(0)), - /*scale=*/GetAllocationSlice(*custom_call->operand(1)), - /*mean=*/GetAllocationSlice(*custom_call->operand(2)), - /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), - /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, - /*output_grad_data=*/output_grad_data, - /*output_grad_scale=*/output_grad_scale, - /*output_grad_offset=*/output_grad_offset, - /*output_tuple=*/GetAllocationSlice(*custom_call), - /*hlo=*/custom_call)); + thunk_sequence_->emplace_back( + absl::make_unique<CudnnBatchNormBackwardThunk>( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*mean=*/GetAllocationSlice(*custom_call->operand(2)), + /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), + /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output_grad_data=*/output_grad_data, + /*output_grad_scale=*/output_grad_scale, + /*output_grad_offset=*/output_grad_offset, + /*output_tuple=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); return Status::OK(); } @@ -476,7 +477,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { const auto& target = custom_call->custom_call_target(); std::unique_ptr<ConvolutionThunk> thunk; if (target == kCudnnConvForwardCallTarget) { - thunk = MakeUnique<ConvolutionThunk>( + thunk = absl::make_unique<ConvolutionThunk>( CudnnConvKind::kForward, /*input_buffer=*/lhs_slice, /*filter_buffer=*/rhs_slice, @@ -490,7 +491,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { backend_config.algorithm(), backend_config.tensor_ops_enabled(), custom_call); } else if (target == kCudnnConvBackwardInputCallTarget) { - thunk = MakeUnique<ConvolutionThunk>( + thunk = absl::make_unique<ConvolutionThunk>( CudnnConvKind::kBackwardInput, /*input_buffer=*/conv_result_slice, /*filter_buffer=*/rhs_slice, @@ -504,7 +505,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { backend_config.algorithm(), backend_config.tensor_ops_enabled(), custom_call); } else if (target == kCudnnConvBackwardFilterCallTarget) { - thunk = MakeUnique<ConvolutionThunk>( + thunk = absl::make_unique<ConvolutionThunk>( CudnnConvKind::kBackwardFilter, /*input_buffer=*/lhs_slice, /*filter_buffer=*/conv_result_slice, @@ -577,7 +578,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunks.push_back( BuildKernelThunk(fusion, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( - MakeUnique<SequentialThunk>(std::move(thunks), fusion)); + absl::make_unique<SequentialThunk>(std::move(thunks), fusion)); std::vector<IrArray> parameter_arrays; for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArray(*operand, *fusion)); @@ -1719,7 +1720,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { thunks.push_back( BuildKernelThunk(reduce, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( - MakeUnique<SequentialThunk>(std::move(thunks), reduce)); + absl::make_unique<SequentialThunk>(std::move(thunks), reduce)); return EmitReductionToVector( reduce, input->shape(), {[&](const IrArray::Index& index) { @@ -1761,7 +1762,7 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { for (const HloInstruction* tuple_element : tuple->operands()) { tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); } - thunk_sequence_->emplace_back(MakeUnique<TupleThunk>( + thunk_sequence_->emplace_back(absl::make_unique<TupleThunk>( tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); return Status::OK(); } @@ -1793,8 +1794,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( thunks.push_back(std::move(initializer_thunk)); thunks.push_back(BuildKernelThunk(select_and_scatter, /*implements_whole_instruction=*/false)); - thunk_sequence_->emplace_back( - MakeUnique<SequentialThunk>(std::move(thunks), select_and_scatter)); + thunk_sequence_->emplace_back(absl::make_unique<SequentialThunk>( + std::move(thunks), select_and_scatter)); // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { @@ -2019,7 +2020,7 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { thunks.push_back(std::move(rng_thunk)); thunks.push_back(std::move(increment_seed_thunk)); thunk_sequence_->emplace_back( - MakeUnique<SequentialThunk>(std::move(thunks), rng)); + absl::make_unique<SequentialThunk>(std::move(thunks), rng)); return Status::OK(); } @@ -2044,7 +2045,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { auto values_destination = GetAllocationSlice(*sort, values_shape_index); if (keys_destination != GetAllocationSlice(*keys)) { - thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>( + thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( /*source_address=*/GetAllocationSlice(*keys), /*destination_buffer=*/keys_destination, /*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr)); @@ -2052,7 +2053,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { if (values != nullptr && values_destination != GetAllocationSlice(*values)) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. - thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>( + thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( /*source_address=*/GetAllocationSlice(*values), /*destination_buffer=*/values_destination, /*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr)); @@ -2104,7 +2105,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { } thunk_sequence_->emplace_back( - MakeUnique<SequentialThunk>(std::move(thunks), sort)); + absl::make_unique<SequentialThunk>(std::move(thunks), sort)); return Status::OK(); } @@ -2131,7 +2132,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { if (crs->operand_count() == 1) { CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); - thunk_sequence_->push_back(MakeUnique<DeviceToDeviceCopyThunk>( + thunk_sequence_->push_back(absl::make_unique<DeviceToDeviceCopyThunk>( /*source_address=*/GetAllocationSlice(*crs->operand(0)), /*destination_buffer=*/GetAllocationSlice(*crs), /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); @@ -2146,17 +2147,17 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment() .GetUniqueSlice(crs, {i}) .ValueOrDie()); - thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>( + thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( /*source_address=*/GetAllocationSlice(*crs->operand(i)), /*destination_buffer=*/tuple_element_buffers.back(), /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr)); } // Output a tuple of the buffers above. - thunks.push_back(MakeUnique<TupleThunk>(tuple_element_buffers, - GetAllocationSlice(*crs), nullptr)); + thunks.push_back(absl::make_unique<TupleThunk>( + tuple_element_buffers, GetAllocationSlice(*crs), nullptr)); thunk_sequence_->push_back( - MakeUnique<SequentialThunk>(std::move(thunks), crs)); + absl::make_unique<SequentialThunk>(std::move(thunks), crs)); return Status::OK(); } @@ -2390,7 +2391,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk( llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } - return MakeUnique<KernelThunk>( + return absl::make_unique<KernelThunk>( non_constant_buffers, llvm_ir::AsString(kernel->getName()), implements_whole_instruction ? inst : nullptr, unroll_factor); } @@ -2399,7 +2400,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); CHECK_EQ(HloOpcode::kConstant, operand->opcode()); - return MakeUnique<HostToDeviceCopyThunk>( + return absl::make_unique<HostToDeviceCopyThunk>( /*source_address=*/operand->literal().untyped_data(), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -2411,7 +2412,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk( std::unique_ptr<Thunk> IrEmitterUnnested::BuildDeviceToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); - return MakeUnique<DeviceToDeviceCopyThunk>( + return absl::make_unique<DeviceToDeviceCopyThunk>( /*source_address=*/GetAllocationSlice(*operand), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -2431,7 +2432,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk( .GetUniqueSlice(inst, index) .ConsumeValueOrDie(); }); - return MakeUnique<InfeedThunk>(slices, inst); + return absl::make_unique<InfeedThunk>(slices, inst); } std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk( @@ -2448,7 +2449,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk( *slice = status_or_slice.ConsumeValueOrDie(); } }); - return MakeUnique<OutfeedThunk>(std::move(slices), inst); + return absl::make_unique<OutfeedThunk>(std::move(slices), inst); } namespace { @@ -2471,7 +2472,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk( if (inst->opcode() == HloOpcode::kDot) { const HloInstruction* lhs = inst->operand(0); const HloInstruction* rhs = inst->operand(1); - return MakeUnique<GemmThunk>( + return absl::make_unique<GemmThunk>( GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. @@ -2513,7 +2514,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk( const HloInstruction* rhs = inst->operand(rhs_parameter->parameter_number()); - return MakeUnique<GemmThunk>( + return absl::make_unique<GemmThunk>( GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. @@ -2530,11 +2531,12 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk( std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); - return MakeUnique<FftThunk>(inst->fft_type(), inst->fft_length(), - /*input_buffer=*/GetAllocationSlice(*operand), - /*output_buffer=*/GetAllocationSlice(*inst), - /*input_shape=*/operand->shape(), - /*output_shape=*/inst->shape(), inst); + return absl::make_unique<FftThunk>( + inst->fft_type(), inst->fft_length(), + /*input_buffer=*/GetAllocationSlice(*operand), + /*output_buffer=*/GetAllocationSlice(*inst), + /*input_shape=*/operand->shape(), + /*output_shape=*/inst->shape(), inst); } StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( @@ -2584,8 +2586,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( ArraySlice<uint8> literal_bytes( reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes); if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return { - MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), nullptr)}; + return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index), + nullptr)}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by @@ -2602,7 +2604,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16)); } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); - return {MakeUnique<Memset32BitValueThunk>( + return {absl::make_unique<Memset32BitValueThunk>( pattern32, GetAllocationSlice(*hlo, index), nullptr)}; } @@ -2613,7 +2615,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( literal_bytes.size() - 4) == 0) { uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); - return {MakeUnique<Memset32BitValueThunk>( + return {absl::make_unique<Memset32BitValueThunk>( word, GetAllocationSlice(*hlo, index), nullptr)}; } } @@ -2765,7 +2767,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk( ir_emitter_context_); TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return MakeUnique<WhileThunk>( + return absl::make_unique<WhileThunk>( GetAllocationSlice(*condition->root_instruction()), // cond result ir_emitter_condition.ConsumeThunkSequence(), ir_emitter_body.ConsumeThunkSequence(), hlo); @@ -2783,8 +2785,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk( ir_emitter_context_); TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return MakeUnique<ForThunk>(loop_limit, - ir_emitter_body.ConsumeThunkSequence(), hlo); + return absl::make_unique<ForThunk>( + loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo); } std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk( @@ -2804,7 +2806,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk( ir_emitter_context_); TF_CHECK_OK(false_computation->Accept(&ir_emitter_false)); - return MakeUnique<ConditionalThunk>( + return absl::make_unique<ConditionalThunk>( GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo->operand(1)), GetAllocationSlice(*hlo->operand(2)), |