aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/xla_client/xla_builder.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_client/xla_builder.h')
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h740
1 files changed, 588 insertions, 152 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index fe31774b86..980e84e40c 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -22,7 +22,8 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -53,7 +54,16 @@ class XlaOp {
}
~XlaOp() = default;
- XlaBuilder* builder() const { return builder_; }
+ // Precondition: !IsUninitialized().
+ //
+ // It's very common to do foo.builder()->bar(). Without this precondition, if
+ // foo.builder() is null, the call to bar will segfault at some point possibly
+ // deep in the callstack when we finally dereference `this`. The precondition
+ // lets us avoid this tricky-to-debug problem.
+ XlaBuilder* builder() const {
+ CHECK(builder_ != nullptr);
+ return builder_;
+ }
// Returns true if the XlaOp represents valid, non-erroneous value.
bool valid() const { return handle_ >= 0; }
@@ -158,6 +168,93 @@ class XlaBuilder {
die_immediately_on_error_ = enabled;
}
+ // Default dimension numbers used for a 2D convolution.
+ static constexpr int64 kConvBatchDimension = 0;
+ static constexpr int64 kConvFeatureDimension = 1;
+ static constexpr int64 kConvFirstSpatialDimension = 2;
+ static constexpr int64 kConvSecondSpatialDimension = 3;
+ static constexpr int64 kConvKernelOutputDimension = 0;
+ static constexpr int64 kConvKernelInputDimension = 1;
+ static constexpr int64 kConvKernelFirstSpatialDimension = 2;
+ static constexpr int64 kConvKernelSecondSpatialDimension = 3;
+
+ // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
+ // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
+ // the kernel operand
+ // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
+ static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
+ int num_spatial_dims = 2);
+
+ // Returns an error if the convolution dimension numbers have conflicts.
+ static Status Validate(const ConvolutionDimensionNumbers& dnum);
+
+ // Returns a new XlaBuilder whose resultant Computation is used only by this
+ // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
+ // behavior as the parent.
+ std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
+
+ // Builds the computation with the requested operations, or returns a non-ok
+ // status. Note that all ops that have been enqueued will be moved to the
+ // computation being returned.
+ StatusOr<XlaComputation> Build();
+
+ // Builds the computation with the requested operations, or notes an error in
+ // the parent XlaBuilder and returns an empty computation if building failed.
+ // This function is intended to be used where the returned XlaComputation is
+ // only used by the parent XlaBuilder and hence further operation on the
+ // returned XlaComputation will simply be error'ed out if an error occurred
+ // while building this computation. If the built computation is to be used by
+ // a XlaBuilder other than the parent XlaBuilder then Build() should be used
+ // instead.
+ XlaComputation BuildAndNoteError();
+
+ // Returns a subgraph that roots on the given root. If the root is not a
+ // compile-time constant (see `IsConstant`), returns an error.
+ //
+ // This will copy the needed ops/computations to the subgraph.
+ StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const;
+
+ // Returns the first error that was encountered while building the
+ // computation. When an error is encountered, by default we return a vacuous
+ // XlaOp and inform the user of the error that occurred while
+ // building the computation when they make a final call to Build().
+ //
+ // See also set_die_immediately_on_error().
+ Status first_error() const { return first_error_; }
+
+ // Returns the shape of the given op.
+ StatusOr<Shape> GetShape(const XlaOp& op) const;
+
+ // Returns the (inferred) result for the current computation's shape.
+ StatusOr<ProgramShape> GetProgramShape() const;
+
+ // Reports an error to the builder, by
+ // * storing it internally and capturing a backtrace if it's the first error
+ // (this deferred value will be produced on the call to
+ // Build()/GetShape()/...)
+ // * dying if die_immediately_on_error_ is true.
+ // Returns an XlaOp with an invalid handle but a valid builder. This value can
+ // be returned in place of a value in APIs that return an XlaOp.
+ XlaOp ReportError(const Status& error);
+
+ // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
+ // If the Status was an error, reports the error to builder and returns an
+ // invalid XlaOp handle.
+ XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
+
+ // A helper function that runs a function that returns a StatusOr<XlaOp> and
+ // returns an XlaOp.
+ XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
+
+ // Returns true if 'operand' is a compile-time constant. A compile-time
+ // constant does not depend on any parameters, or on stateful operators such
+ // as `RngNormal` or `Infeed`.
+ //
+ // This tests whether a computation is a compile-time constant without
+ // evaluating the computation.
+ StatusOr<bool> IsConstant(const XlaOp& operand) const;
+
+ private:
// Enqueues a "retrieve parameter value" instruction for a parameter that was
// passed to the computation.
XlaOp Parameter(int64 parameter_number, const Shape& shape,
@@ -230,6 +327,27 @@ class XlaBuilder {
XlaOp Broadcast(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ // Performs in-dimension-style broadcast.
+ //
+ // Operand specifies the input to be broadcast. "shape" is expected output
+ // shape. "broadcast_dimensions" are the dimensions to be broadcasting into.
+ // Dimension numbers in broadcast_dimensions map to individual dimensions
+ // of the operand, and specify what dimension of the output shape they
+ // should be broadcast.
+ // e.g.
+ // Say operand = [1, 2], i.e., a 1D tensor with 2 elements.
+ // and dimension of shape is [2,2].
+ // Specifying {1} as brodcast_dimension will generate output
+ // [1 , 2]
+ // [1 , 2]
+ // On the other hand, specifying {0} as broadcast_dimension
+ // will generate output
+ // [1 , 1]
+ // [2 , 2]
+ XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
// Enqueues a pad operation onto the computation that pads the given value on
// the edges as well as between the elements of the input. padding_config
// specifies the padding amount for each dimension.
@@ -378,26 +496,6 @@ class XlaBuilder {
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers);
- // Default dimension numbers used for a 2D convolution.
- static constexpr int64 kConvBatchDimension = 0;
- static constexpr int64 kConvFeatureDimension = 1;
- static constexpr int64 kConvFirstSpatialDimension = 2;
- static constexpr int64 kConvSecondSpatialDimension = 3;
- static constexpr int64 kConvKernelOutputDimension = 0;
- static constexpr int64 kConvKernelInputDimension = 1;
- static constexpr int64 kConvKernelFirstSpatialDimension = 2;
- static constexpr int64 kConvKernelSecondSpatialDimension = 3;
-
- // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
- // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
- // the kernel operand
- // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
- static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
- int num_spatial_dims = 2);
-
- // Returns an error if the convolution dimension numbers have conflicts.
- static Status Validate(const ConvolutionDimensionNumbers& dnum);
-
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
@@ -444,6 +542,8 @@ class XlaBuilder {
// Enqueues an infeed instruction onto the computation, which writes data of
// the given shape to the infeed buffer of the device.
XlaOp Infeed(const Shape& shape, const string& config = "");
+ XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config = "");
// Enqueues an outfeed instruction onto the computation. This instruction
// generates outgoing data transfers for the given data.
@@ -453,6 +553,9 @@ class XlaBuilder {
// will occur.
void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config);
+ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config);
// Enqueues a call instruction onto the computation.
XlaOp Call(const XlaComputation& computation,
@@ -663,16 +766,6 @@ class XlaBuilder {
// Enqueues an imaginary-part instruction onto the computation.
XlaOp Imag(const XlaOp& operand);
- // Enqueues a float32 sqrt instruction onto the computation.
- // (float32 is specified as there is an implicit float32 0.5f constant
- // exponent).
- XlaOp SqrtF32(const XlaOp& operand);
-
- // Enqueues a float32 square instruction onto the computation.
- // (float32 is specified as there is an implicit float32 2.0f constant
- // exponent).
- XlaOp SquareF32(const XlaOp& operand);
-
// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
@@ -695,14 +788,6 @@ class XlaBuilder {
XlaOp BitcastConvertType(const XlaOp& operand,
PrimitiveType new_element_type);
- // Enqueues a float32 reciprocal instruction onto the computation.
- // (float32 is specified as there is an implicit float32 -1.0f constant
- // exponent).
- //
- // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
- // shape of the operand.
- XlaOp ReciprocalF32(const XlaOp& operand);
-
// Enqueues a negate instruction onto the computation.
XlaOp Neg(const XlaOp& operand);
@@ -717,7 +802,24 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<int64> dimensions);
// Enqueues a sort (as increasing order) instruction onto the computation.
- XlaOp Sort(const XlaOp& operand);
+ // If only keys are provided:
+ // * If the keys are an rank-1 tensor (an array), the result is a sorted array
+ // of keys, in ascending order.
+ // * If the keys have higher rank, the keys are sorted along the provided
+ // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
+ // value of 0 will indepenently sort every column, and a dimension value of 1
+ // will independently sort each row. If no dimension number is provided, then
+ // the last dimension is chosen by default.
+ //
+ // If both keys and values are provided:
+ // * The keys and the values must tensors with the same dimensions. The
+ // element types of the tensors may be different.
+ // * The result is a tuple that consists of a sorted tensor of keys (along the
+ // provided dimension, as above) as the first element, and a tensor with their
+ // corresponding values as the second element.
+ XlaOp Sort(XlaOp keys,
+ tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt,
+ int64 dimension = -1);
// Enqueues a clamp instruction onto the computation.
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
@@ -755,22 +857,35 @@ class XlaBuilder {
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
- // Enqueues a Send node onto the computation, to send the given operand to
- // a Recv instruction that shares the same channel handle.
+ // Enqueues a Send node onto the computation for device-to-device
+ // communication, to send the given operand to a Recv instruction that shares
+ // the same channel handle.
void Send(const XlaOp& operand, const ChannelHandle& handle);
+ XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle);
+
+ // Enqueues a Send node which sends data to the host.
+ XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout, const ChannelHandle& handle);
+
+ // Enqueues a Recv node which receives data from the host.
+ XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+
+ // Enqueues an AfterAll operation with no operands producing a token-shaped
+ // value.
+ XlaOp CreateToken();
+
+ // Enqueues an AfterAll operation with no operands producing a token-shaped
+ // value.
+ XlaOp AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens);
// Enqueues a Recv node onto the computation. The data comes from a Send
// instruction that shares the same channel handle and its shape must
// be the same as the given shape.
XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
-
- // Returns true if 'operand' is a compile-time constant. A compile-time
- // constant does not depend on any parameters, or on stateful operators such
- // as `RngNormal` or `Infeed`.
- //
- // This tests whether a computation is a compile-time constant without
- // evaluating the computation.
- StatusOr<bool> IsConstant(const XlaOp& operand) const;
+ XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
// Normalizes operand across spatial and batch dimensions for each feature.
//
@@ -810,65 +925,6 @@ class XlaBuilder {
const XlaOp& grad_output, float epsilon,
int64 feature_index);
- // Returns a new XlaBuilder whose resultant Computation is used only by this
- // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
- // behavior as the parent.
- std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
-
- // Builds the computation with the requested operations, or returns a non-ok
- // status. Note that all ops that have been enqueued will be moved to the
- // computation being returned.
- StatusOr<XlaComputation> Build();
-
- // Builds the computation with the requested operations, or notes an error in
- // the parent XlaBuilder and returns an empty computation if building failed.
- // This function is intended to be used where the returned XlaComputation is
- // only used by the parent XlaBuilder and hence further operation on the
- // returned XlaComputation will simply be error'ed out if an error occurred
- // while building this computation. If the built computation is to be used by
- // a XlaBuilder other than the parent XlaBuilder then Build() should be used
- // instead.
- XlaComputation BuildAndNoteError();
-
- // Returns a subgraph that roots on the given root. If the root is not a
- // compile-time constant (see `IsConstant`), returns an error.
- //
- // This will copy the needed ops/computations to the subgraph.
- StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const;
-
- // Returns the first error that was encountered while building the
- // computation. When an error is encountered, by default we return a vacuous
- // XlaOp and inform the user of the error that occurred while
- // building the computation when they make a final call to Build().
- //
- // See also set_die_immediately_on_error().
- Status first_error() const { return first_error_; }
-
- // Returns the shape of the given op.
- StatusOr<Shape> GetShape(const XlaOp& op) const;
-
- // Returns the (inferred) result for the current computation's shape.
- StatusOr<ProgramShape> GetProgramShape() const;
-
- // Reports an error to the builder, by
- // * storing it internally and capturing a backtrace if it's the first error
- // (this deferred value will be produced on the call to
- // Build()/GetShape()/...)
- // * dying if die_immediately_on_error_ is true.
- // Returns an XlaOp with an invalid handle but a valid builder. This value can
- // be returned in place of a value in APIs that return an XlaOp.
- XlaOp ReportError(const Status& error);
-
- // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
- // If the Status was an error, reports the error to builder and returns an
- // invalid XlaOp handle.
- XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
-
- // A helper function that runs a function that returns a StatusOr<XlaOp> and
- // returns an XlaOp.
- XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
-
- private:
StatusOr<XlaOp> AddInstruction(
HloInstructionProto&& instr, HloOpcode opcode,
tensorflow::gtl::ArraySlice<XlaOp> operands = {});
@@ -971,6 +1027,306 @@ class XlaBuilder {
bool die_immediately_on_error_ = false;
XlaBuilder* parent_builder_{nullptr};
+
+ friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
+ const Shape& shape, const string& name);
+ friend XlaOp ConstantLiteral(XlaBuilder* builder,
+ const LiteralSlice& literal);
+ template <typename NativeT>
+ friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
+ template <typename NativeT>
+ friend XlaOp ConstantR1(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<NativeT> values);
+ friend XlaOp ConstantR1(XlaBuilder* builder,
+ const tensorflow::core::Bitmap& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR2(
+ XlaBuilder* builder,
+ std::initializer_list<std::initializer_list<NativeT>> values);
+ template <typename NativeT>
+ friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
+ const Array<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantFromArray(XlaBuilder* builder,
+ const Array<NativeT>& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
+ const Array2D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
+ const Array2D<NativeT>& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
+ const Array3D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
+ const Array3D<NativeT>& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
+ const Array4D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
+ const Array4D<NativeT>& values);
+
+ template <typename NativeT>
+ friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
+
+ friend XlaOp Broadcast(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+
+ friend XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+ friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
+ const PaddingConfig& padding_config);
+
+ friend XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ friend XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ friend XlaOp Collapse(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ friend XlaOp Slice(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
+
+ friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index,
+ int64 limit_index, int64 stride, int64 dimno);
+
+ friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
+
+ friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
+ const XlaOp& start_indices);
+
+ friend XlaOp ConcatInDim(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ int64 dimension);
+
+ friend void Trace(const string& tag, const XlaOp& operand);
+
+ friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true,
+ const XlaOp& on_false);
+ friend XlaOp Tuple(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> elements);
+ friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
+ friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+ friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers);
+ friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding);
+ friend XlaOp ConvWithGeneralPadding(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ friend XlaOp ConvWithGeneralDimensions(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+ friend XlaOp ConvGeneral(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+ friend XlaOp ConvGeneralDilated(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+ friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+ friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
+ const string& config);
+ friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
+ const string& outfeed_config);
+ friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<XlaOp> operands);
+ friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const Shape& shape);
+ friend XlaOp HostCompute(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const string& channel_name, int64 cost_estimate_ns,
+ const Shape& shape);
+ friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Conj(const XlaOp& operand);
+ friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Not(const XlaOp& operand);
+ friend XlaOp ShiftLeft(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp ShiftRightArithmetic(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp ShiftRightLogical(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation);
+ friend XlaOp ReduceWindow(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
+ friend XlaOp ReduceWindowWithGeneralPadding(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ friend XlaOp CrossReplicaSum(
+ const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids);
+ friend XlaOp CrossReplicaSum(
+ const XlaOp& operand, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ const tensorflow::gtl::optional<ChannelHandle>& channel_id);
+ friend XlaOp SelectAndScatter(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
+ friend XlaOp SelectAndScatterWithGeneralPadding(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
+ friend XlaOp Abs(const XlaOp& operand);
+ friend XlaOp Atan2(const XlaOp& y, const XlaOp& x,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Exp(const XlaOp& operand);
+ friend XlaOp Expm1(const XlaOp& operand);
+ friend XlaOp Floor(const XlaOp& operand);
+ friend XlaOp Ceil(const XlaOp& operand);
+ friend XlaOp Round(const XlaOp& operand);
+ friend XlaOp Log(const XlaOp& operand);
+ friend XlaOp Log1p(const XlaOp& operand);
+ friend XlaOp Sign(const XlaOp& operand);
+ friend XlaOp Clz(const XlaOp& operand);
+ friend XlaOp Cos(const XlaOp& operand);
+ friend XlaOp Sin(const XlaOp& operand);
+ friend XlaOp Tanh(const XlaOp& operand);
+ friend XlaOp Real(const XlaOp& operand);
+ friend XlaOp Imag(const XlaOp& operand);
+ friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp IsFinite(const XlaOp& operand);
+ // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota
+ // in xla/client/lib/numeric.h with this (renamed to xla::Iota).
+ friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size);
+ friend XlaOp ConvertElementType(const XlaOp& operand,
+ PrimitiveType new_element_type);
+ friend XlaOp BitcastConvertType(const XlaOp& operand,
+ PrimitiveType new_element_type);
+ friend XlaOp Neg(const XlaOp& operand);
+ friend XlaOp Transpose(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation);
+ friend XlaOp Rev(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+ friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
+ int64 dimension);
+ friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
+ friend XlaOp Map(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<XlaOp> static_operands);
+ friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma,
+ const Shape& shape);
+ friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
+ friend XlaOp While(const XlaComputation& condition,
+ const XlaComputation& body, const XlaOp& init);
+ friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
+ const XlaComputation& true_computation,
+ const XlaOp& false_operand,
+ const XlaComputation& false_computation);
+ friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
+ const int mantissa_bits);
+ friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ const GatherDimensionNumbers& dimension_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+ friend void Send(const XlaOp& operand, const ChannelHandle& handle);
+ friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
+ const ChannelHandle& handle);
+ friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, float epsilon,
+ int64 feature_index);
+ friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, const XlaOp& mean,
+ const XlaOp& variance, float epsilon,
+ int64 feature_index);
+ friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& batch_mean, const XlaOp& batch_var,
+ const XlaOp& grad_output, float epsilon,
+ int64 feature_index);
+ friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle);
+ friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+ friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const ChannelHandle& handle);
+ friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+ friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config);
+ friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config);
+ friend XlaOp CreateToken(XlaBuilder* builder);
+ friend XlaOp AfterAll(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> tokens);
};
// RAII-style object: sets the current sharding assignment in builder on
@@ -1087,6 +1443,27 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
XlaOp Broadcast(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+// Performs in-dimension-style broadcast.
+//
+// Operand specifies the input to be broadcast. "shape" is expected output
+// shape. "broadcast_dimensions" are the dimensions to be broadcasting into.
+// Dimension numbers in broadcast_dimensions map to individual dimensions
+// of the operand, and specify what dimension of the output shape they
+// should be broadcast.
+// e.g.
+// Say operand = [1, 2], i.e., a 1D tensor with 2 elements.
+// and dimension of shape is [2,2].
+// Specifying {1} as brodcast_dimension will generate output
+// [1 , 2]
+// [1 , 2]
+// On the other hand, specifying {0} as broadcast_dimension
+// will generate output
+// [1 , 1]
+// [2 , 2]
+XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
// Enqueues a pad operation onto the computation that pads the given value on
// the edges as well as between the elements of the input. padding_config
// specifies the padding amount for each dimension.
@@ -1281,6 +1658,13 @@ XlaOp Fft(const XlaOp& operand, FftType fft_type,
XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
const string& config = "");
+// Variant of Infeed which takes a token-shaped operand and produces a
+// two-element tuple containing the data value and a token-shaped value.
+// Tokens are used for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config = "");
+
// Enqueues an outfeed instruction onto the computation. This instruction
// generates outgoing data transfers for the given data.
//
@@ -1290,6 +1674,13 @@ XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config);
+// Variant of Outfeed which takes a token-shaped operand and produces a
+// token-shaped value. Tokens are used for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config);
+
// Enqueues a call instruction onto the computation.
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
tensorflow::gtl::ArraySlice<XlaOp> operands);
@@ -1498,16 +1889,6 @@ XlaOp Real(const XlaOp& operand);
// Enqueues an imaginary-part instruction onto the computation.
XlaOp Imag(const XlaOp& operand);
-// Enqueues a float32 sqrt instruction onto the computation.
-// (float32 is specified as there is an implicit float32 0.5f constant
-// exponent).
-XlaOp SqrtF32(const XlaOp& operand);
-
-// Enqueues a float32 square instruction onto the computation.
-// (float32 is specified as there is an implicit float32 2.0f constant
-// exponent).
-XlaOp SquareF32(const XlaOp& operand);
-
// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
@@ -1528,14 +1909,6 @@ XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type);
// identical.
XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type);
-// Enqueues a float32 reciprocal instruction onto the computation.
-// (float32 is specified as there is an implicit float32 -1.0f constant
-// exponent).
-//
-// TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
-// shape of the operand.
-XlaOp ReciprocalF32(const XlaOp& operand);
-
// Enqueues a negate instruction onto the computation.
XlaOp Neg(const XlaOp& operand);
@@ -1549,7 +1922,24 @@ XlaOp Transpose(const XlaOp& operand,
XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
// Enqueues a sort (as increasing order) instruction onto the computation.
-XlaOp Sort(const XlaOp& operand);
+// If only keys are provided:
+// * If the keys are an rank-1 tensor (an array), the result is a sorted array
+// of keys, in ascending order.
+// * If the keys have higher rank, the keys are sorted along the provided
+// dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
+// value of 0 will indepenently sort every column, and a dimension value of 1
+// will independently sort each row. If no dimension number is provided, then
+// the last dimension is chosen by default.
+//
+// If both keys and values are provided:
+// * The keys and the values must tensors with the same dimensions. The
+// element types of the tensors may be different.
+// * The result is a tuple that consists of a sorted tensor of keys (along the
+// provided dimension, as above) as the first element, and a tensor with their
+// corresponding values as the second element.
+XlaOp Sort(XlaOp keys,
+ tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt,
+ int64 dimension = -1);
// Enqueues a clamp instruction onto the computation.
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
@@ -1587,16 +1977,59 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
-// Enqueues a Send node onto the computation, to send the given operand to
-// a Recv instruction that shares the same channel handle.
+// Enqueues a Send node onto the computation for device-to-device
+// communication. This operation sends the given operand to
+// a Recv instruction in a different computation that shares the same channel
+// handle.
void Send(const XlaOp& operand, const ChannelHandle& handle);
-// Enqueues a Recv node onto the computation. The data comes from a Send
-// instruction that shares the same channel handle and its shape must
-// be the same as the given shape.
+// Variant of Send which takes a token-shaped operand and produces a
+// token-shaped value. Tokens are used for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle);
+
+// Enqueues a Recv node onto the computation for device-to-device
+// communication. The data comes from a Send instruction in a different
+// computation that shares the same channel handle and its shape must be the
+// same as the given shape.
XlaOp Recv(XlaBuilder* builder, const Shape& shape,
const ChannelHandle& handle);
+// Variant of Recv which takes a token-shaped operand and produces a two-element
+// tuple containing the data value and a token-shaped value. Tokens are used
+// for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+
+// Enqueues a Send node which transfers data from the device to the host. The
+// 'shape_with_layout' argument defines the layout of the data transferred; its
+// shape must be compatible with the shape of the operand. The operand must be
+// array-shaped.
+// TODO(b/111544877): Support tuple shapes.
+XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout, const ChannelHandle& handle);
+
+// Enqueues a Recv node which transfers data from the host to the device. The
+// given shape must contain a layout and must be an array.
+// TODO(b/111544877): Support tuple shapes.
+XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+
+// Enqueues an operation (AfterAll) with no operands that produces a
+// token-shaped value. Tokens are used for ordering side-effecting operations.
+// This is a separate method from AfterAll to facility the removal of
+// operand-less AfterAll instructions.
+// TODO(b/110532604): Remove this function when all tokens are derived from a
+// single token generated or passed into the entry computation.
+XlaOp CreateToken(XlaBuilder* builder);
+
+// Enqueues an AfterAll instruction which produces a token-shaped value and
+// takes a variadic number of token-shaped operands. The number of operands must
+// be greater than zero. Used for joining tokens.
+XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens);
+
// Normalizes operand across spatial and batch dimensions for each feature.
//
// Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
@@ -1639,12 +2072,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
template <typename NativeT>
XlaOp XlaBuilder::ConstantR0(NativeT value) {
- return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
+ return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values) {
- return ConstantLiteral(*Literal::CreateR1<NativeT>(values));
+ return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -1656,44 +2089,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
}
inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(*Literal::CreateR1(values));
+ return ConstantLiteral(*LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(*Literal::CreateR2<NativeT>(values));
+ return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
- return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values));
+ return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
- return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values));
+ return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
@@ -1716,13 +2149,13 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
- return ConstantLiteral(builder, *Literal::CreateR0<NativeT>(value));
+ return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder,
tensorflow::gtl::ArraySlice<NativeT> values) {
- return ConstantLiteral(builder, *Literal::CreateR1<NativeT>(values));
+ return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -1735,13 +2168,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
inline XlaOp ConstantR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(builder, *Literal::CreateR1(values));
+ return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(builder, *Literal::CreateR2<NativeT>(values));
+ return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
@@ -1749,12 +2182,14 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder, *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder,
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
- return ConstantLiteral(builder, *Literal::CreateFromArray<NativeT>(values));
+ return ConstantLiteral(builder,
+ *LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
@@ -1762,14 +2197,15 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
const Array2D<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder, *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder,
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
const Array2D<NativeT>& values) {
return ConstantLiteral(builder,
- *Literal::CreateR2FromArray2D<NativeT>(values));
+ *LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
@@ -1778,7 +2214,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
const Layout& layout) {
return ConstantLiteral(
builder,
- *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>