aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/client.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/client.h')
-rw-r--r--tensorflow/compiler/xla/client/client.h16
1 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index 68f0d0ac78..be50cebfcc 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -20,8 +20,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/global_data.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -178,10 +178,15 @@ class Client {
StatusOr<std::unique_ptr<ProgramShape>> GetComputationShape(
const XlaComputation& computation);
- // Creates a channel handle that can be used to transfer data between
- // two computations via a pair of Send and Recv instructions.
+ // Creates a channel handle that can be used to transfer data between two
+ // computations on different devices via a pair of Send and Recv instructions.
StatusOr<ChannelHandle> CreateChannelHandle();
+ // Create a channel for communicating with the host via a SendtoHost or
+ // RecvFromHost operation.
+ StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle();
+ StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle();
+
StatusOr<XlaComputation> LoadSnapshot(const HloSnapshot& module);
ServiceInterface* stub() { return stub_; }
@@ -192,6 +197,9 @@ class Client {
StatusOr<string> ExecutionStatsAsString(const XlaComputation& computation,
const ExecutionProfile& profile);
+ StatusOr<ChannelHandle> CreateChannelHandleByType(
+ ChannelHandle::ChannelType type);
+
ServiceInterface* stub_; // Stub that this client is connected on.
TF_DISALLOW_COPY_AND_ASSIGN(Client);