aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/xla_client/xla_builder.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_client/xla_builder.h')
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h59
1 files changed, 51 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index 2be6f4a553..8359d936b7 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -54,7 +54,16 @@ class XlaOp {
}
~XlaOp() = default;
- XlaBuilder* builder() const { return builder_; }
+ // Precondition: !IsUninitialized().
+ //
+ // It's very common to do foo.builder()->bar(). Without this precondition, if
+ // foo.builder() is null, the call to bar will segfault at some point possibly
+ // deep in the callstack when we finally dereference `this`. The precondition
+ // lets us avoid this tricky-to-debug problem.
+ XlaBuilder* builder() const {
+ CHECK(builder_ != nullptr);
+ return builder_;
+ }
// Returns true if the XlaOp represents valid, non-erroneous value.
bool valid() const { return handle_ >= 0; }
@@ -848,12 +857,21 @@ class XlaBuilder {
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
- // Enqueues a Send node onto the computation, to send the given operand to
- // a Recv instruction that shares the same channel handle.
+ // Enqueues a Send node onto the computation for device-to-device
+ // communication, to send the given operand to a Recv instruction that shares
+ // the same channel handle.
void Send(const XlaOp& operand, const ChannelHandle& handle);
XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
const ChannelHandle& handle);
+ // Enqueues a Send node which sends data to the host.
+ XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout, const ChannelHandle& handle);
+
+ // Enqueues a Recv node which receives data from the host.
+ XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+
// Enqueues an AfterAll operation with no operands producing a token-shaped
// value.
XlaOp CreateToken();
@@ -1244,6 +1262,9 @@ class XlaBuilder {
friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
friend XlaOp IsFinite(const XlaOp& operand);
+ // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota
+ // in xla/client/lib/numeric.h with this (renamed to xla::Iota).
+ friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size);
friend XlaOp ConvertElementType(const XlaOp& operand,
PrimitiveType new_element_type);
friend XlaOp BitcastConvertType(const XlaOp& operand,
@@ -1293,6 +1314,11 @@ class XlaBuilder {
const ChannelHandle& handle);
friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
const ChannelHandle& handle);
+ friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const ChannelHandle& handle);
+ friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
const string& config);
friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
@@ -1951,8 +1977,10 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
-// Enqueues a Send node onto the computation, to send the given operand to
-// a Recv instruction that shares the same channel handle.
+// Enqueues a Send node onto the computation for device-to-device
+// communication. This operation sends the given operand to
+// a Recv instruction in a different computation that shares the same channel
+// handle.
void Send(const XlaOp& operand, const ChannelHandle& handle);
// Variant of Send which takes a token-shaped operand and produces a
@@ -1961,9 +1989,10 @@ void Send(const XlaOp& operand, const ChannelHandle& handle);
XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
const ChannelHandle& handle);
-// Enqueues a Recv node onto the computation. The data comes from a Send
-// instruction that shares the same channel handle and its shape must
-// be the same as the given shape.
+// Enqueues a Recv node onto the computation for device-to-device
+// communication. The data comes from a Send instruction in a different
+// computation that shares the same channel handle and its shape must be the
+// same as the given shape.
XlaOp Recv(XlaBuilder* builder, const Shape& shape,
const ChannelHandle& handle);
@@ -1974,6 +2003,20 @@ XlaOp Recv(XlaBuilder* builder, const Shape& shape,
XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
const ChannelHandle& handle);
+// Enqueues a Send node which transfers data from the device to the host. The
+// 'shape_with_layout' argument defines the layout of the data transferred; its
+// shape must be compatible with the shape of the operand. The operand must be
+// array-shaped.
+// TODO(b/111544877): Support tuple shapes.
+XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout, const ChannelHandle& handle);
+
+// Enqueues a Recv node which transfers data from the host to the device. The
+// given shape must contain a layout and must be an array.
+// TODO(b/111544877): Support tuple shapes.
+XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+
// Enqueues an operation (AfterAll) with no operands that produces a
// token-shaped value. Tokens are used for ordering side-effecting operations.
// This is a separate method from AfterAll to facility the removal of