aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/client.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/client.cc')
-rw-r--r--tensorflow/compiler/xla/client/client.cc19
1 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 3d596a6e65..d0ce5e8a6a 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -18,9 +18,10 @@ limitations under the License.
#include <string>
#include <utility>
+#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -409,8 +410,10 @@ StatusOr<string> Client::ExecutionStatsAsString(
return string("[Execution Statistics] not available.");
}
-StatusOr<ChannelHandle> Client::CreateChannelHandle() {
+StatusOr<ChannelHandle> Client::CreateChannelHandleByType(
+ ChannelHandle::ChannelType type) {
CreateChannelHandleRequest request;
+ request.set_channel_type(type);
CreateChannelHandleResponse response;
VLOG(1) << "making create channel handle request";
@@ -424,4 +427,16 @@ StatusOr<ChannelHandle> Client::CreateChannelHandle() {
return response.channel();
}
+StatusOr<ChannelHandle> Client::CreateChannelHandle() {
+ return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_DEVICE);
+}
+
+StatusOr<ChannelHandle> Client::CreateHostToDeviceChannelHandle() {
+ return CreateChannelHandleByType(ChannelHandle::HOST_TO_DEVICE);
+}
+
+StatusOr<ChannelHandle> Client::CreateDeviceToHostChannelHandle() {
+ return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_HOST);
+}
+
} // namespace xla