diff options
Diffstat (limited to 'tensorflow')
39 files changed, 218 insertions, 238 deletions
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index b2f026df6c..3f928a1bea 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -97,9 +97,9 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32)) - PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT, - xla_data_pb2.PrecisionConfigProto.HIGH, - xla_data_pb2.PrecisionConfigProto.HIGHEST) + PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfig.DEFAULT, + xla_data_pb2.PrecisionConfig.HIGH, + xla_data_pb2.PrecisionConfig.HIGHEST) @parameterized.parameters(*PRECISION_VALUES) def testConv(self, precision): @@ -120,7 +120,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) precision_config = None if precision: - precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config = xla_data_pb2.PrecisionConfig() precision_config.operand_precision.extend([precision, precision]) return xla.conv( lhs, @@ -151,7 +151,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): dnums.rhs_batch_dimensions.append(0) precision_config = None if precision: - precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config = xla_data_pb2.PrecisionConfig() precision_config.operand_precision.extend([precision, precision]) return xla.dot_general( lhs, diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc index 8848623868..fecc7c556e 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -84,7 +84,7 @@ class XlaConvOp : public XlaOpKernel { private: xla::ConvolutionDimensionNumbers dnums_; - xla::PrecisionConfigProto precision_config_; + xla::PrecisionConfig precision_config_; TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp); }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc index 2fed53e5c0..40b15b5579 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -54,7 +54,7 @@ class XlaDotOp : public XlaOpKernel { private: xla::DotDimensionNumbers dnums_; - xla::PrecisionConfigProto precision_config_; + xla::PrecisionConfig precision_config_; TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp); }; diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index d8c050d09e..64f2d781a6 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -28,7 +28,7 @@ namespace tensorflow { xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, bool transpose_y, bool conjugate_x, bool conjugate_y, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); @@ -96,7 +96,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, y = xla::Conj(y); } - xla::PrecisionConfigProto precision_proto; + xla::PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index 6cfccd5553..6edd63a4d3 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -43,11 +43,11 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, - bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::DEFAULT); +xla::XlaOp BatchDot( + xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, + bool transpose_y = false, bool conjugate_x = false, + bool conjugate_y = false, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index c50a8de33e..ab3d0a5668 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -50,7 +50,7 @@ namespace { // l[..., j, j] // return l xla::XlaOp CholeskyUnblocked(xla::XlaOp a, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -150,7 +150,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, } // namespace xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 60cd7ded53..9a561c34b9 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -30,9 +30,9 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::HIGHEST); +xla::XlaOp Cholesky( + xla::XlaOp a, int64 block_size = 256, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index 0a140fa93c..6b3f2b6e06 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -150,7 +150,7 @@ struct QRBlockResult { xla::XlaOp vs; // Shape: [..., m, n] }; xla::StatusOr<QRBlockResult> QRBlock( - xla::XlaOp a, xla::PrecisionConfigProto::Precision precision) { + xla::XlaOp a, xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); @@ -257,7 +257,7 @@ xla::StatusOr<QRBlockResult> QRBlock( xla::StatusOr<xla::XlaOp> ComputeWYRepresentation( xla::PrimitiveType type, absl::Span<const int64> batch_dims, xla::XlaOp vs, xla::XlaOp taus, int64 m, int64 n, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { std::vector<int64> batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64 n_index = batch_dims.size() + 1; @@ -332,7 +332,7 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation( // rather than WY transformations. xla::StatusOr<QRDecompositionResult> QRDecomposition( xla::XlaOp a, bool full_matrices, int64 block_size, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h index 8a389fb7b0..24b537ac8b 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -35,8 +35,7 @@ struct QRDecompositionResult { xla::StatusOr<QRDecompositionResult> QRDecomposition( xla::XlaOp a, bool full_matrices, int64 block_size = 128, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::HIGHEST); + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 37b2240b45..6524c2a9b1 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -110,9 +110,9 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { }); } -xla::XlaOp InvertDiagonalBlocks( - xla::XlaOp diag_blocks, bool lower, bool transpose_a, bool conjugate_a, - xla::PrecisionConfigProto::Precision precision) { +xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, + bool transpose_a, bool conjugate_a, + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = diag_blocks.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { // Input is a batch of square lower triangular square matrices. Its shape is @@ -216,7 +216,7 @@ xla::XlaOp InvertDiagonalBlocks( dnums.add_rhs_batch_dimensions(0); dnums.add_lhs_contracting_dimensions(2); dnums.add_rhs_contracting_dimensions(1); - xla::PrecisionConfigProto precision_proto; + xla::PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); @@ -245,7 +245,7 @@ xla::XlaOp InvertDiagonalBlocks( xla::XlaOp SolveWithInvertedDiagonalBlocks( xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side, bool lower, bool transpose_a, bool conjugate_a, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, @@ -346,7 +346,7 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, int64 block_size, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index ac42a48352..2303234f36 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -57,11 +57,10 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - int64 block_size = 128, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::HIGHEST); +xla::XlaOp TriangularSolve( + xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, int64 block_size = 128, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 2cd9ae799f..68cfdc1785 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -83,7 +83,7 @@ lhs_dilation: dilation to apply between input elements rhs_dilation: dilation to apply between kernel elements feature_group_count: number of feature groups for grouped convolution. dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto. -precision_config: a serialized xla::PrecisionConfigProto proto. +precision_config: a serialized xla::PrecisionConfig proto. )doc"); REGISTER_OP("XlaDot") @@ -102,7 +102,7 @@ Wraps the XLA ConvGeneralDilated operator, documented at lhs: the LHS tensor rhs: the RHS tensor dimension_numbers: a serialized xla::DotDimensionNumbers proto. -precision_config: a serialized xla::PrecisionConfigProto proto. +precision_config: a serialized xla::PrecisionConfig proto. )doc"); REGISTER_OP("XlaDynamicUpdateSlice") diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 7f2125f74c..887b970661 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -820,7 +820,7 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -828,14 +828,13 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, dimension_numbers.add_lhs_contracting_dimensions( lhs_shape.dimensions_size() == 1 ? 0 : 1); dimension_numbers.add_rhs_contracting_dimensions(0); - return DotGeneral(lhs, rhs, dimension_numbers, precision_config_proto); + return DotGeneral(lhs, rhs, dimension_numbers, precision_config); }); } -XlaOp XlaBuilder::DotGeneral( - const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto) { +XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -844,8 +843,8 @@ XlaOp XlaBuilder::DotGeneral( ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); *instr.mutable_dot_dimension_numbers() = dimension_numbers; - if (precision_config_proto != nullptr) { - *instr.mutable_precision_config() = *precision_config_proto; + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; } return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); }); @@ -899,28 +898,26 @@ Status XlaBuilder::VerifyConvolution( XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ConvGeneral(lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -948,7 +945,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); }); } @@ -956,11 +953,10 @@ XlaOp XlaBuilder::ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp XlaBuilder::ConvGeneralDilated( @@ -968,8 +964,7 @@ XlaOp XlaBuilder::ConvGeneralDilated( absl::Span<const std::pair<int64, int64>> padding, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -996,8 +991,8 @@ XlaOp XlaBuilder::ConvGeneralDilated( *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); - if (precision_config_proto != nullptr) { - *instr.mutable_precision_config() = *precision_config_proto; + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; } return AddInstruction(std::move(instr), HloOpcode::kConvolution, @@ -2594,43 +2589,40 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, } XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto) { - return lhs.builder()->Dot(lhs, rhs, precision_config_proto); + const PrecisionConfig* precision_config) { + return lhs.builder()->Dot(lhs, rhs, precision_config); } XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers, - precision_config_proto); + precision_config); } XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->Conv(lhs, rhs, window_strides, padding, - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } -XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { - return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, - padding, feature_group_count, - precision_config_proto); +XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, + absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + int64 feature_group_count, + const PrecisionConfig* precision_config) { + return lhs.builder()->ConvWithGeneralPadding( + lhs, rhs, window_strides, padding, feature_group_count, precision_config); } XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, @@ -2638,10 +2630,10 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, @@ -2651,10 +2643,10 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneralDilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers, feature_group_count, precision_config_proto); + dimension_numbers, feature_group_count, precision_config); } XlaOp Fft(const XlaOp& operand, FftType fft_type, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 59fbc664f2..58e8f4e7fa 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -496,20 +496,19 @@ class XlaBuilder { // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a general dot instruction onto the computation. - XlaOp DotGeneral( - const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). @@ -518,7 +517,7 @@ class XlaBuilder { absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. @@ -527,29 +526,27 @@ class XlaBuilder { absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. - XlaOp ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. - XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - absl::Span<const int64> lhs_dilation, - absl::Span<const int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + absl::Span<const int64> lhs_dilation, + absl::Span<const int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. @@ -1150,32 +1147,30 @@ class XlaBuilder { friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> broadcast_dimensions); friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_number, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, @@ -1183,8 +1178,7 @@ class XlaBuilder { absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span<const int64> fft_length); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, @@ -1629,27 +1623,27 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a general dot instruction onto the computation. XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). -XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); +XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, + absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. @@ -1657,7 +1651,7 @@ XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. @@ -1666,17 +1660,18 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span<const std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. -XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); +XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + absl::Span<const int64> lhs_dilation, + absl::Span<const int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 8a05d1b0d7..9f1afa2671 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -574,9 +574,9 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - /*new_size=*/2, PrecisionConfigProto::DEFAULT); + /*new_size=*/2, PrecisionConfig::DEFAULT); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, precision_config)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 0db74bd038..aa40fba9bb 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2379,9 +2379,9 @@ TEST_P(ConvFilterPaddingTest, DoIt) { // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place // after the transformation. - PrecisionConfigProto precision_config; - precision_config.add_operand_precision(PrecisionConfigProto::HIGH); - precision_config.add_operand_precision(PrecisionConfigProto::HIGHEST); + PrecisionConfig precision_config; + precision_config.add_operand_precision(PrecisionConfig::HIGH); + precision_config.add_operand_precision(PrecisionConfig::HIGHEST); orig_conv->set_precision_config(precision_config); auto module = CreateNewModule(); @@ -2401,9 +2401,8 @@ TEST_P(ConvFilterPaddingTest, DoIt) { conv->operand(1)->shape().dimensions(2), conv->operand(1)->shape().dimensions(3), testcase.expected_conv_window)); - EXPECT_THAT( - conv->precision_config().operand_precision(), - ElementsAre(PrecisionConfigProto::HIGH, PrecisionConfigProto::HIGHEST)); + EXPECT_THAT(conv->precision_config().operand_precision(), + ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST)); } } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index d480d72297..933cf873e0 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -308,9 +308,9 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config)); diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 7398f105a0..56bd67fb55 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1490,9 +1490,9 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot( shape_2x4, param_a, param_b, dot_dnums, precision_config)); auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 6bd0a2dd90..0fea462c85 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -38,9 +38,9 @@ std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, precision_config); } diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 0a49d85c6d..ef70b68877 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -112,9 +112,9 @@ std::unique_ptr<HloModule> MakeBigGraph() { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - /*new_size=*/2, PrecisionConfigProto::DEFAULT); + /*new_size=*/2, PrecisionConfig::DEFAULT); auto dot = builder.AddInstruction(HloInstruction::CreateDot( vshape, clamp, param_v0, dot_dnums, precision_config)); auto tuple = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 58b7af93eb..99d0cf50ca 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -172,7 +172,7 @@ message HloInstructionProto { xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; // Precision configuration for the instruction. Has backend-specific meaning. - xla.PrecisionConfigProto precision_config = 51; + xla.PrecisionConfig precision_config = 51; // Collective permute field. repeated SourceTarget source_target_pairs = 52; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index a2c1ce34c6..2aaaef1d36 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -601,9 +601,9 @@ TEST_F(HloComputationTest, Stringification) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); @@ -636,9 +636,9 @@ TEST_F(HloComputationTest, StringificationIndent) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); @@ -672,9 +672,9 @@ TEST_F(HloComputationTest, StringificationCanonical) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index a6ae0337a5..a3fcc0fefa 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -63,7 +63,7 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand, StatusOr<HloInstruction*> MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config) { + const PrecisionConfig& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN(Shape convolve_shape, @@ -167,10 +167,9 @@ StatusOr<HloInstruction*> MakeConcatHlo( HloInstruction::CreateConcatenate(concat_shape, operands, dimension)); } -StatusOr<HloInstruction*> MakeDotHlo( - HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config) { +StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 1c82956907..b22058abb4 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -50,7 +50,7 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand, StatusOr<HloInstruction*> MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config); + const PrecisionConfig& precision_config); // Creates a transpose HLO instruction and adds it to the computation containing // `operand`. @@ -98,10 +98,9 @@ StatusOr<HloInstruction*> MakeConcatHlo( // Creates a Dot HLO instruction and adds it to the computation containing `lhs` // and `rhs` (both must be in the same computation). -StatusOr<HloInstruction*> MakeDotHlo( - HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config); +StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config); // Creates a Map HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 62eea2b06c..72b236801a 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2334,9 +2334,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); auto dot = builder.AddInstruction( HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index ffb3451164..d0d955fea8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -345,7 +345,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp( StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp( const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, const Literal& lhs, + const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs) { std::unique_ptr<HloInstruction> lhs_instr = HloInstruction::CreateConstant(lhs.CloneToUnique()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index e13af8e999..72252bafc7 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -116,7 +116,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { StatusOr<std::unique_ptr<Literal>> EvaluateDotOp( const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, const Literal& lhs, + const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs); protected: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f25761ac70..471a12d6aa 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -347,9 +347,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( << proto.operand_ids_size(); TF_RET_CHECK(proto.has_window()); TF_RET_CHECK(proto.has_convolution_dimension_numbers()); - PrecisionConfigProto precision_config = proto.precision_config(); + PrecisionConfig precision_config = proto.precision_config(); precision_config.mutable_operand_precision()->Resize( - proto.operand_ids_size(), PrecisionConfigProto::DEFAULT); + proto.operand_ids_size(), PrecisionConfig::DEFAULT); instruction = CreateConvolve( proto.shape(), operands(0), operands(1), std::max<int64>(proto.feature_group_count(), 1), proto.window(), @@ -475,7 +475,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( if (instruction->opcode() == HloOpcode::kDot) { instruction->precision_config_ = proto.precision_config(); instruction->precision_config_.mutable_operand_precision()->Resize( - instruction->operand_count(), PrecisionConfigProto::DEFAULT); + instruction->operand_count(), PrecisionConfig::DEFAULT); TF_RET_CHECK(proto.has_dot_dimension_numbers()); instruction->dot_dimension_numbers_ = absl::make_unique<DotDimensionNumbers>( @@ -657,7 +657,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config) { + const PrecisionConfig& precision_config) { return absl::make_unique<HloConvolutionInstruction>( shape, lhs, rhs, feature_group_count, window, dimension_numbers, precision_config); @@ -673,7 +673,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config) { + const PrecisionConfig& precision_config) { auto instruction = absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); instruction->AppendOperand(lhs); @@ -2888,8 +2888,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) { return absl::AsciiStrToLower(RandomDistribution_Name(distribution)); } -string PrecisionToString(const PrecisionConfigProto::Precision& precision) { - return absl::AsciiStrToLower(PrecisionConfigProto::Precision_Name(precision)); +string PrecisionToString(const PrecisionConfig::Precision& precision) { + return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision)); } string ConvolutionDimensionNumbersToString( @@ -2967,32 +2967,31 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { string HloInstruction::PrecisionConfigToString() const { if (absl::c_all_of( precision_config_.operand_precision(), [](int32 precision) { - return static_cast<PrecisionConfigProto::Precision>(precision) == - PrecisionConfigProto::DEFAULT; + return static_cast<PrecisionConfig::Precision>(precision) == + PrecisionConfig::DEFAULT; })) { return ""; } return StrCat( "operand_precision={", - StrJoin(precision_config_.operand_precision(), ",", - [](string* out, int32 precision) { - CHECK(PrecisionConfigProto::Precision_IsValid(precision)) - << precision; - StrAppend(out, PrecisionToString( - static_cast<PrecisionConfigProto::Precision>( - precision))); - }), + StrJoin( + precision_config_.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision; + StrAppend(out, + PrecisionToString( + static_cast<PrecisionConfig::Precision>(precision))); + }), "}"); } -StatusOr<PrecisionConfigProto::Precision> StringToPrecision( - const string& name) { - static std::unordered_map<string, PrecisionConfigProto::Precision>* map = [] { +StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) { + static std::unordered_map<string, PrecisionConfig::Precision>* map = [] { static auto* map = - new std::unordered_map<string, PrecisionConfigProto::Precision>; - for (int i = 0; i < PrecisionConfigProto::Precision_ARRAYSIZE; i++) { - if (PrecisionConfigProto::Precision_IsValid(i)) { - auto value = static_cast<PrecisionConfigProto::Precision>(i); + new std::unordered_map<string, PrecisionConfig::Precision>; + for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) { + if (PrecisionConfig::Precision_IsValid(i)) { + auto value = static_cast<PrecisionConfig::Precision>(i); (*map)[PrecisionToString(value)] = value; } } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 55d592ff94..691f8155f9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -407,7 +407,7 @@ class HloInstruction { const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config); + const PrecisionConfig& precision_config); // Creates an FFT op, of the type indicated by fft_type. static std::unique_ptr<HloInstruction> CreateFft( @@ -419,7 +419,7 @@ class HloInstruction { static std::unique_ptr<HloInstruction> CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config); + const PrecisionConfig& precision_config); // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS @@ -1262,10 +1262,8 @@ class HloInstruction { // information. Transformations to other HLOs will not preserve this // information but it is presumed that the alternate lowering is strictly // superior. - const PrecisionConfigProto& precision_config() const { - return precision_config_; - } - void set_precision_config(const PrecisionConfigProto& precision_config) { + const PrecisionConfig& precision_config() const { return precision_config_; } + void set_precision_config(const PrecisionConfig& precision_config) { precision_config_ = precision_config; } @@ -1680,7 +1678,7 @@ class HloInstruction { // Information used to communicate to the implementation about the algorithm // used to produce results. See the documentation on precision_config(). - PrecisionConfigProto precision_config_; + PrecisionConfig precision_config_; // String identifier for instruction. string name_; @@ -1704,12 +1702,12 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); -string PrecisionToString(const PrecisionConfigProto::Precision& precision); +string PrecisionToString(const PrecisionConfig::Precision& precision); string ConvolutionDimensionNumbersToString( const ConvolutionDimensionNumbers& dnums); StatusOr<RandomDistribution> StringToRandomDistribution(const string& name); -StatusOr<PrecisionConfigProto::Precision> StringToPrecision(const string& name); +StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 9eab6eea80..c1b7c3832b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1752,9 +1752,9 @@ TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) { auto* conv = module->entry_computation()->root_instruction(); auto clone = conv->Clone(); - EXPECT_THAT(clone->precision_config().operand_precision(), - ::testing::ElementsAre(PrecisionConfigProto::HIGH, - PrecisionConfigProto::DEFAULT)); + EXPECT_THAT( + clone->precision_config().operand_precision(), + ::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT)); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e3683aaec9..ad87aa1123 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1630,7 +1630,7 @@ HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config) + const PrecisionConfig& precision_config) : HloInstruction(HloOpcode::kConvolution, shape), feature_group_count_(feature_group_count), window_(window), diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 1c85aa4681..e1215a7566 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -944,7 +944,7 @@ class HloConvolutionInstruction : public HloInstruction { const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config); + const PrecisionConfig& precision_config); const Window& window() const override { return window_; } void set_window(const Window& window) override { window_ = window; } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 62f01c4adb..0f26ed4235 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -221,7 +221,7 @@ class HloParser { bool ParseWindowPad(std::vector<std::vector<tensorflow::int64>>* pad); bool ParseSliceRanges(SliceRanges* result); - bool ParsePrecisionList(std::vector<PrecisionConfigProto::Precision>* result); + bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector<tensorflow::int64>* result); @@ -240,7 +240,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); - bool ParsePrecision(PrecisionConfigProto::Precision* result); + bool ParsePrecision(PrecisionConfig::Precision* result); bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); @@ -909,7 +909,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; - optional<std::vector<PrecisionConfigProto::Precision>> operand_precision; + optional<std::vector<PrecisionConfig::Precision>> operand_precision; attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, &operand_precision}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -922,13 +922,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!feature_group_count) { feature_group_count = 1; } - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; if (operand_precision) { *precision_config.mutable_operand_precision() = { operand_precision->begin(), operand_precision->end()}; } else { precision_config.mutable_operand_precision()->Resize( - operands.size(), PrecisionConfigProto::DEFAULT); + operands.size(), PrecisionConfig::DEFAULT); } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( shape, /*lhs=*/operands[0], /*rhs=*/operands[1], @@ -1279,7 +1279,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional<std::vector<tensorflow::int64>> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; - optional<std::vector<PrecisionConfigProto::Precision>> operand_precision; + optional<std::vector<PrecisionConfig::Precision>> operand_precision; attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, &operand_precision}; @@ -1306,13 +1306,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, rhs_batch_dims->end()}; } - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; if (operand_precision) { *precision_config.mutable_operand_precision() = { operand_precision->begin(), operand_precision->end()}; } else { precision_config.mutable_operand_precision()->Resize( - operands.size(), PrecisionConfigProto::DEFAULT); + operands.size(), PrecisionConfig::DEFAULT); } instruction = builder->AddInstruction(HloInstruction::CreateDot( @@ -2410,11 +2410,11 @@ bool HloParser::ParseAttributeHelper( return ParseDomain(static_cast<DomainData*>(attr_out_ptr)); } case AttrTy::kPrecisionList: { - std::vector<PrecisionConfigProto::Precision> result; + std::vector<PrecisionConfig::Precision> result; if (!ParsePrecisionList(&result)) { return false; } - static_cast<optional<std::vector<PrecisionConfigProto::Precision>>*>( + static_cast<optional<std::vector<PrecisionConfig::Precision>>*>( attr_out_ptr) ->emplace(result); return true; @@ -2698,9 +2698,9 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { // ::= /*empty*/ // ::= precision_val (delim precision_val)* bool HloParser::ParsePrecisionList( - std::vector<PrecisionConfigProto::Precision>* result) { + std::vector<PrecisionConfig::Precision>* result) { auto parse_and_add_item = [&]() { - PrecisionConfigProto::Precision item; + PrecisionConfig::Precision item; if (!ParsePrecision(&item)) { return false; } @@ -3032,7 +3032,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { return true; } -bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) { +bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) { VLOG(1) << "ParsePrecision"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects random distribution"); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 4a71ee909b..37b774b8a5 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -1031,8 +1031,8 @@ bool CanFoldDotIntoIndexedArray( StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, - ScalarIndexedConstantArray* lhs, ConstantArray* rhs) { + const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, + ConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " " << ToString(rhs); if (!CanFoldDotIntoIndexedArray( @@ -1066,7 +1066,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, ConstantArray* lhs, + const PrecisionConfig& precision_config, ConstantArray* lhs, ScalarIndexedConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " " << ToString(rhs); @@ -1101,7 +1101,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot( const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs) { + const PrecisionConfig& precision_config, Array* lhs, Array* rhs) { // Intuitively, if // // - The LHS of a dot product is a gathered sequence of rows from a constant diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index f21e784a4d..9746d176cc 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -267,17 +267,18 @@ class IndexedArrayAnalysis { StatusOr<Array*> ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, - ScalarIndexedConstantArray* lhs, ConstantArray* rhs); + const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, + ConstantArray* rhs); StatusOr<Array*> ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, ConstantArray* lhs, + const PrecisionConfig& precision_config, ConstantArray* lhs, ScalarIndexedConstantArray* rhs); - StatusOr<Array*> ComputeArrayForDot( - const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs); + StatusOr<Array*> ComputeArrayForDot(const Shape& shape, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, + Array* lhs, Array* rhs); // This tries to fold a ScalarIndexedArray which has another // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index e3328203a6..2b2a2eb42a 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1064,9 +1064,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - /*new_size=*/2, PrecisionConfigProto::DEFAULT); + /*new_size=*/2, PrecisionConfig::DEFAULT); auto dot = builder.AddInstruction( HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index edab480091..3df99aac7d 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -121,10 +121,10 @@ StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass, } /* static */ -PrecisionConfigProto HloTestBase::DefaultPrecisionConfig(int operands) { - PrecisionConfigProto precision_config; +PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - operands, PrecisionConfigProto::DEFAULT); + operands, PrecisionConfig::DEFAULT); return precision_config; } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 89e72a045e..21d77c0cc4 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -80,7 +80,7 @@ class HloTestBase : public ::testing::Test { static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass, HloModule* module); - static PrecisionConfigProto DefaultPrecisionConfig(int operands); + static PrecisionConfig DefaultPrecisionConfig(int operands); protected: // This uses the interpreter backend as the reference backend and diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 8e43f275e1..dd329f1181 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -580,7 +580,7 @@ message SourceTarget { // Used to indicate the precision configuration. It has backend specific // meaning. -message PrecisionConfigProto { +message PrecisionConfig { enum Precision { DEFAULT = 0; HIGH = 1; |