diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/client.h')
-rw-r--r-- | tensorflow/compiler/xla/client/client.h | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 68f0d0ac78..be50cebfcc 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -20,8 +20,8 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -178,10 +178,15 @@ class Client { StatusOr<std::unique_ptr<ProgramShape>> GetComputationShape( const XlaComputation& computation); - // Creates a channel handle that can be used to transfer data between - // two computations via a pair of Send and Recv instructions. + // Creates a channel handle that can be used to transfer data between two + // computations on different devices via a pair of Send and Recv instructions. StatusOr<ChannelHandle> CreateChannelHandle(); + // Create a channel for communicating with the host via a SendtoHost or + // RecvFromHost operation. + StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle(); + StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle(); + StatusOr<XlaComputation> LoadSnapshot(const HloSnapshot& module); ServiceInterface* stub() { return stub_; } @@ -192,6 +197,9 @@ class Client { StatusOr<string> ExecutionStatsAsString(const XlaComputation& computation, const ExecutionProfile& profile); + StatusOr<ChannelHandle> CreateChannelHandleByType( + ChannelHandle::ChannelType type); + ServiceInterface* stub_; // Stub that this client is connected on. TF_DISALLOW_COPY_AND_ASSIGN(Client); |