diff options
13 files changed, 277 insertions, 59 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index c6976fd849..7bc6e8d860 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -808,7 +809,8 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions); } -XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { +XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfigProto* precision_config_proto) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -816,12 +818,14 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { dimension_numbers.add_lhs_contracting_dimensions( lhs_shape.dimensions_size() == 1 ? 0 : 1); dimension_numbers.add_rhs_contracting_dimensions(0); - return DotGeneral(lhs, rhs, dimension_numbers); + return DotGeneral(lhs, rhs, dimension_numbers, precision_config_proto); }); } -XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers) { +XlaOp XlaBuilder::DotGeneral( + const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfigProto* precision_config_proto) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -830,6 +834,9 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); *instr.mutable_dot_dimension_numbers() = dimension_numbers; + if (precision_config_proto != nullptr) { + *instr.mutable_precision_config() = *precision_config_proto; + } return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); }); } @@ -883,28 +890,31 @@ Status XlaBuilder::VerifyConvolution( XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, - Padding padding, int64 feature_group_count) { + Padding padding, int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count); + feature_group_count, precision_config_proto); } XlaOp XlaBuilder::ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, - int64 feature_group_count) { + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return ConvGeneral(lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count); + feature_group_count, precision_config_proto); } XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) { + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -931,7 +941,8 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( return ConvGeneral(lhs, rhs, window_strides, MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), - dimension_numbers, feature_group_count); + dimension_numbers, feature_group_count, + precision_config_proto); }); } @@ -940,9 +951,11 @@ XlaOp XlaBuilder::ConvGeneral( tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) { + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, - dimension_numbers, feature_group_count); + dimension_numbers, feature_group_count, + precision_config_proto); } XlaOp XlaBuilder::ConvGeneralDilated( @@ -952,7 +965,8 @@ XlaOp XlaBuilder::ConvGeneralDilated( tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) { + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -979,6 +993,10 @@ XlaOp XlaBuilder::ConvGeneralDilated( *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); + if (precision_config_proto != nullptr) { + *instr.mutable_precision_config() = *precision_config_proto; + } + return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs}); }); @@ -2548,48 +2566,57 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, return lhs.builder()->Le(lhs, rhs, broadcast_dimensions); } -XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs) { - return lhs.builder()->Dot(lhs, rhs); +XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfigProto* precision_config_proto) { + return lhs.builder()->Dot(lhs, rhs, precision_config_proto); } XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers) { - return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers); + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfigProto* precision_config_proto) { + return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers, + precision_config_proto); } XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, - int64 feature_group_count) { + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return lhs.builder()->Conv(lhs, rhs, window_strides, padding, - feature_group_count); + feature_group_count, precision_config_proto); } XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, - int64 feature_group_count) { + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, - padding, feature_group_count); + padding, feature_group_count, + precision_config_proto); } XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) { - return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides, - padding, dimension_numbers, - feature_group_count); + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { + return lhs.builder()->ConvWithGeneralDimensions( + lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, + precision_config_proto); } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) { + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, - dimension_numbers, feature_group_count); + dimension_numbers, feature_group_count, + precision_config_proto); } XlaOp ConvGeneralDilated( @@ -2599,10 +2626,11 @@ XlaOp ConvGeneralDilated( tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) { + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return lhs.builder()->ConvGeneralDilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers, feature_group_count); + dimension_numbers, feature_group_count, precision_config_proto); } XlaOp Fft(const XlaOp& operand, FftType fft_type, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 089967147f..8d9ec9a18a 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -501,17 +501,21 @@ class XlaBuilder { tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); // Enqueues a dot instruction onto the computation. - XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a general dot instruction onto the computation. - XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers); + XlaOp DotGeneral( + const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, - int64 feature_group_count = 1); + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). @@ -519,7 +523,8 @@ class XlaBuilder { const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, - int64 feature_group_count = 1); + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. @@ -527,7 +532,8 @@ class XlaBuilder { const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. @@ -536,7 +542,8 @@ class XlaBuilder { tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. @@ -547,7 +554,8 @@ class XlaBuilder { tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. @@ -1146,28 +1154,34 @@ class XlaBuilder { tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); - friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfigProto* precision_config_proto); friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers); + const DotDimensionNumbers& dimension_number, + const PrecisionConfigProto* precision_config_proto); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, - Padding padding, int64 feature_group_count); + Padding padding, int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, - int64 feature_group_count); + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count); + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto); friend XlaOp ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count); + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, @@ -1175,7 +1189,8 @@ class XlaBuilder { tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count); + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, tensorflow::gtl::ArraySlice<int64> fft_length); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, @@ -1626,17 +1641,20 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); // Enqueues a dot instruction onto the computation. -XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); +XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a general dot instruction onto the computation. XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers); + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, - int64 feature_group_count = 1); + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). @@ -1644,7 +1662,8 @@ XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, - int64 feature_group_count = 1); + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. @@ -1652,7 +1671,8 @@ XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. @@ -1660,7 +1680,8 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. @@ -1671,7 +1692,8 @@ XlaOp ConvGeneralDilated( tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 0a040b5d16..b86b7d2e71 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -268,7 +268,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot); StatusOr<HloInstruction*> OptimizeDotOfConcatHelper( - const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot); @@ -829,18 +829,18 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat( TF_ASSIGN_OR_RETURN( HloInstruction * optimized_lhs_concat, - OptimizeDotOfConcatHelper(dot->shape(), lhs, lhs_contracting_dim, rhs, + OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs, rhs_contracting_dim, /*swapped=*/false)); if (optimized_lhs_concat) { return optimized_lhs_concat; } - return OptimizeDotOfConcatHelper(dot->shape(), rhs, rhs_contracting_dim, lhs, + return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs, lhs_contracting_dim, /*swapped=*/true); } StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( - const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) { bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate && lhs->concatenate_dimension() == lhs_contracting_dim && @@ -939,11 +939,12 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( } auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot_shape, new_dot_lhs, new_dot_rhs, new_dot_dnums)); + dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums)); + new_dot->set_precision_config(dot.precision_config()); if (add_result) { add_result = computation_->AddInstruction(HloInstruction::CreateBinary( - dot_shape, HloOpcode::kAdd, add_result, new_dot)); + dot.shape(), HloOpcode::kAdd, add_result, new_dot)); } else { add_result = new_dot; } @@ -1042,6 +1043,7 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather( auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( memoized_shape, left_operand, right_operand, dnums)); + memoized_inst->set_precision_config(dot->precision_config()); // Get pair {start, 0} or {0, start}. HloInstruction* original_start_indices = lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); @@ -1139,6 +1141,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers)); + new_dot->set_precision_config(dot->precision_config()); return ReplaceWithNewInstruction( dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); } @@ -2297,6 +2300,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto dot = computation_->AddInstruction(HloInstruction::CreateDot( dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); + dot->set_precision_config(convolution->precision_config()); + return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index b226e7ecb0..be6fbcc9e3 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -64,6 +64,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, MakeDotHlo(new_lhs, new_rhs, new_dim_numbers)); + new_dot->set_precision_config(batch_dot->precision_config()); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, MakeReshapeHlo(batch_dot->shape(), new_dot)); diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 8affa08b65..9c81a86bbb 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -224,6 +224,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { auto new_convolution = HloInstruction::CreateConvolve( convolution->shape(), convolution->mutable_operand(0), new_filter, convolution->window(), dim_numbers, /*feature_group_count=*/1); + new_convolution->set_precision_config(convolution->precision_config()); TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( convolution, std::move(new_convolution))); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 0985b9297f..098ce17a56 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -132,6 +132,7 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) { HloInstruction* new_conv = module->entry_computation()->AddInstruction( HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, hlo->window(), new_dnums)); + new_conv->set_precision_config(hlo->precision_config()); // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 12faed6967..09cb10d6ee 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -136,6 +136,7 @@ Status DecomposeBatchDot(HloInstruction* dot) { dot_dnums.add_rhs_contracting_dimensions(0); auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); + dot_r2->set_precision_config(dot->precision_config()); // Reshape Dot to R3 so we can concat along batch dimension. auto dot_r3 = computation->AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index fa218657fe..12b609a60f 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 51 +// Next ID: 52 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -171,6 +171,9 @@ message HloInstructionProto { bool is_host_transfer = 47; xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; + + // Precision configuration for the instruction. Has backend-specific meaning. + xla.PrecisionConfigProto precision_config = 51; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 8a9856c1da..9d795da100 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -444,6 +444,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); + instruction->precision_config_ = proto.precision_config(); if (proto.has_dot_dimension_numbers()) { instruction->dot_dimension_numbers_ = @@ -1019,6 +1020,7 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->clear_sharding(); } derived_instruction->set_metadata(metadata_); + derived_instruction->set_precision_config(precision_config_); } bool HloInstruction::HasSideEffectNoRecurse() const { @@ -1279,6 +1281,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( } break; } + // SetupDerivedInstruction will setup the precision_config_ field. SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); clone->set_raw_backend_config_string(backend_config_); @@ -2000,6 +2003,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString( extra.push_back(DotDimensionNumbersToString()); } + string precision_config_string = PrecisionConfigToString(); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { @@ -2121,6 +2129,7 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; proto.set_backend_config(backend_config_); + *proto.mutable_precision_config() = precision_config_; if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); @@ -2819,6 +2828,11 @@ string RandomDistributionToString(const RandomDistribution& distribution) { return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); } +string PrecisionToString(const PrecisionConfigProto::Precision& precision) { + return tensorflow::str_util::Lowercase( + PrecisionConfigProto::Precision_Name(precision)); +} + string ConvolutionDimensionNumbersToString( const ConvolutionDimensionNumbers& dnums) { // lhs_dims[i] is the symbol of the logical dimension i for the lhs @@ -2889,6 +2903,44 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { return found->second; } +string HloInstruction::PrecisionConfigToString() const { + if (precision_config_.operand_precision().empty()) { + return ""; + } + return StrCat( + "operand_precision={", + Join(precision_config_.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfigProto::Precision_IsValid(precision)) + << precision; + StrAppend( + out, + PrecisionToString( + static_cast<PrecisionConfigProto::Precision>(precision))); + }), + "}"); +} + +StatusOr<PrecisionConfigProto::Precision> StringToPrecision( + const string& name) { + static std::unordered_map<string, PrecisionConfigProto::Precision>* map = [] { + static auto* map = + new std::unordered_map<string, PrecisionConfigProto::Precision>; + for (int i = 0; i < PrecisionConfigProto::Precision_ARRAYSIZE; i++) { + if (PrecisionConfigProto::Precision_IsValid(i)) { + auto value = static_cast<PrecisionConfigProto::Precision>(i); + (*map)[PrecisionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 69397a4b37..21710bd31d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -102,6 +102,7 @@ class HloPrintOptions { return HloPrintOptions() .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) .set_print_metadata(false) + .set_print_backend_config(false) .set_compact_operands(true) .set_print_operand_shape(true) .set_print_program_shape(false) @@ -183,7 +184,7 @@ class HloPrintOptions { return print_subcomputation_mode_; } bool print_metadata() const { return print_metadata_; } - bool print_backend_config() const { return print_metadata_; } + bool print_backend_config() const { return print_backend_config_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } @@ -858,6 +859,11 @@ class HloInstruction { return false; } + if (!ContainersEqual(precision_config_.operand_precision(), + other.precision_config_.operand_precision())) { + return false; + } + return IdenticalSlowPath(other, eq_computations); } @@ -1105,6 +1111,9 @@ class HloInstruction { // Returns the dump string of the dot dimension numbers. string DotDimensionNumbersToString() const; + // Returns the dump string of the precision configuration. + string PrecisionConfigToString() const; + // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1248,6 +1257,20 @@ class HloInstruction { static StatusOr<string> BackendConfigToRawString( const tensorflow::protobuf::Message& proto); + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfigProto& precision_config() const { + return precision_config_; + } + void set_precision_config(const PrecisionConfigProto& precision_config) { + precision_config_ = precision_config; + } + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1653,6 +1676,10 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfigProto precision_config_; + // String identifier for instruction. string name_; @@ -1675,10 +1702,12 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); +string PrecisionToString(const PrecisionConfigProto::Precision& precision); string ConvolutionDimensionNumbersToString( const ConvolutionDimensionNumbers& dnums); StatusOr<RandomDistribution> StringToRandomDistribution(const string& name); +StatusOr<PrecisionConfigProto::Precision> StringToPrecision(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index b4793998ec..ede55510d3 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -155,6 +155,7 @@ class HloParser { kFusionKind, kDistribution, kDomain, + kPrecisionList, }; struct AttrConfig { @@ -220,6 +221,7 @@ class HloParser { bool ParseWindowPad(std::vector<std::vector<tensorflow::int64>>* pad); bool ParseSliceRanges(SliceRanges* result); + bool ParsePrecisionList(std::vector<PrecisionConfigProto::Precision>* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector<tensorflow::int64>* result); @@ -238,6 +240,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); + bool ParsePrecision(PrecisionConfigProto::Precision* result); bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); @@ -502,6 +505,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, attrs["backend_config"] = {/*required=*/false, AttrTy::kString, &backend_config}; + optional<std::vector<PrecisionConfigProto::Precision>> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -1366,6 +1373,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (backend_config) { instruction->set_raw_backend_config_string(std::move(*backend_config)); } + if (operand_precision) { + PrecisionConfigProto precision_config; + *precision_config.mutable_operand_precision() = {operand_precision->begin(), + operand_precision->end()}; + instruction->set_precision_config(precision_config); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -2343,6 +2356,16 @@ bool HloParser::ParseAttributeHelper( case AttrTy::kDomain: { return ParseDomain(static_cast<DomainData*>(attr_out_ptr)); } + case AttrTy::kPrecisionList: { + std::vector<PrecisionConfigProto::Precision> result; + if (!ParsePrecisionList(&result)) { + return false; + } + static_cast<optional<std::vector<PrecisionConfigProto::Precision>>*>( + attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { @@ -2615,6 +2638,24 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); } +// precisionlist ::= start precision_elements end +// precision_elements +// ::= /*empty*/ +// ::= precision_val (delim precision_val)* +bool HloParser::ParsePrecisionList( + std::vector<PrecisionConfigProto::Precision>* result) { + auto parse_and_add_item = [&]() { + PrecisionConfigProto::Precision item; + if (!ParsePrecision(&item)) { + return false; + } + result->push_back(item); + return true; + }; + return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item); +} + // int64list ::= start int64_elements end // int64_elements // ::= /*empty*/ @@ -2941,6 +2982,23 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { return true; } +bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) { + VLOG(1) << "ParsePrecision"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects random distribution"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToPrecision(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects precision but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + bool HloParser::ParseInt64(tensorflow::int64* result) { VLOG(1) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 49e1f87319..530f40e4b2 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -109,6 +109,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) { std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot( dot->shape(), new_lhs, new_rhs, new_dim_numbers); + new_dot->set_precision_config(dot->precision_config()); return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot)); } @@ -178,6 +179,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); + new_conv->set_precision_config(convolution.precision_config()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 27aa94c2cb..9451e0c315 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -569,3 +569,18 @@ message ReplicaGroup { // ids matters in some op (e.g., all-to-all). repeated int64 replica_ids = 1; } + +// Used to indicate the precision configuration. It has backend specific +// meaning. +message PrecisionConfigProto { + enum Precision { + DEFAULT = 0; + HIGH = 1; + HIGHEST = 2; + + // Next: 3 + } + repeated Precision operand_precision = 1; + + // Next: 2 +} |