aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/xla_builder.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_builder.h')
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h97
1 files changed, 46 insertions, 51 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 59fbc664f2..58e8f4e7fa 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -496,20 +496,19 @@ class XlaBuilder {
// Enqueues a dot instruction onto the computation.
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a general dot instruction onto the computation.
- XlaOp DotGeneral(
- const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
@@ -518,7 +517,7 @@ class XlaBuilder {
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
@@ -527,29 +526,27 @@ class XlaBuilder {
absl::Span<const int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
- XlaOp ConvGeneral(
- const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> window_strides,
- absl::Span<const std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
- XlaOp ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> window_strides,
- absl::Span<const std::pair<int64, int64>> padding,
- absl::Span<const int64> lhs_dilation,
- absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
@@ -1150,32 +1147,30 @@ class XlaBuilder {
friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
- const PrecisionConfigProto* precision_config_proto);
+ const PrecisionConfig* precision_config);
friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_number,
- const PrecisionConfigProto* precision_config_proto);
+ const PrecisionConfig* precision_config);
friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto);
+ const PrecisionConfig* precision_config);
friend XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto);
+ int64 feature_group_count, const PrecisionConfig* precision_config);
friend XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto);
+ int64 feature_group_count, const PrecisionConfig* precision_config);
friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto);
+ const PrecisionConfig* precision_config);
friend XlaOp ConvGeneralDilated(
const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides,
@@ -1183,8 +1178,7 @@ class XlaBuilder {
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto);
+ int64 feature_group_count, const PrecisionConfig* precision_config);
friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
absl::Span<const int64> fft_length);
friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
@@ -1629,27 +1623,27 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
// Enqueues a dot instruction onto the computation.
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a general dot instruction onto the computation.
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
-XlaOp ConvWithGeneralPadding(
- const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
- absl::Span<const std::pair<int64, int64>> padding,
- int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
@@ -1657,7 +1651,7 @@ XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
@@ -1666,17 +1660,18 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
-XlaOp ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
- absl::Span<const std::pair<int64, int64>> padding,
- absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.