aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/channel_tracker.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/channel_tracker.cc')
-rw-r--r--tensorflow/compiler/xla/service/channel_tracker.cc91
1 files changed, 91 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc
new file mode 100644
index 0000000000..b3784c36ff
--- /dev/null
+++ b/tensorflow/compiler/xla/service/channel_tracker.cc
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/channel_tracker.h"
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+ChannelTracker::ChannelTracker() : next_channel_(1) {}
+
+ChannelHandle ChannelTracker::NewChannel() {
+ tensorflow::mutex_lock lock(channel_mutex_);
+
+ // Create a new channel handle with a unique value.
+ const ChannelHandle new_handle = AllocateHandle();
+
+ // Register a channel object associated with the handle.
+ Channel channel;
+ channel.has_sender = false;
+ channel.receiver_count = 0;
+ opaque_to_channel_[new_handle.handle()] = channel;
+
+ return new_handle;
+}
+
+Status ChannelTracker::RegisterSend(const ChannelHandle& handle) {
+ tensorflow::mutex_lock lock(channel_mutex_);
+ return RegisterSendInternal(handle);
+}
+
+Status ChannelTracker::RegisterRecv(const ChannelHandle& handle) {
+ tensorflow::mutex_lock lock(channel_mutex_);
+ return RegisterRecvInternal(handle);
+}
+
+ChannelHandle ChannelTracker::AllocateHandle() {
+ int64 handle_value = next_channel_++;
+ ChannelHandle result;
+ result.set_handle(handle_value);
+ return result;
+}
+
+Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) {
+ if (opaque_to_channel_.count(handle.handle()) == 0) {
+ return NotFound("channel handle not found: %lld", handle.handle());
+ }
+ Channel& channel = opaque_to_channel_[handle.handle()];
+ if (channel.has_sender) {
+ return FailedPrecondition("channel handle is already used by a sender");
+ }
+ channel.has_sender = true;
+ return Status::OK();
+}
+
+Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) {
+ if (opaque_to_channel_.count(handle.handle()) == 0) {
+ return NotFound("channel handle not found: %lld", handle.handle());
+ }
+ Channel& channel = opaque_to_channel_[handle.handle()];
+ // TODO(b/33942691): Allow more than 1 receivers for broadcast.
+ if (channel.receiver_count >= 1) {
+ return FailedPrecondition("channel handle is already used by a receiver");
+ }
+ channel.receiver_count += 1;
+ return Status::OK();
+}
+
+} // namespace xla