diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_client/xla_builder.h')
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.h | 740 |
1 files changed, 588 insertions, 152 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index fe31774b86..980e84e40c 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -22,7 +22,8 @@ limitations under the License. #include <utility> #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -53,7 +54,16 @@ class XlaOp { } ~XlaOp() = default; - XlaBuilder* builder() const { return builder_; } + // Precondition: !IsUninitialized(). + // + // It's very common to do foo.builder()->bar(). Without this precondition, if + // foo.builder() is null, the call to bar will segfault at some point possibly + // deep in the callstack when we finally dereference `this`. The precondition + // lets us avoid this tricky-to-debug problem. + XlaBuilder* builder() const { + CHECK(builder_ != nullptr); + return builder_; + } // Returns true if the XlaOp represents valid, non-erroneous value. bool valid() const { return handle_ >= 0; } @@ -158,6 +168,93 @@ class XlaBuilder { die_immediately_on_error_ = enabled; } + // Default dimension numbers used for a 2D convolution. + static constexpr int64 kConvBatchDimension = 0; + static constexpr int64 kConvFeatureDimension = 1; + static constexpr int64 kConvFirstSpatialDimension = 2; + static constexpr int64 kConvSecondSpatialDimension = 3; + static constexpr int64 kConvKernelOutputDimension = 0; + static constexpr int64 kConvKernelInputDimension = 1; + static constexpr int64 kConvKernelFirstSpatialDimension = 2; + static constexpr int64 kConvKernelSecondSpatialDimension = 3; + + // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for + // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for + // the kernel operand + // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. + static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( + int num_spatial_dims = 2); + + // Returns an error if the convolution dimension numbers have conflicts. + static Status Validate(const ConvolutionDimensionNumbers& dnum); + + // Returns a new XlaBuilder whose resultant Computation is used only by this + // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error + // behavior as the parent. + std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name); + + // Builds the computation with the requested operations, or returns a non-ok + // status. Note that all ops that have been enqueued will be moved to the + // computation being returned. + StatusOr<XlaComputation> Build(); + + // Builds the computation with the requested operations, or notes an error in + // the parent XlaBuilder and returns an empty computation if building failed. + // This function is intended to be used where the returned XlaComputation is + // only used by the parent XlaBuilder and hence further operation on the + // returned XlaComputation will simply be error'ed out if an error occurred + // while building this computation. If the built computation is to be used by + // a XlaBuilder other than the parent XlaBuilder then Build() should be used + // instead. + XlaComputation BuildAndNoteError(); + + // Returns a subgraph that roots on the given root. If the root is not a + // compile-time constant (see `IsConstant`), returns an error. + // + // This will copy the needed ops/computations to the subgraph. + StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const; + + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // XlaOp and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + + // Returns the shape of the given op. + StatusOr<Shape> GetShape(const XlaOp& op) const; + + // Returns the (inferred) result for the current computation's shape. + StatusOr<ProgramShape> GetProgramShape() const; + + // Reports an error to the builder, by + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to + // Build()/GetShape()/...) + // * dying if die_immediately_on_error_ is true. + // Returns an XlaOp with an invalid handle but a valid builder. This value can + // be returned in place of a value in APIs that return an XlaOp. + XlaOp ReportError(const Status& error); + + // A helper function that converts a StatusOr<XlaOp> into an XlaOp. + // If the Status was an error, reports the error to builder and returns an + // invalid XlaOp handle. + XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op); + + // A helper function that runs a function that returns a StatusOr<XlaOp> and + // returns an XlaOp. + XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on any parameters, or on stateful operators such + // as `RngNormal` or `Infeed`. + // + // This tests whether a computation is a compile-time constant without + // evaluating the computation. + StatusOr<bool> IsConstant(const XlaOp& operand) const; + + private: // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. XlaOp Parameter(int64 parameter_number, const Shape& shape, @@ -230,6 +327,27 @@ class XlaBuilder { XlaOp Broadcast(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes); + // Performs in-dimension-style broadcast. + // + // Operand specifies the input to be broadcast. "shape" is expected output + // shape. "broadcast_dimensions" are the dimensions to be broadcasting into. + // Dimension numbers in broadcast_dimensions map to individual dimensions + // of the operand, and specify what dimension of the output shape they + // should be broadcast. + // e.g. + // Say operand = [1, 2], i.e., a 1D tensor with 2 elements. + // and dimension of shape is [2,2]. + // Specifying {1} as brodcast_dimension will generate output + // [1 , 2] + // [1 , 2] + // On the other hand, specifying {0} as broadcast_dimension + // will generate output + // [1 , 1] + // [2 , 2] + XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config // specifies the padding amount for each dimension. @@ -378,26 +496,6 @@ class XlaBuilder { XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers); - // Default dimension numbers used for a 2D convolution. - static constexpr int64 kConvBatchDimension = 0; - static constexpr int64 kConvFeatureDimension = 1; - static constexpr int64 kConvFirstSpatialDimension = 2; - static constexpr int64 kConvSecondSpatialDimension = 3; - static constexpr int64 kConvKernelOutputDimension = 0; - static constexpr int64 kConvKernelInputDimension = 1; - static constexpr int64 kConvKernelFirstSpatialDimension = 2; - static constexpr int64 kConvKernelSecondSpatialDimension = 3; - - // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for - // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for - // the kernel operand - // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. - static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( - int num_spatial_dims = 2); - - // Returns an error if the convolution dimension numbers have conflicts. - static Status Validate(const ConvolutionDimensionNumbers& dnum); - // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, @@ -444,6 +542,8 @@ class XlaBuilder { // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. XlaOp Infeed(const Shape& shape, const string& config = ""); + XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config = ""); // Enqueues an outfeed instruction onto the computation. This instruction // generates outgoing data transfers for the given data. @@ -453,6 +553,9 @@ class XlaBuilder { // will occur. void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config); + XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config); // Enqueues a call instruction onto the computation. XlaOp Call(const XlaComputation& computation, @@ -663,16 +766,6 @@ class XlaBuilder { // Enqueues an imaginary-part instruction onto the computation. XlaOp Imag(const XlaOp& operand); - // Enqueues a float32 sqrt instruction onto the computation. - // (float32 is specified as there is an implicit float32 0.5f constant - // exponent). - XlaOp SqrtF32(const XlaOp& operand); - - // Enqueues a float32 square instruction onto the computation. - // (float32 is specified as there is an implicit float32 2.0f constant - // exponent). - XlaOp SquareF32(const XlaOp& operand); - // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); @@ -695,14 +788,6 @@ class XlaBuilder { XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); - // Enqueues a float32 reciprocal instruction onto the computation. - // (float32 is specified as there is an implicit float32 -1.0f constant - // exponent). - // - // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the - // shape of the operand. - XlaOp ReciprocalF32(const XlaOp& operand); - // Enqueues a negate instruction onto the computation. XlaOp Neg(const XlaOp& operand); @@ -717,7 +802,24 @@ class XlaBuilder { tensorflow::gtl::ArraySlice<int64> dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. - XlaOp Sort(const XlaOp& operand); + // If only keys are provided: + // * If the keys are an rank-1 tensor (an array), the result is a sorted array + // of keys, in ascending order. + // * If the keys have higher rank, the keys are sorted along the provided + // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension + // value of 0 will indepenently sort every column, and a dimension value of 1 + // will independently sort each row. If no dimension number is provided, then + // the last dimension is chosen by default. + // + // If both keys and values are provided: + // * The keys and the values must tensors with the same dimensions. The + // element types of the tensors may be different. + // * The result is a tuple that consists of a sorted tensor of keys (along the + // provided dimension, as above) as the first element, and a tensor with their + // corresponding values as the second element. + XlaOp Sort(XlaOp keys, + tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt, + int64 dimension = -1); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -755,22 +857,35 @@ class XlaBuilder { const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds); - // Enqueues a Send node onto the computation, to send the given operand to - // a Recv instruction that shares the same channel handle. + // Enqueues a Send node onto the computation for device-to-device + // communication, to send the given operand to a Recv instruction that shares + // the same channel handle. void Send(const XlaOp& operand, const ChannelHandle& handle); + XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle); + + // Enqueues a Send node which sends data to the host. + XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, const ChannelHandle& handle); + + // Enqueues a Recv node which receives data from the host. + XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + + // Enqueues an AfterAll operation with no operands producing a token-shaped + // value. + XlaOp CreateToken(); + + // Enqueues an AfterAll operation with no operands producing a token-shaped + // value. + XlaOp AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens); // Enqueues a Recv node onto the computation. The data comes from a Send // instruction that shares the same channel handle and its shape must // be the same as the given shape. XlaOp Recv(const Shape& shape, const ChannelHandle& handle); - - // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on any parameters, or on stateful operators such - // as `RngNormal` or `Infeed`. - // - // This tests whether a computation is a compile-time constant without - // evaluating the computation. - StatusOr<bool> IsConstant(const XlaOp& operand) const; + XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); // Normalizes operand across spatial and batch dimensions for each feature. // @@ -810,65 +925,6 @@ class XlaBuilder { const XlaOp& grad_output, float epsilon, int64 feature_index); - // Returns a new XlaBuilder whose resultant Computation is used only by this - // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error - // behavior as the parent. - std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name); - - // Builds the computation with the requested operations, or returns a non-ok - // status. Note that all ops that have been enqueued will be moved to the - // computation being returned. - StatusOr<XlaComputation> Build(); - - // Builds the computation with the requested operations, or notes an error in - // the parent XlaBuilder and returns an empty computation if building failed. - // This function is intended to be used where the returned XlaComputation is - // only used by the parent XlaBuilder and hence further operation on the - // returned XlaComputation will simply be error'ed out if an error occurred - // while building this computation. If the built computation is to be used by - // a XlaBuilder other than the parent XlaBuilder then Build() should be used - // instead. - XlaComputation BuildAndNoteError(); - - // Returns a subgraph that roots on the given root. If the root is not a - // compile-time constant (see `IsConstant`), returns an error. - // - // This will copy the needed ops/computations to the subgraph. - StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const; - - // Returns the first error that was encountered while building the - // computation. When an error is encountered, by default we return a vacuous - // XlaOp and inform the user of the error that occurred while - // building the computation when they make a final call to Build(). - // - // See also set_die_immediately_on_error(). - Status first_error() const { return first_error_; } - - // Returns the shape of the given op. - StatusOr<Shape> GetShape(const XlaOp& op) const; - - // Returns the (inferred) result for the current computation's shape. - StatusOr<ProgramShape> GetProgramShape() const; - - // Reports an error to the builder, by - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to - // Build()/GetShape()/...) - // * dying if die_immediately_on_error_ is true. - // Returns an XlaOp with an invalid handle but a valid builder. This value can - // be returned in place of a value in APIs that return an XlaOp. - XlaOp ReportError(const Status& error); - - // A helper function that converts a StatusOr<XlaOp> into an XlaOp. - // If the Status was an error, reports the error to builder and returns an - // invalid XlaOp handle. - XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op); - - // A helper function that runs a function that returns a StatusOr<XlaOp> and - // returns an XlaOp. - XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator); - - private: StatusOr<XlaOp> AddInstruction( HloInstructionProto&& instr, HloOpcode opcode, tensorflow::gtl::ArraySlice<XlaOp> operands = {}); @@ -971,6 +1027,306 @@ class XlaBuilder { bool die_immediately_on_error_ = false; XlaBuilder* parent_builder_{nullptr}; + + friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, + const Shape& shape, const string& name); + friend XlaOp ConstantLiteral(XlaBuilder* builder, + const LiteralSlice& literal); + template <typename NativeT> + friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value); + template <typename NativeT> + friend XlaOp ConstantR1(XlaBuilder* builder, + tensorflow::gtl::ArraySlice<NativeT> values); + friend XlaOp ConstantR1(XlaBuilder* builder, + const tensorflow::core::Bitmap& values); + template <typename NativeT> + friend XlaOp ConstantR2( + XlaBuilder* builder, + std::initializer_list<std::initializer_list<NativeT>> values); + template <typename NativeT> + friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array<NativeT>& values, + const Layout& layout); + template <typename NativeT> + friend XlaOp ConstantFromArray(XlaBuilder* builder, + const Array<NativeT>& values); + template <typename NativeT> + friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D<NativeT>& values, + const Layout& layout); + template <typename NativeT> + friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D<NativeT>& values); + template <typename NativeT> + friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D<NativeT>& values, + const Layout& layout); + template <typename NativeT> + friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D<NativeT>& values); + template <typename NativeT> + friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D<NativeT>& values, + const Layout& layout); + template <typename NativeT> + friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D<NativeT>& values); + + template <typename NativeT> + friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); + + friend XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> broadcast_sizes); + + friend XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + + friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + + friend XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<int64> new_sizes); + + friend XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> new_sizes); + + friend XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions); + + friend XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> start_indices, + tensorflow::gtl::ArraySlice<int64> limit_indices, + tensorflow::gtl::ArraySlice<int64> strides); + + friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); + + friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice<int64> slice_sizes); + + friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + + friend XlaOp ConcatInDim(XlaBuilder* builder, + tensorflow::gtl::ArraySlice<XlaOp> operands, + int64 dimension); + + friend void Trace(const string& tag, const XlaOp& operand); + + friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true, + const XlaOp& on_false); + friend XlaOp Tuple(XlaBuilder* builder, + tensorflow::gtl::ArraySlice<XlaOp> elements); + friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + Padding padding); + friend XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding); + friend XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + tensorflow::gtl::ArraySlice<int64> lhs_dilation, + tensorflow::gtl::ArraySlice<int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice<int64> fft_length); + friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, + const string& config); + friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + tensorflow::gtl::ArraySlice<XlaOp> operands); + friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + tensorflow::gtl::ArraySlice<XlaOp> operands, + const Shape& shape); + friend XlaOp HostCompute(XlaBuilder* builder, + tensorflow::gtl::ArraySlice<XlaOp> operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Conj(const XlaOp& operand); + friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Not(const XlaOp& operand); + friend XlaOp ShiftLeft( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce); + friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + friend XlaOp ReduceWindow( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding); + friend XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding); + friend XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> replica_group_ids); + friend XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> replica_group_ids, + const tensorflow::gtl::optional<ChannelHandle>& channel_id); + friend XlaOp SelectAndScatter( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + friend XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + friend XlaOp Abs(const XlaOp& operand); + friend XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Exp(const XlaOp& operand); + friend XlaOp Expm1(const XlaOp& operand); + friend XlaOp Floor(const XlaOp& operand); + friend XlaOp Ceil(const XlaOp& operand); + friend XlaOp Round(const XlaOp& operand); + friend XlaOp Log(const XlaOp& operand); + friend XlaOp Log1p(const XlaOp& operand); + friend XlaOp Sign(const XlaOp& operand); + friend XlaOp Clz(const XlaOp& operand); + friend XlaOp Cos(const XlaOp& operand); + friend XlaOp Sin(const XlaOp& operand); + friend XlaOp Tanh(const XlaOp& operand); + friend XlaOp Real(const XlaOp& operand); + friend XlaOp Imag(const XlaOp& operand); + friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp IsFinite(const XlaOp& operand); + // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota + // in xla/client/lib/numeric.h with this (renamed to xla::Iota). + friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size); + friend XlaOp ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type); + friend XlaOp BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type); + friend XlaOp Neg(const XlaOp& operand); + friend XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> permutation); + friend XlaOp Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions); + friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values, + int64 dimension); + friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + friend XlaOp Map(XlaBuilder* builder, + tensorflow::gtl::ArraySlice<XlaOp> operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<XlaOp> static_operands); + friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, + const Shape& shape); + friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + friend XlaOp While(const XlaComputation& condition, + const XlaComputation& body, const XlaOp& init); + friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds); + friend void Send(const XlaOp& operand, const ChannelHandle& handle); + friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); + friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle); + friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const ChannelHandle& handle); + friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config); + friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config); + friend XlaOp CreateToken(XlaBuilder* builder); + friend XlaOp AfterAll(XlaBuilder* builder, + tensorflow::gtl::ArraySlice<XlaOp> tokens); }; // RAII-style object: sets the current sharding assignment in builder on @@ -1087,6 +1443,27 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); XlaOp Broadcast(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes); +// Performs in-dimension-style broadcast. +// +// Operand specifies the input to be broadcast. "shape" is expected output +// shape. "broadcast_dimensions" are the dimensions to be broadcasting into. +// Dimension numbers in broadcast_dimensions map to individual dimensions +// of the operand, and specify what dimension of the output shape they +// should be broadcast. +// e.g. +// Say operand = [1, 2], i.e., a 1D tensor with 2 elements. +// and dimension of shape is [2,2]. +// Specifying {1} as brodcast_dimension will generate output +// [1 , 2] +// [1 , 2] +// On the other hand, specifying {0} as broadcast_dimension +// will generate output +// [1 , 1] +// [2 , 2] +XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config // specifies the padding amount for each dimension. @@ -1281,6 +1658,13 @@ XlaOp Fft(const XlaOp& operand, FftType fft_type, XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config = ""); +// Variant of Infeed which takes a token-shaped operand and produces a +// two-element tuple containing the data value and a token-shaped value. +// Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config = ""); + // Enqueues an outfeed instruction onto the computation. This instruction // generates outgoing data transfers for the given data. // @@ -1290,6 +1674,13 @@ XlaOp Infeed(XlaBuilder* builder, const Shape& shape, void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config); +// Variant of Outfeed which takes a token-shaped operand and produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config); + // Enqueues a call instruction onto the computation. XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, tensorflow::gtl::ArraySlice<XlaOp> operands); @@ -1498,16 +1889,6 @@ XlaOp Real(const XlaOp& operand); // Enqueues an imaginary-part instruction onto the computation. XlaOp Imag(const XlaOp& operand); -// Enqueues a float32 sqrt instruction onto the computation. -// (float32 is specified as there is an implicit float32 0.5f constant -// exponent). -XlaOp SqrtF32(const XlaOp& operand); - -// Enqueues a float32 square instruction onto the computation. -// (float32 is specified as there is an implicit float32 2.0f constant -// exponent). -XlaOp SquareF32(const XlaOp& operand); - // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); @@ -1528,14 +1909,6 @@ XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); // identical. XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); -// Enqueues a float32 reciprocal instruction onto the computation. -// (float32 is specified as there is an implicit float32 -1.0f constant -// exponent). -// -// TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the -// shape of the operand. -XlaOp ReciprocalF32(const XlaOp& operand); - // Enqueues a negate instruction onto the computation. XlaOp Neg(const XlaOp& operand); @@ -1549,7 +1922,24 @@ XlaOp Transpose(const XlaOp& operand, XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. -XlaOp Sort(const XlaOp& operand); +// If only keys are provided: +// * If the keys are an rank-1 tensor (an array), the result is a sorted array +// of keys, in ascending order. +// * If the keys have higher rank, the keys are sorted along the provided +// dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension +// value of 0 will indepenently sort every column, and a dimension value of 1 +// will independently sort each row. If no dimension number is provided, then +// the last dimension is chosen by default. +// +// If both keys and values are provided: +// * The keys and the values must tensors with the same dimensions. The +// element types of the tensors may be different. +// * The result is a tuple that consists of a sorted tensor of keys (along the +// provided dimension, as above) as the first element, and a tensor with their +// corresponding values as the second element. +XlaOp Sort(XlaOp keys, + tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt, + int64 dimension = -1); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -1587,16 +1977,59 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds); -// Enqueues a Send node onto the computation, to send the given operand to -// a Recv instruction that shares the same channel handle. +// Enqueues a Send node onto the computation for device-to-device +// communication. This operation sends the given operand to +// a Recv instruction in a different computation that shares the same channel +// handle. void Send(const XlaOp& operand, const ChannelHandle& handle); -// Enqueues a Recv node onto the computation. The data comes from a Send -// instruction that shares the same channel handle and its shape must -// be the same as the given shape. +// Variant of Send which takes a token-shaped operand and produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle); + +// Enqueues a Recv node onto the computation for device-to-device +// communication. The data comes from a Send instruction in a different +// computation that shares the same channel handle and its shape must be the +// same as the given shape. XlaOp Recv(XlaBuilder* builder, const Shape& shape, const ChannelHandle& handle); +// Variant of Recv which takes a token-shaped operand and produces a two-element +// tuple containing the data value and a token-shaped value. Tokens are used +// for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + +// Enqueues a Send node which transfers data from the device to the host. The +// 'shape_with_layout' argument defines the layout of the data transferred; its +// shape must be compatible with the shape of the operand. The operand must be +// array-shaped. +// TODO(b/111544877): Support tuple shapes. +XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, const ChannelHandle& handle); + +// Enqueues a Recv node which transfers data from the host to the device. The +// given shape must contain a layout and must be an array. +// TODO(b/111544877): Support tuple shapes. +XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + +// Enqueues an operation (AfterAll) with no operands that produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// This is a separate method from AfterAll to facility the removal of +// operand-less AfterAll instructions. +// TODO(b/110532604): Remove this function when all tokens are derived from a +// single token generated or passed into the entry computation. +XlaOp CreateToken(XlaBuilder* builder); + +// Enqueues an AfterAll instruction which produces a token-shaped value and +// takes a variadic number of token-shaped operands. The number of operands must +// be greater than zero. Used for joining tokens. +XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens); + // Normalizes operand across spatial and batch dimensions for each feature. // // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` @@ -1639,12 +2072,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, template <typename NativeT> XlaOp XlaBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*Literal::CreateR0<NativeT>(value)); + return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value)); } template <typename NativeT> XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values) { - return ConstantLiteral(*Literal::CreateR1<NativeT>(values)); + return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values)); } template <typename NativeT> @@ -1656,44 +2089,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { } inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*Literal::CreateR1(values)); + return ConstantLiteral(*LiteralUtil::CreateR1(values)); } template <typename NativeT> XlaOp XlaBuilder::ConstantR2( std::initializer_list<std::initializer_list<NativeT>> values) { - return ConstantLiteral(*Literal::CreateR2<NativeT>(values)); + return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values)); } template <typename NativeT> XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values, const Layout& layout) { return ConstantLiteral( - *Literal::CreateFromArrayWithLayout<NativeT>(values, layout)); + *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); } template <typename NativeT> XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) { - return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values)); + return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values)); } template <typename NativeT> XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( const Array2D<NativeT>& values, const Layout& layout) { return ConstantLiteral( - *Literal::CreateFromArrayWithLayout<NativeT>(values, layout)); + *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); } template <typename NativeT> XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) { - return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values)); + return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values)); } template <typename NativeT> XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( const Array3D<NativeT>& values, const Layout& layout) { return ConstantLiteral( - *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout)); + *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout)); } template <typename NativeT> @@ -1716,13 +2149,13 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) { template <typename NativeT> XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { - return ConstantLiteral(builder, *Literal::CreateR0<NativeT>(value)); + return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value)); } template <typename NativeT> XlaOp ConstantR1(XlaBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> values) { - return ConstantLiteral(builder, *Literal::CreateR1<NativeT>(values)); + return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values)); } template <typename NativeT> @@ -1735,13 +2168,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) { inline XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values) { - return ConstantLiteral(builder, *Literal::CreateR1(values)); + return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); } template <typename NativeT> XlaOp ConstantR2(XlaBuilder* builder, std::initializer_list<std::initializer_list<NativeT>> values) { - return ConstantLiteral(builder, *Literal::CreateR2<NativeT>(values)); + return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values)); } template <typename NativeT> @@ -1749,12 +2182,14 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, const Array<NativeT>& values, const Layout& layout) { return ConstantLiteral( - builder, *Literal::CreateFromArrayWithLayout<NativeT>(values, layout)); + builder, + *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); } template <typename NativeT> XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) { - return ConstantLiteral(builder, *Literal::CreateFromArray<NativeT>(values)); + return ConstantLiteral(builder, + *LiteralUtil::CreateFromArray<NativeT>(values)); } template <typename NativeT> @@ -1762,14 +2197,15 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, const Array2D<NativeT>& values, const Layout& layout) { return ConstantLiteral( - builder, *Literal::CreateFromArrayWithLayout<NativeT>(values, layout)); + builder, + *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); } template <typename NativeT> XlaOp ConstantR2FromArray2D(XlaBuilder* builder, const Array2D<NativeT>& values) { return ConstantLiteral(builder, - *Literal::CreateR2FromArray2D<NativeT>(values)); + *LiteralUtil::CreateR2FromArray2D<NativeT>(values)); } template <typename NativeT> @@ -1778,7 +2214,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, const Layout& layout) { return ConstantLiteral( builder, - *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout)); + *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout)); } template <typename NativeT> |