diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_client/xla_builder.h')
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.h | 59 |
1 files changed, 51 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 2be6f4a553..8359d936b7 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -54,7 +54,16 @@ class XlaOp { } ~XlaOp() = default; - XlaBuilder* builder() const { return builder_; } + // Precondition: !IsUninitialized(). + // + // It's very common to do foo.builder()->bar(). Without this precondition, if + // foo.builder() is null, the call to bar will segfault at some point possibly + // deep in the callstack when we finally dereference `this`. The precondition + // lets us avoid this tricky-to-debug problem. + XlaBuilder* builder() const { + CHECK(builder_ != nullptr); + return builder_; + } // Returns true if the XlaOp represents valid, non-erroneous value. bool valid() const { return handle_ >= 0; } @@ -848,12 +857,21 @@ class XlaBuilder { const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds); - // Enqueues a Send node onto the computation, to send the given operand to - // a Recv instruction that shares the same channel handle. + // Enqueues a Send node onto the computation for device-to-device + // communication, to send the given operand to a Recv instruction that shares + // the same channel handle. void Send(const XlaOp& operand, const ChannelHandle& handle); XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, const ChannelHandle& handle); + // Enqueues a Send node which sends data to the host. + XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, const ChannelHandle& handle); + + // Enqueues a Recv node which receives data from the host. + XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + // Enqueues an AfterAll operation with no operands producing a token-shaped // value. XlaOp CreateToken(); @@ -1244,6 +1262,9 @@ class XlaBuilder { friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); friend XlaOp IsFinite(const XlaOp& operand); + // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota + // in xla/client/lib/numeric.h with this (renamed to xla::Iota). + friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size); friend XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); friend XlaOp BitcastConvertType(const XlaOp& operand, @@ -1293,6 +1314,11 @@ class XlaBuilder { const ChannelHandle& handle); friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, const ChannelHandle& handle); + friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const ChannelHandle& handle); + friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, const string& config); friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, @@ -1951,8 +1977,10 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds); -// Enqueues a Send node onto the computation, to send the given operand to -// a Recv instruction that shares the same channel handle. +// Enqueues a Send node onto the computation for device-to-device +// communication. This operation sends the given operand to +// a Recv instruction in a different computation that shares the same channel +// handle. void Send(const XlaOp& operand, const ChannelHandle& handle); // Variant of Send which takes a token-shaped operand and produces a @@ -1961,9 +1989,10 @@ void Send(const XlaOp& operand, const ChannelHandle& handle); XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, const ChannelHandle& handle); -// Enqueues a Recv node onto the computation. The data comes from a Send -// instruction that shares the same channel handle and its shape must -// be the same as the given shape. +// Enqueues a Recv node onto the computation for device-to-device +// communication. The data comes from a Send instruction in a different +// computation that shares the same channel handle and its shape must be the +// same as the given shape. XlaOp Recv(XlaBuilder* builder, const Shape& shape, const ChannelHandle& handle); @@ -1974,6 +2003,20 @@ XlaOp Recv(XlaBuilder* builder, const Shape& shape, XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, const ChannelHandle& handle); +// Enqueues a Send node which transfers data from the device to the host. The +// 'shape_with_layout' argument defines the layout of the data transferred; its +// shape must be compatible with the shape of the operand. The operand must be +// array-shaped. +// TODO(b/111544877): Support tuple shapes. +XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, const ChannelHandle& handle); + +// Enqueues a Recv node which transfers data from the host to the device. The +// given shape must contain a layout and must be an array. +// TODO(b/111544877): Support tuple shapes. +XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + // Enqueues an operation (AfterAll) with no operands that produces a // token-shaped value. Tokens are used for ordering side-effecting operations. // This is a separate method from AfterAll to facility the removal of |