aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/channel_tracker.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/channel_tracker.cc')
-rw-r--r--tensorflow/compiler/xla/service/channel_tracker.cc28
1 files changed, 25 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc
index a5b392cbc3..13008efed1 100644
--- a/tensorflow/compiler/xla/service/channel_tracker.cc
+++ b/tensorflow/compiler/xla/service/channel_tracker.cc
@@ -31,16 +31,23 @@ namespace xla {
ChannelTracker::ChannelTracker() : next_channel_(1) {}
-ChannelHandle ChannelTracker::NewChannel() {
+StatusOr<ChannelHandle> ChannelTracker::NewChannel(
+ ChannelHandle::ChannelType type) {
+ if (type != ChannelHandle::DEVICE_TO_DEVICE &&
+ type != ChannelHandle::HOST_TO_DEVICE &&
+ type != ChannelHandle::DEVICE_TO_HOST) {
+ return InvalidArgument("Invalid channel type: %d", type);
+ }
tensorflow::mutex_lock lock(channel_mutex_);
// Create a new channel handle with a unique value.
- const ChannelHandle new_handle = AllocateHandle();
+ ChannelHandle new_handle = AllocateHandle(type);
// Register a channel object associated with the handle.
Channel channel;
channel.has_sender = false;
channel.receiver_count = 0;
+ channel.type = type;
opaque_to_channel_[new_handle.handle()] = channel;
return new_handle;
@@ -56,10 +63,11 @@ Status ChannelTracker::RegisterRecv(const ChannelHandle& handle) {
return RegisterRecvInternal(handle);
}
-ChannelHandle ChannelTracker::AllocateHandle() {
+ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) {
int64 handle_value = next_channel_++;
ChannelHandle result;
result.set_handle(handle_value);
+ result.set_type(type);
return result;
}
@@ -68,6 +76,13 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) {
return NotFound("channel handle not found: %lld", handle.handle());
}
Channel& channel = opaque_to_channel_[handle.handle()];
+ if (channel.type == ChannelHandle::HOST_TO_DEVICE) {
+ return FailedPrecondition(
+ "host-to-device channels cannot be used with a Send operation; "
+ "channel handle: %lld",
+ handle.handle());
+ }
+
if (channel.has_sender) {
return FailedPrecondition(
"when registering send, passed a channel handle that is already used "
@@ -83,6 +98,13 @@ Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) {
return NotFound("channel handle not found: %lld", handle.handle());
}
Channel& channel = opaque_to_channel_[handle.handle()];
+ if (channel.type == ChannelHandle::DEVICE_TO_HOST) {
+ return FailedPrecondition(
+ "device-to-host channels cannot be used with a Recv operation; "
+ "channel handle: %lld",
+ handle.handle());
+ }
+
// TODO(b/33942691): Allow more than 1 receivers for broadcast.
if (channel.receiver_count >= 1) {
return FailedPrecondition(