diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/infeed_thunk.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/infeed_thunk.cc | 98 |
1 files changed, 58 insertions, 40 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 2b63d8727c..fee6d2af3b 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -27,51 +28,70 @@ InfeedThunk::InfeedThunk( : Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {} Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { - VLOG(2) << "Infeeding to GPU "; + se::Stream* stream, + HloExecutionProfiler* profiler) { + VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString(); + + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); + ShapeTree<InfeedBuffer> infeed_buffers = + GetOrCreateInfeedManager()->BlockingGetNextDestination(); + + { + // The infeed buffer has an extra outer tuple with a token. Adjust the index + // accordingly. + ShapeIndex index = {0}; + std::function<void(std::vector<void*>*)> copy_tuple_contents = + [&](std::vector<void*>* tuple_element_addresses) { + const Shape& shape = ShapeUtil::GetSubshape(infeed_buffers.shape(), + ShapeIndexView(index, 1)); + // For the leaf buffers of the tuple copy the elements directly. + if (ShapeUtil::IsArray(shape)) { + const BufferAllocation::Slice& tuple_element_buffer = + infeed_slices_.element(index); + se::DeviceMemoryBase tuple_element_address = + buffer_allocations.GetDeviceAddress(tuple_element_buffer); + + InfeedBuffer* buffer = + infeed_buffers.mutable_element(ShapeIndexView(index, 1)); + stream->ThenMemcpy(&tuple_element_address, + *(buffer->device_memory()), buffer->length()); + tuple_element_addresses->push_back(tuple_element_address.opaque()); + return; + } + + const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape); + index.push_back(0); + std::vector<void*> inner_tuple_element_addresses; + for (int64 i = 0; i < tuple_element_count; ++i) { + index.back() = i; + copy_tuple_contents(&inner_tuple_element_addresses); + } + index.pop_back(); + + // Create a buffer of pointers for non-leaf buffers. + CHECK_EQ(tuple_element_count, inner_tuple_element_addresses.size()); + auto host_size = inner_tuple_element_addresses.size() * sizeof(void*); + se::DeviceMemoryBase tuple_address = + buffer_allocations.GetDeviceAddress( + infeed_slices_.element(index)); + stream->ThenMemcpy(&tuple_address, + inner_tuple_element_addresses.data(), host_size); + tuple_element_addresses->push_back(tuple_address.opaque()); + }; - // First copy the infeed data which is element 0 of the infeed instruction's - // two-tuple output (the other element is a token). - se::DeviceMemoryBase data_address = - buffer_allocations.GetDeviceAddress(infeed_slices_.element({0})); - InfeedManager* infeed_manager = GetOrCreateInfeedManager(); - std::vector<InfeedBuffer*> infeed_buffers; - const Shape& data_shape = - ShapeUtil::GetTupleElementShape(hlo_instruction()->shape(), 0); - if (ShapeUtil::IsTuple(data_shape)) { - CHECK(!ShapeUtil::IsNestedTuple(data_shape)); - // Transfer the tuple elements first. std::vector<void*> tuple_element_addresses; - for (int i = 0; i < ShapeUtil::TupleElementCount(data_shape); ++i) { - const BufferAllocation::Slice& tuple_element_buffer = - infeed_slices_.element({0, i}); - se::DeviceMemoryBase tuple_element_address = - buffer_allocations.GetDeviceAddress(tuple_element_buffer); - - InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); - infeed_buffers.push_back(buffer); - stream->ThenMemcpy(&tuple_element_address, *(buffer->device_memory()), - buffer->length()); - tuple_element_addresses.push_back(tuple_element_address.opaque()); - } - // Transfer the tuple outer buffer. - auto host_size = tuple_element_addresses.size() * sizeof(void*); - stream->ThenMemcpy(&data_address, tuple_element_addresses.data(), - host_size); - } else { - InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); - infeed_buffers.push_back(buffer); - stream->ThenMemcpy(&data_address, *(buffer->device_memory()), - buffer->length()); + copy_tuple_contents(&tuple_element_addresses); + CHECK_EQ(1, tuple_element_addresses.size()); } // Construct top-level tuple of infeed containing the data and the token. Use // a nullptr for the token, it should never be dereferenced. - std::vector<void*> infeed_addresses = {data_address.opaque(), nullptr}; + se::DeviceMemoryBase data_address = + buffer_allocations.GetDeviceAddress(infeed_slices_.element({0})); + void* infeed_addresses[] = {data_address.opaque(), nullptr}; se::DeviceMemoryBase top_level_address = buffer_allocations.GetDeviceAddress(infeed_slices_.element({})); - stream->ThenMemcpy(&top_level_address, infeed_addresses.data(), - 2 * sizeof(void*)); + stream->ThenMemcpy(&top_level_address, infeed_addresses, 2 * sizeof(void*)); Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { @@ -79,8 +99,6 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, stream, block_status.error_message().c_str()); } - infeed_manager->ReleaseBuffers(infeed_buffers); - VLOG(2) << "Infeeding to GPU complete"; return Status::OK(); } |