diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_builder.h')
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.h | 97 |
1 files changed, 46 insertions, 51 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 59fbc664f2..58e8f4e7fa 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -496,20 +496,19 @@ class XlaBuilder { // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a general dot instruction onto the computation. - XlaOp DotGeneral( - const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). @@ -518,7 +517,7 @@ class XlaBuilder { absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. @@ -527,29 +526,27 @@ class XlaBuilder { absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. - XlaOp ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. - XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - absl::Span<const int64> lhs_dilation, - absl::Span<const int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + absl::Span<const int64> lhs_dilation, + absl::Span<const int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. @@ -1150,32 +1147,30 @@ class XlaBuilder { friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> broadcast_dimensions); friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_number, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, @@ -1183,8 +1178,7 @@ class XlaBuilder { absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span<const int64> fft_length); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, @@ -1629,27 +1623,27 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a general dot instruction onto the computation. XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). -XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); +XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, + absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. @@ -1657,7 +1651,7 @@ XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. @@ -1666,17 +1660,18 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. -XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); +XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + absl::Span<const int64> lhs_dilation, + absl::Span<const int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. |