diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/client.cc')
-rw-r--r-- | tensorflow/compiler/xla/client/client.cc | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 3d596a6e65..d0ce5e8a6a 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -18,9 +18,10 @@ limitations under the License. #include <string> #include <utility> +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -409,8 +410,10 @@ StatusOr<string> Client::ExecutionStatsAsString( return string("[Execution Statistics] not available."); } -StatusOr<ChannelHandle> Client::CreateChannelHandle() { +StatusOr<ChannelHandle> Client::CreateChannelHandleByType( + ChannelHandle::ChannelType type) { CreateChannelHandleRequest request; + request.set_channel_type(type); CreateChannelHandleResponse response; VLOG(1) << "making create channel handle request"; @@ -424,4 +427,16 @@ StatusOr<ChannelHandle> Client::CreateChannelHandle() { return response.channel(); } +StatusOr<ChannelHandle> Client::CreateChannelHandle() { + return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_DEVICE); +} + +StatusOr<ChannelHandle> Client::CreateHostToDeviceChannelHandle() { + return CreateChannelHandleByType(ChannelHandle::HOST_TO_DEVICE); +} + +StatusOr<ChannelHandle> Client::CreateDeviceToHostChannelHandle() { + return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_HOST); +} + } // namespace xla |