aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/xla_builder.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_builder.cc')
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc82
1 files changed, 37 insertions, 45 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 7f2125f74c..887b970661 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -820,7 +820,7 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs,
- const PrecisionConfigProto* precision_config_proto) {
+ const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -828,14 +828,13 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs,
dimension_numbers.add_lhs_contracting_dimensions(
lhs_shape.dimensions_size() == 1 ? 0 : 1);
dimension_numbers.add_rhs_contracting_dimensions(0);
- return DotGeneral(lhs, rhs, dimension_numbers, precision_config_proto);
+ return DotGeneral(lhs, rhs, dimension_numbers, precision_config);
});
}
-XlaOp XlaBuilder::DotGeneral(
- const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto* precision_config_proto) {
+XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -844,8 +843,8 @@ XlaOp XlaBuilder::DotGeneral(
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
dimension_numbers));
*instr.mutable_dot_dimension_numbers() = dimension_numbers;
- if (precision_config_proto != nullptr) {
- *instr.mutable_precision_config() = *precision_config_proto;
+ if (precision_config != nullptr) {
+ *instr.mutable_precision_config() = *precision_config;
}
return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
});
@@ -899,28 +898,26 @@ Status XlaBuilder::VerifyConvolution(
XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ const PrecisionConfig* precision_config) {
return ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()),
- feature_group_count, precision_config_proto);
+ feature_group_count, precision_config);
}
XlaOp XlaBuilder::ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return ConvGeneral(lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()),
- feature_group_count, precision_config_proto);
+ feature_group_count, precision_config);
}
XlaOp XlaBuilder::ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -948,7 +945,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
MakePadding(base_area_dimensions, window_dimensions,
window_strides, padding),
dimension_numbers, feature_group_count,
- precision_config_proto);
+ precision_config);
});
}
@@ -956,11 +953,10 @@ XlaOp XlaBuilder::ConvGeneral(
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
dimension_numbers, feature_group_count,
- precision_config_proto);
+ precision_config);
}
XlaOp XlaBuilder::ConvGeneralDilated(
@@ -968,8 +964,7 @@ XlaOp XlaBuilder::ConvGeneralDilated(
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -996,8 +991,8 @@ XlaOp XlaBuilder::ConvGeneralDilated(
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
instr.set_feature_group_count(feature_group_count);
- if (precision_config_proto != nullptr) {
- *instr.mutable_precision_config() = *precision_config_proto;
+ if (precision_config != nullptr) {
+ *instr.mutable_precision_config() = *precision_config;
}
return AddInstruction(std::move(instr), HloOpcode::kConvolution,
@@ -2594,43 +2589,40 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
- const PrecisionConfigProto* precision_config_proto) {
- return lhs.builder()->Dot(lhs, rhs, precision_config_proto);
+ const PrecisionConfig* precision_config) {
+ return lhs.builder()->Dot(lhs, rhs, precision_config);
}
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto* precision_config_proto) {
+ const PrecisionConfig* precision_config) {
return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
- precision_config_proto);
+ precision_config);
}
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
- feature_group_count, precision_config_proto);
+ feature_group_count, precision_config);
}
-XlaOp ConvWithGeneralPadding(
- const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
- absl::Span<const std::pair<int64, int64>> padding,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
- return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides,
- padding, feature_group_count,
- precision_config_proto);
+XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ int64 feature_group_count,
+ const PrecisionConfig* precision_config) {
+ return lhs.builder()->ConvWithGeneralPadding(
+ lhs, rhs, window_strides, padding, feature_group_count, precision_config);
}
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return lhs.builder()->ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
- precision_config_proto);
+ precision_config);
}
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
@@ -2638,10 +2630,10 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ const PrecisionConfig* precision_config) {
return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
dimension_numbers, feature_group_count,
- precision_config_proto);
+ precision_config);
}
XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
@@ -2651,10 +2643,10 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ const PrecisionConfig* precision_config) {
return lhs.builder()->ConvGeneralDilated(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
- dimension_numbers, feature_group_count, precision_config_proto);
+ dimension_numbers, feature_group_count, precision_config);
}
XlaOp Fft(const XlaOp& operand, FftType fft_type,