aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/infeed_thunk.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_thunk.cc98
1 files changed, 58 insertions, 40 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
index 2b63d8727c..fee6d2af3b 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -27,51 +28,70 @@ InfeedThunk::InfeedThunk(
: Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {}
Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
- VLOG(2) << "Infeeding to GPU ";
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
+ VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString();
+
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
+ ShapeTree<InfeedBuffer> infeed_buffers =
+ GetOrCreateInfeedManager()->BlockingGetNextDestination();
+
+ {
+ // The infeed buffer has an extra outer tuple with a token. Adjust the index
+ // accordingly.
+ ShapeIndex index = {0};
+ std::function<void(std::vector<void*>*)> copy_tuple_contents =
+ [&](std::vector<void*>* tuple_element_addresses) {
+ const Shape& shape = ShapeUtil::GetSubshape(infeed_buffers.shape(),
+ ShapeIndexView(index, 1));
+ // For the leaf buffers of the tuple copy the elements directly.
+ if (ShapeUtil::IsArray(shape)) {
+ const BufferAllocation::Slice& tuple_element_buffer =
+ infeed_slices_.element(index);
+ se::DeviceMemoryBase tuple_element_address =
+ buffer_allocations.GetDeviceAddress(tuple_element_buffer);
+
+ InfeedBuffer* buffer =
+ infeed_buffers.mutable_element(ShapeIndexView(index, 1));
+ stream->ThenMemcpy(&tuple_element_address,
+ *(buffer->device_memory()), buffer->length());
+ tuple_element_addresses->push_back(tuple_element_address.opaque());
+ return;
+ }
+
+ const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape);
+ index.push_back(0);
+ std::vector<void*> inner_tuple_element_addresses;
+ for (int64 i = 0; i < tuple_element_count; ++i) {
+ index.back() = i;
+ copy_tuple_contents(&inner_tuple_element_addresses);
+ }
+ index.pop_back();
+
+ // Create a buffer of pointers for non-leaf buffers.
+ CHECK_EQ(tuple_element_count, inner_tuple_element_addresses.size());
+ auto host_size = inner_tuple_element_addresses.size() * sizeof(void*);
+ se::DeviceMemoryBase tuple_address =
+ buffer_allocations.GetDeviceAddress(
+ infeed_slices_.element(index));
+ stream->ThenMemcpy(&tuple_address,
+ inner_tuple_element_addresses.data(), host_size);
+ tuple_element_addresses->push_back(tuple_address.opaque());
+ };
- // First copy the infeed data which is element 0 of the infeed instruction's
- // two-tuple output (the other element is a token).
- se::DeviceMemoryBase data_address =
- buffer_allocations.GetDeviceAddress(infeed_slices_.element({0}));
- InfeedManager* infeed_manager = GetOrCreateInfeedManager();
- std::vector<InfeedBuffer*> infeed_buffers;
- const Shape& data_shape =
- ShapeUtil::GetTupleElementShape(hlo_instruction()->shape(), 0);
- if (ShapeUtil::IsTuple(data_shape)) {
- CHECK(!ShapeUtil::IsNestedTuple(data_shape));
- // Transfer the tuple elements first.
std::vector<void*> tuple_element_addresses;
- for (int i = 0; i < ShapeUtil::TupleElementCount(data_shape); ++i) {
- const BufferAllocation::Slice& tuple_element_buffer =
- infeed_slices_.element({0, i});
- se::DeviceMemoryBase tuple_element_address =
- buffer_allocations.GetDeviceAddress(tuple_element_buffer);
-
- InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer();
- infeed_buffers.push_back(buffer);
- stream->ThenMemcpy(&tuple_element_address, *(buffer->device_memory()),
- buffer->length());
- tuple_element_addresses.push_back(tuple_element_address.opaque());
- }
- // Transfer the tuple outer buffer.
- auto host_size = tuple_element_addresses.size() * sizeof(void*);
- stream->ThenMemcpy(&data_address, tuple_element_addresses.data(),
- host_size);
- } else {
- InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer();
- infeed_buffers.push_back(buffer);
- stream->ThenMemcpy(&data_address, *(buffer->device_memory()),
- buffer->length());
+ copy_tuple_contents(&tuple_element_addresses);
+ CHECK_EQ(1, tuple_element_addresses.size());
}
// Construct top-level tuple of infeed containing the data and the token. Use
// a nullptr for the token, it should never be dereferenced.
- std::vector<void*> infeed_addresses = {data_address.opaque(), nullptr};
+ se::DeviceMemoryBase data_address =
+ buffer_allocations.GetDeviceAddress(infeed_slices_.element({0}));
+ void* infeed_addresses[] = {data_address.opaque(), nullptr};
se::DeviceMemoryBase top_level_address =
buffer_allocations.GetDeviceAddress(infeed_slices_.element({}));
- stream->ThenMemcpy(&top_level_address, infeed_addresses.data(),
- 2 * sizeof(void*));
+ stream->ThenMemcpy(&top_level_address, infeed_addresses, 2 * sizeof(void*));
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
@@ -79,8 +99,6 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
stream, block_status.error_message().c_str());
}
- infeed_manager->ReleaseBuffers(infeed_buffers);
-
VLOG(2) << "Infeeding to GPU complete";
return Status::OK();
}