diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_builder.cc')
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.cc | 82 |
1 files changed, 37 insertions, 45 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 7f2125f74c..887b970661 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -820,7 +820,7 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -828,14 +828,13 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, dimension_numbers.add_lhs_contracting_dimensions( lhs_shape.dimensions_size() == 1 ? 0 : 1); dimension_numbers.add_rhs_contracting_dimensions(0); - return DotGeneral(lhs, rhs, dimension_numbers, precision_config_proto); + return DotGeneral(lhs, rhs, dimension_numbers, precision_config); }); } -XlaOp XlaBuilder::DotGeneral( - const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto) { +XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -844,8 +843,8 @@ XlaOp XlaBuilder::DotGeneral( ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); *instr.mutable_dot_dimension_numbers() = dimension_numbers; - if (precision_config_proto != nullptr) { - *instr.mutable_precision_config() = *precision_config_proto; + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; } return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); }); @@ -899,28 +898,26 @@ Status XlaBuilder::VerifyConvolution( XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ConvGeneral(lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -948,7 +945,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); }); } @@ -956,11 +953,10 @@ XlaOp XlaBuilder::ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp XlaBuilder::ConvGeneralDilated( @@ -968,8 +964,7 @@ XlaOp XlaBuilder::ConvGeneralDilated( absl::Span<const std::pair<int64, int64>> padding, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -996,8 +991,8 @@ XlaOp XlaBuilder::ConvGeneralDilated( *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); - if (precision_config_proto != nullptr) { - *instr.mutable_precision_config() = *precision_config_proto; + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; } return AddInstruction(std::move(instr), HloOpcode::kConvolution, @@ -2594,43 +2589,40 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, } XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto) { - return lhs.builder()->Dot(lhs, rhs, precision_config_proto); + const PrecisionConfig* precision_config) { + return lhs.builder()->Dot(lhs, rhs, precision_config); } XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers, - precision_config_proto); + precision_config); } XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->Conv(lhs, rhs, window_strides, padding, - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } -XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { - return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, - padding, feature_group_count, - precision_config_proto); +XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, + absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + int64 feature_group_count, + const PrecisionConfig* precision_config) { + return lhs.builder()->ConvWithGeneralPadding( + lhs, rhs, window_strides, padding, feature_group_count, precision_config); } XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, @@ -2638,10 +2630,10 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, @@ -2651,10 +2643,10 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneralDilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers, feature_group_count, precision_config_proto); + dimension_numbers, feature_group_count, precision_config); } XlaOp Fft(const XlaOp& operand, FftType fft_type, |