aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-09-05 14:01:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 14:06:30 -0700
commit9059375e16a563af1cc208a8f4cb898a4892a396 (patch)
tree0292f7a0a766aa7c960f704514fc8e827e78357f
parent11caab3c138d06390344c88a4149f1897e3d780d (diff)
[XLA] Rename PrecisionConfigProto to PrecisionConfig
The "Proto" suffix adds little clarity but makes a long type name even longer. PiperOrigin-RevId: 211693871
-rw-r--r--tensorflow/compiler/tests/xla_ops_test.py10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc4
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.h10
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc4
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.h6
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.cc6
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.h3
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc12
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.h9
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc4
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc82
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h97
-rw-r--r--tensorflow/compiler/xla/reference_util.cc4
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/graphviz_example.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto2
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h9
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc47
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc26
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc8
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h13
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc6
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h2
-rw-r--r--tensorflow/compiler/xla/xla_data.proto2
39 files changed, 218 insertions, 238 deletions
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index b2f026df6c..3f928a1bea 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -97,9 +97,9 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32))
- PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT,
- xla_data_pb2.PrecisionConfigProto.HIGH,
- xla_data_pb2.PrecisionConfigProto.HIGHEST)
+ PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfig.DEFAULT,
+ xla_data_pb2.PrecisionConfig.HIGH,
+ xla_data_pb2.PrecisionConfig.HIGHEST)
@parameterized.parameters(*PRECISION_VALUES)
def testConv(self, precision):
@@ -120,7 +120,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
precision_config = None
if precision:
- precision_config = xla_data_pb2.PrecisionConfigProto()
+ precision_config = xla_data_pb2.PrecisionConfig()
precision_config.operand_precision.extend([precision, precision])
return xla.conv(
lhs,
@@ -151,7 +151,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
dnums.rhs_batch_dimensions.append(0)
precision_config = None
if precision:
- precision_config = xla_data_pb2.PrecisionConfigProto()
+ precision_config = xla_data_pb2.PrecisionConfig()
precision_config.operand_precision.extend([precision, precision])
return xla.dot_general(
lhs,
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
index 8848623868..fecc7c556e 100644
--- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
@@ -84,7 +84,7 @@ class XlaConvOp : public XlaOpKernel {
private:
xla::ConvolutionDimensionNumbers dnums_;
- xla::PrecisionConfigProto precision_config_;
+ xla::PrecisionConfig precision_config_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp);
};
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
index 2fed53e5c0..40b15b5579 100644
--- a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
@@ -54,7 +54,7 @@ class XlaDotOp : public XlaOpKernel {
private:
xla::DotDimensionNumbers dnums_;
- xla::PrecisionConfigProto precision_config_;
+ xla::PrecisionConfig precision_config_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp);
};
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index d8c050d09e..64f2d781a6 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -28,7 +28,7 @@ namespace tensorflow {
xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
bool transpose_y, bool conjugate_x, bool conjugate_y,
- xla::PrecisionConfigProto::Precision precision) {
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
@@ -96,7 +96,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
y = xla::Conj(y);
}
- xla::PrecisionConfigProto precision_proto;
+ xla::PrecisionConfig precision_proto;
precision_proto.add_operand_precision(precision);
precision_proto.add_operand_precision(precision);
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h
index 6cfccd5553..6edd63a4d3 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.h
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h
@@ -43,11 +43,11 @@ namespace tensorflow {
// It is computed as:
//
// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
-xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false,
- bool transpose_y = false, bool conjugate_x = false,
- bool conjugate_y = false,
- xla::PrecisionConfigProto::Precision precision =
- xla::PrecisionConfigProto::DEFAULT);
+xla::XlaOp BatchDot(
+ xla::XlaOp x, xla::XlaOp y, bool transpose_x = false,
+ bool transpose_y = false, bool conjugate_x = false,
+ bool conjugate_y = false,
+ xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index c50a8de33e..ab3d0a5668 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -50,7 +50,7 @@ namespace {
// l[..., j, j]
// return l
xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
- xla::PrecisionConfigProto::Precision precision) {
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
@@ -150,7 +150,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
} // namespace
xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size,
- xla::PrecisionConfigProto::Precision precision) {
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h
index 60cd7ded53..9a561c34b9 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.h
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.h
@@ -30,9 +30,9 @@ namespace tensorflow {
// TODO(phawkins): check for negative values on the diagonal and return an
// error, instead of silently yielding NaNs.
// TODO(znado): handle the complex Hermitian case
-xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256,
- xla::PrecisionConfigProto::Precision precision =
- xla::PrecisionConfigProto::HIGHEST);
+xla::XlaOp Cholesky(
+ xla::XlaOp a, int64 block_size = 256,
+ xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc
index 0a140fa93c..6b3f2b6e06 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.cc
+++ b/tensorflow/compiler/tf2xla/lib/qr.cc
@@ -150,7 +150,7 @@ struct QRBlockResult {
xla::XlaOp vs; // Shape: [..., m, n]
};
xla::StatusOr<QRBlockResult> QRBlock(
- xla::XlaOp a, xla::PrecisionConfigProto::Precision precision) {
+ xla::XlaOp a, xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int num_dims = xla::ShapeUtil::Rank(a_shape);
@@ -257,7 +257,7 @@ xla::StatusOr<QRBlockResult> QRBlock(
xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
xla::PrimitiveType type, absl::Span<const int64> batch_dims, xla::XlaOp vs,
xla::XlaOp taus, int64 m, int64 n,
- xla::PrecisionConfigProto::Precision precision) {
+ xla::PrecisionConfig::Precision precision) {
std::vector<int64> batch_dim_indices(batch_dims.size());
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
int64 n_index = batch_dims.size() + 1;
@@ -332,7 +332,7 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
// rather than WY transformations.
xla::StatusOr<QRDecompositionResult> QRDecomposition(
xla::XlaOp a, bool full_matrices, int64 block_size,
- xla::PrecisionConfigProto::Precision precision) {
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int num_dims = xla::ShapeUtil::Rank(a_shape);
diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h
index 8a389fb7b0..24b537ac8b 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.h
+++ b/tensorflow/compiler/tf2xla/lib/qr.h
@@ -35,8 +35,7 @@ struct QRDecompositionResult {
xla::StatusOr<QRDecompositionResult> QRDecomposition(
xla::XlaOp a, bool full_matrices, int64 block_size = 128,
- xla::PrecisionConfigProto::Precision precision =
- xla::PrecisionConfigProto::HIGHEST);
+ xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index 37b2240b45..6524c2a9b1 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -110,9 +110,9 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) {
});
}
-xla::XlaOp InvertDiagonalBlocks(
- xla::XlaOp diag_blocks, bool lower, bool transpose_a, bool conjugate_a,
- xla::PrecisionConfigProto::Precision precision) {
+xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower,
+ bool transpose_a, bool conjugate_a,
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = diag_blocks.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
// Input is a batch of square lower triangular square matrices. Its shape is
@@ -216,7 +216,7 @@ xla::XlaOp InvertDiagonalBlocks(
dnums.add_rhs_batch_dimensions(0);
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
- xla::PrecisionConfigProto precision_proto;
+ xla::PrecisionConfig precision_proto;
precision_proto.add_operand_precision(precision);
precision_proto.add_operand_precision(precision);
auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
@@ -245,7 +245,7 @@ xla::XlaOp InvertDiagonalBlocks(
xla::XlaOp SolveWithInvertedDiagonalBlocks(
xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side,
bool lower, bool transpose_a, bool conjugate_a,
- xla::PrecisionConfigProto::Precision precision) {
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape,
@@ -346,7 +346,7 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(
xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
bool lower, bool transpose_a, bool conjugate_a,
int64 block_size,
- xla::PrecisionConfigProto::Precision precision) {
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
index ac42a48352..2303234f36 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
@@ -57,11 +57,10 @@ namespace tensorflow {
//
// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no
// blocking is used.
-xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
- bool lower, bool transpose_a, bool conjugate_a,
- int64 block_size = 128,
- xla::PrecisionConfigProto::Precision precision =
- xla::PrecisionConfigProto::HIGHEST);
+xla::XlaOp TriangularSolve(
+ xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a,
+ bool conjugate_a, int64 block_size = 128,
+ xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 2cd9ae799f..68cfdc1785 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -83,7 +83,7 @@ lhs_dilation: dilation to apply between input elements
rhs_dilation: dilation to apply between kernel elements
feature_group_count: number of feature groups for grouped convolution.
dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
-precision_config: a serialized xla::PrecisionConfigProto proto.
+precision_config: a serialized xla::PrecisionConfig proto.
)doc");
REGISTER_OP("XlaDot")
@@ -102,7 +102,7 @@ Wraps the XLA ConvGeneralDilated operator, documented at
lhs: the LHS tensor
rhs: the RHS tensor
dimension_numbers: a serialized xla::DotDimensionNumbers proto.
-precision_config: a serialized xla::PrecisionConfigProto proto.
+precision_config: a serialized xla::PrecisionConfig proto.
)doc");
REGISTER_OP("XlaDynamicUpdateSlice")
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 7f2125f74c..887b970661 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -820,7 +820,7 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs,
- const PrecisionConfigProto* precision_config_proto) {
+ const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -828,14 +828,13 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs,
dimension_numbers.add_lhs_contracting_dimensions(
lhs_shape.dimensions_size() == 1 ? 0 : 1);
dimension_numbers.add_rhs_contracting_dimensions(0);
- return DotGeneral(lhs, rhs, dimension_numbers, precision_config_proto);
+ return DotGeneral(lhs, rhs, dimension_numbers, precision_config);
});
}
-XlaOp XlaBuilder::DotGeneral(
- const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto* precision_config_proto) {
+XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -844,8 +843,8 @@ XlaOp XlaBuilder::DotGeneral(
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
dimension_numbers));
*instr.mutable_dot_dimension_numbers() = dimension_numbers;
- if (precision_config_proto != nullptr) {
- *instr.mutable_precision_config() = *precision_config_proto;
+ if (precision_config != nullptr) {
+ *instr.mutable_precision_config() = *precision_config;
}
return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
});
@@ -899,28 +898,26 @@ Status XlaBuilder::VerifyConvolution(
XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ const PrecisionConfig* precision_config) {
return ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()),
- feature_group_count, precision_config_proto);
+ feature_group_count, precision_config);
}
XlaOp XlaBuilder::ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return ConvGeneral(lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()),
- feature_group_count, precision_config_proto);
+ feature_group_count, precision_config);
}
XlaOp XlaBuilder::ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -948,7 +945,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
MakePadding(base_area_dimensions, window_dimensions,
window_strides, padding),
dimension_numbers, feature_group_count,
- precision_config_proto);
+ precision_config);
});
}
@@ -956,11 +953,10 @@ XlaOp XlaBuilder::ConvGeneral(
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
dimension_numbers, feature_group_count,
- precision_config_proto);
+ precision_config);
}
XlaOp XlaBuilder::ConvGeneralDilated(
@@ -968,8 +964,7 @@ XlaOp XlaBuilder::ConvGeneralDilated(
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -996,8 +991,8 @@ XlaOp XlaBuilder::ConvGeneralDilated(
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
instr.set_feature_group_count(feature_group_count);
- if (precision_config_proto != nullptr) {
- *instr.mutable_precision_config() = *precision_config_proto;
+ if (precision_config != nullptr) {
+ *instr.mutable_precision_config() = *precision_config;
}
return AddInstruction(std::move(instr), HloOpcode::kConvolution,
@@ -2594,43 +2589,40 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
- const PrecisionConfigProto* precision_config_proto) {
- return lhs.builder()->Dot(lhs, rhs, precision_config_proto);
+ const PrecisionConfig* precision_config) {
+ return lhs.builder()->Dot(lhs, rhs, precision_config);
}
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto* precision_config_proto) {
+ const PrecisionConfig* precision_config) {
return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
- precision_config_proto);
+ precision_config);
}
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
- feature_group_count, precision_config_proto);
+ feature_group_count, precision_config);
}
-XlaOp ConvWithGeneralPadding(
- const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
- absl::Span<const std::pair<int64, int64>> padding,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
- return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides,
- padding, feature_group_count,
- precision_config_proto);
+XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ int64 feature_group_count,
+ const PrecisionConfig* precision_config) {
+ return lhs.builder()->ConvWithGeneralPadding(
+ lhs, rhs, window_strides, padding, feature_group_count, precision_config);
}
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return lhs.builder()->ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
- precision_config_proto);
+ precision_config);
}
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
@@ -2638,10 +2630,10 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ const PrecisionConfig* precision_config) {
return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
dimension_numbers, feature_group_count,
- precision_config_proto);
+ precision_config);
}
XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
@@ -2651,10 +2643,10 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+ const PrecisionConfig* precision_config) {
return lhs.builder()->ConvGeneralDilated(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
- dimension_numbers, feature_group_count, precision_config_proto);
+ dimension_numbers, feature_group_count, precision_config);
}
XlaOp Fft(const XlaOp& operand, FftType fft_type,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 59fbc664f2..58e8f4e7fa 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -496,20 +496,19 @@ class XlaBuilder {
// Enqueues a dot instruction onto the computation.
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a general dot instruction onto the computation.
- XlaOp DotGeneral(
- const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
@@ -518,7 +517,7 @@ class XlaBuilder {
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
@@ -527,29 +526,27 @@ class XlaBuilder {
absl::Span<const int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
- XlaOp ConvGeneral(
- const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> window_strides,
- absl::Span<const std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
- XlaOp ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> window_strides,
- absl::Span<const std::pair<int64, int64>> padding,
- absl::Span<const int64> lhs_dilation,
- absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
@@ -1150,32 +1147,30 @@ class XlaBuilder {
friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
- const PrecisionConfigProto* precision_config_proto);
+ const PrecisionConfig* precision_config);
friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_number,
- const PrecisionConfigProto* precision_config_proto);
+ const PrecisionConfig* precision_config);
friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto);
+ const PrecisionConfig* precision_config);
friend XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto);
+ int64 feature_group_count, const PrecisionConfig* precision_config);
friend XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto);
+ int64 feature_group_count, const PrecisionConfig* precision_config);
friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto);
+ const PrecisionConfig* precision_config);
friend XlaOp ConvGeneralDilated(
const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides,
@@ -1183,8 +1178,7 @@ class XlaBuilder {
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto);
+ int64 feature_group_count, const PrecisionConfig* precision_config);
friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
absl::Span<const int64> fft_length);
friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
@@ -1629,27 +1623,27 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
// Enqueues a dot instruction onto the computation.
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a general dot instruction onto the computation.
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
-XlaOp ConvWithGeneralPadding(
- const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
- absl::Span<const std::pair<int64, int64>> padding,
- int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
@@ -1657,7 +1651,7 @@ XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
@@ -1666,17 +1660,18 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
-XlaOp ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
- absl::Span<const std::pair<int64, int64>> padding,
- absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1,
- const PrecisionConfigProto* precision_config_proto = nullptr);
+XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 8a05d1b0d7..9f1afa2671 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -574,9 +574,9 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
- PrecisionConfigProto precision_config;
+ PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
- /*new_size=*/2, PrecisionConfigProto::DEFAULT);
+ /*new_size=*/2, PrecisionConfig::DEFAULT);
b.AddInstruction(HloInstruction::CreateConvolve(
shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
window, dnums, precision_config));
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 0db74bd038..aa40fba9bb 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2379,9 +2379,9 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
// Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
// after the transformation.
- PrecisionConfigProto precision_config;
- precision_config.add_operand_precision(PrecisionConfigProto::HIGH);
- precision_config.add_operand_precision(PrecisionConfigProto::HIGHEST);
+ PrecisionConfig precision_config;
+ precision_config.add_operand_precision(PrecisionConfig::HIGH);
+ precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
orig_conv->set_precision_config(precision_config);
auto module = CreateNewModule();
@@ -2401,9 +2401,8 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
conv->operand(1)->shape().dimensions(2),
conv->operand(1)->shape().dimensions(3),
testcase.expected_conv_window));
- EXPECT_THAT(
- conv->precision_config().operand_precision(),
- ElementsAre(PrecisionConfigProto::HIGH, PrecisionConfigProto::HIGHEST));
+ EXPECT_THAT(conv->precision_config().operand_precision(),
+ ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST));
}
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index d480d72297..933cf873e0 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -308,9 +308,9 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- PrecisionConfigProto precision_config;
+ PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
- 2, PrecisionConfigProto::DEFAULT);
+ 2, PrecisionConfig::DEFAULT);
HloInstruction* dot = builder.AddInstruction(
HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config));
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 7398f105a0..56bd67fb55 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -1490,9 +1490,9 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- PrecisionConfigProto precision_config;
+ PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
- 2, PrecisionConfigProto::DEFAULT);
+ 2, PrecisionConfig::DEFAULT);
auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
shape_2x4, param_a, param_b, dot_dnums, precision_config));
auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 6bd0a2dd90..0fea462c85 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -38,9 +38,9 @@ std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs,
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- PrecisionConfigProto precision_config;
+ PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
- 2, PrecisionConfigProto::DEFAULT);
+ 2, PrecisionConfig::DEFAULT);
return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
precision_config);
}
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
index 0a49d85c6d..ef70b68877 100644
--- a/tensorflow/compiler/xla/service/graphviz_example.cc
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -112,9 +112,9 @@ std::unique_ptr<HloModule> MakeBigGraph() {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- PrecisionConfigProto precision_config;
+ PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
- /*new_size=*/2, PrecisionConfigProto::DEFAULT);
+ /*new_size=*/2, PrecisionConfig::DEFAULT);
auto dot = builder.AddInstruction(HloInstruction::CreateDot(
vshape, clamp, param_v0, dot_dnums, precision_config));
auto tuple = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 58b7af93eb..99d0cf50ca 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -172,7 +172,7 @@ message HloInstructionProto {
xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
// Precision configuration for the instruction. Has backend-specific meaning.
- xla.PrecisionConfigProto precision_config = 51;
+ xla.PrecisionConfig precision_config = 51;
// Collective permute field.
repeated SourceTarget source_target_pairs = 52;
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index a2c1ce34c6..2aaaef1d36 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -601,9 +601,9 @@ TEST_F(HloComputationTest, Stringification) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- PrecisionConfigProto precision_config;
+ PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
- 2, PrecisionConfigProto::DEFAULT);
+ 2, PrecisionConfig::DEFAULT);
builder.AddInstruction(
HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
@@ -636,9 +636,9 @@ TEST_F(HloComputationTest, StringificationIndent) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- PrecisionConfigProto precision_config;
+ PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
- 2, PrecisionConfigProto::DEFAULT);
+ 2, PrecisionConfig::DEFAULT);
builder.AddInstruction(
HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
@@ -672,9 +672,9 @@ TEST_F(HloComputationTest, StringificationCanonical) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- PrecisionConfigProto precision_config;
+ PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
- 2, PrecisionConfigProto::DEFAULT);
+ 2, PrecisionConfig::DEFAULT);
builder.AddInstruction(
HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index a6ae0337a5..a3fcc0fefa 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -63,7 +63,7 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
StatusOr<HloInstruction*> MakeConvolveHlo(
HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto& precision_config) {
+ const PrecisionConfig& precision_config) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
TF_ASSIGN_OR_RETURN(Shape convolve_shape,
@@ -167,10 +167,9 @@ StatusOr<HloInstruction*> MakeConcatHlo(
HloInstruction::CreateConcatenate(concat_shape, operands, dimension));
}
-StatusOr<HloInstruction*> MakeDotHlo(
- HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dim_numbers,
- const PrecisionConfigProto& precision_config) {
+StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfig& precision_config) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
TF_ASSIGN_OR_RETURN(
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index 1c82956907..b22058abb4 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -50,7 +50,7 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
StatusOr<HloInstruction*> MakeConvolveHlo(
HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto& precision_config);
+ const PrecisionConfig& precision_config);
// Creates a transpose HLO instruction and adds it to the computation containing
// `operand`.
@@ -98,10 +98,9 @@ StatusOr<HloInstruction*> MakeConcatHlo(
// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
// and `rhs` (both must be in the same computation).
-StatusOr<HloInstruction*> MakeDotHlo(
- HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dim_numbers,
- const PrecisionConfigProto& precision_config);
+StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfig& precision_config);
// Creates a Map HLO instruction and adds it to the computation containing the
// operands. All operands must be in the same computation.
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 62eea2b06c..72b236801a 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -2334,9 +2334,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- PrecisionConfigProto precision_config;
+ PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
- 2, PrecisionConfigProto::DEFAULT);
+ 2, PrecisionConfig::DEFAULT);
auto dot = builder.AddInstruction(
HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index ffb3451164..d0d955fea8 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -345,7 +345,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
const DotDimensionNumbers& dim_numbers,
- const PrecisionConfigProto& precision_config, const Literal& lhs,
+ const PrecisionConfig& precision_config, const Literal& lhs,
const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
HloInstruction::CreateConstant(lhs.CloneToUnique());
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index e13af8e999..72252bafc7 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -116,7 +116,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
const DotDimensionNumbers& dim_numbers,
- const PrecisionConfigProto& precision_config, const Literal& lhs,
+ const PrecisionConfig& precision_config, const Literal& lhs,
const Literal& rhs);
protected:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index f25761ac70..471a12d6aa 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -347,9 +347,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_window());
TF_RET_CHECK(proto.has_convolution_dimension_numbers());
- PrecisionConfigProto precision_config = proto.precision_config();
+ PrecisionConfig precision_config = proto.precision_config();
precision_config.mutable_operand_precision()->Resize(
- proto.operand_ids_size(), PrecisionConfigProto::DEFAULT);
+ proto.operand_ids_size(), PrecisionConfig::DEFAULT);
instruction = CreateConvolve(
proto.shape(), operands(0), operands(1),
std::max<int64>(proto.feature_group_count(), 1), proto.window(),
@@ -475,7 +475,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
if (instruction->opcode() == HloOpcode::kDot) {
instruction->precision_config_ = proto.precision_config();
instruction->precision_config_.mutable_operand_precision()->Resize(
- instruction->operand_count(), PrecisionConfigProto::DEFAULT);
+ instruction->operand_count(), PrecisionConfig::DEFAULT);
TF_RET_CHECK(proto.has_dot_dimension_numbers());
instruction->dot_dimension_numbers_ =
absl::make_unique<DotDimensionNumbers>(
@@ -657,7 +657,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
int64 feature_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto& precision_config) {
+ const PrecisionConfig& precision_config) {
return absl::make_unique<HloConvolutionInstruction>(
shape, lhs, rhs, feature_group_count, window, dimension_numbers,
precision_config);
@@ -673,7 +673,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto& precision_config) {
+ const PrecisionConfig& precision_config) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
instruction->AppendOperand(lhs);
@@ -2888,8 +2888,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) {
return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
}
-string PrecisionToString(const PrecisionConfigProto::Precision& precision) {
- return absl::AsciiStrToLower(PrecisionConfigProto::Precision_Name(precision));
+string PrecisionToString(const PrecisionConfig::Precision& precision) {
+ return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
}
string ConvolutionDimensionNumbersToString(
@@ -2967,32 +2967,31 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
string HloInstruction::PrecisionConfigToString() const {
if (absl::c_all_of(
precision_config_.operand_precision(), [](int32 precision) {
- return static_cast<PrecisionConfigProto::Precision>(precision) ==
- PrecisionConfigProto::DEFAULT;
+ return static_cast<PrecisionConfig::Precision>(precision) ==
+ PrecisionConfig::DEFAULT;
})) {
return "";
}
return StrCat(
"operand_precision={",
- StrJoin(precision_config_.operand_precision(), ",",
- [](string* out, int32 precision) {
- CHECK(PrecisionConfigProto::Precision_IsValid(precision))
- << precision;
- StrAppend(out, PrecisionToString(
- static_cast<PrecisionConfigProto::Precision>(
- precision)));
- }),
+ StrJoin(
+ precision_config_.operand_precision(), ",",
+ [](string* out, int32 precision) {
+ CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
+ StrAppend(out,
+ PrecisionToString(
+ static_cast<PrecisionConfig::Precision>(precision)));
+ }),
"}");
}
-StatusOr<PrecisionConfigProto::Precision> StringToPrecision(
- const string& name) {
- static std::unordered_map<string, PrecisionConfigProto::Precision>* map = [] {
+StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
+ static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
static auto* map =
- new std::unordered_map<string, PrecisionConfigProto::Precision>;
- for (int i = 0; i < PrecisionConfigProto::Precision_ARRAYSIZE; i++) {
- if (PrecisionConfigProto::Precision_IsValid(i)) {
- auto value = static_cast<PrecisionConfigProto::Precision>(i);
+ new std::unordered_map<string, PrecisionConfig::Precision>;
+ for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
+ if (PrecisionConfig::Precision_IsValid(i)) {
+ auto value = static_cast<PrecisionConfig::Precision>(i);
(*map)[PrecisionToString(value)] = value;
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 55d592ff94..691f8155f9 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -407,7 +407,7 @@ class HloInstruction {
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
int64 feature_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto& precision_config);
+ const PrecisionConfig& precision_config);
// Creates an FFT op, of the type indicated by fft_type.
static std::unique_ptr<HloInstruction> CreateFft(
@@ -419,7 +419,7 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto& precision_config);
+ const PrecisionConfig& precision_config);
// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
// of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS
@@ -1262,10 +1262,8 @@ class HloInstruction {
// information. Transformations to other HLOs will not preserve this
// information but it is presumed that the alternate lowering is strictly
// superior.
- const PrecisionConfigProto& precision_config() const {
- return precision_config_;
- }
- void set_precision_config(const PrecisionConfigProto& precision_config) {
+ const PrecisionConfig& precision_config() const { return precision_config_; }
+ void set_precision_config(const PrecisionConfig& precision_config) {
precision_config_ = precision_config;
}
@@ -1680,7 +1678,7 @@ class HloInstruction {
// Information used to communicate to the implementation about the algorithm
// used to produce results. See the documentation on precision_config().
- PrecisionConfigProto precision_config_;
+ PrecisionConfig precision_config_;
// String identifier for instruction.
string name_;
@@ -1704,12 +1702,12 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
string PaddingConfigToString(const PaddingConfig& padding);
string OpMetadataToString(const OpMetadata& metadata);
string RandomDistributionToString(const RandomDistribution& distribution);
-string PrecisionToString(const PrecisionConfigProto::Precision& precision);
+string PrecisionToString(const PrecisionConfig::Precision& precision);
string ConvolutionDimensionNumbersToString(
const ConvolutionDimensionNumbers& dnums);
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
-StatusOr<PrecisionConfigProto::Precision> StringToPrecision(const string& name);
+StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name);
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 9eab6eea80..c1b7c3832b 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1752,9 +1752,9 @@ TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) {
auto* conv = module->entry_computation()->root_instruction();
auto clone = conv->Clone();
- EXPECT_THAT(clone->precision_config().operand_precision(),
- ::testing::ElementsAre(PrecisionConfigProto::HIGH,
- PrecisionConfigProto::DEFAULT));
+ EXPECT_THAT(
+ clone->precision_config().operand_precision(),
+ ::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e3683aaec9..ad87aa1123 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1630,7 +1630,7 @@ HloConvolutionInstruction::HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
int64 feature_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto& precision_config)
+ const PrecisionConfig& precision_config)
: HloInstruction(HloOpcode::kConvolution, shape),
feature_group_count_(feature_group_count),
window_(window),
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 1c85aa4681..e1215a7566 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -944,7 +944,7 @@ class HloConvolutionInstruction : public HloInstruction {
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
int64 feature_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
- const PrecisionConfigProto& precision_config);
+ const PrecisionConfig& precision_config);
const Window& window() const override { return window_; }
void set_window(const Window& window) override { window_ = window; }
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 62f01c4adb..0f26ed4235 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -221,7 +221,7 @@ class HloParser {
bool ParseWindowPad(std::vector<std::vector<tensorflow::int64>>* pad);
bool ParseSliceRanges(SliceRanges* result);
- bool ParsePrecisionList(std::vector<PrecisionConfigProto::Precision>* result);
+ bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<tensorflow::int64>* result);
@@ -240,7 +240,7 @@ class HloParser {
bool ParseFftType(FftType* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseRandomDistribution(RandomDistribution* result);
- bool ParsePrecision(PrecisionConfigProto::Precision* result);
+ bool ParsePrecision(PrecisionConfig::Precision* result);
bool ParseInt64(tensorflow::int64* result);
bool ParseDouble(double* result);
bool ParseBool(bool* result);
@@ -909,7 +909,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
- optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
+ optional<std::vector<PrecisionConfig::Precision>> operand_precision;
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
&operand_precision};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
@@ -922,13 +922,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!feature_group_count) {
feature_group_count = 1;
}
- PrecisionConfigProto precision_config;
+ PrecisionConfig precision_config;
if (operand_precision) {
*precision_config.mutable_operand_precision() = {
operand_precision->begin(), operand_precision->end()};
} else {
precision_config.mutable_operand_precision()->Resize(
- operands.size(), PrecisionConfigProto::DEFAULT);
+ operands.size(), PrecisionConfig::DEFAULT);
}
instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
@@ -1279,7 +1279,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<std::vector<tensorflow::int64>> rhs_batch_dims;
attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&rhs_batch_dims};
- optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
+ optional<std::vector<PrecisionConfig::Precision>> operand_precision;
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
&operand_precision};
@@ -1306,13 +1306,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
rhs_batch_dims->end()};
}
- PrecisionConfigProto precision_config;
+ PrecisionConfig precision_config;
if (operand_precision) {
*precision_config.mutable_operand_precision() = {
operand_precision->begin(), operand_precision->end()};
} else {
precision_config.mutable_operand_precision()->Resize(
- operands.size(), PrecisionConfigProto::DEFAULT);
+ operands.size(), PrecisionConfig::DEFAULT);
}
instruction = builder->AddInstruction(HloInstruction::CreateDot(
@@ -2410,11 +2410,11 @@ bool HloParser::ParseAttributeHelper(
return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
}
case AttrTy::kPrecisionList: {
- std::vector<PrecisionConfigProto::Precision> result;
+ std::vector<PrecisionConfig::Precision> result;
if (!ParsePrecisionList(&result)) {
return false;
}
- static_cast<optional<std::vector<PrecisionConfigProto::Precision>>*>(
+ static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
attr_out_ptr)
->emplace(result);
return true;
@@ -2698,9 +2698,9 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) {
// ::= /*empty*/
// ::= precision_val (delim precision_val)*
bool HloParser::ParsePrecisionList(
- std::vector<PrecisionConfigProto::Precision>* result) {
+ std::vector<PrecisionConfig::Precision>* result) {
auto parse_and_add_item = [&]() {
- PrecisionConfigProto::Precision item;
+ PrecisionConfig::Precision item;
if (!ParsePrecision(&item)) {
return false;
}
@@ -3032,7 +3032,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
return true;
}
-bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) {
+bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) {
VLOG(1) << "ParsePrecision";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects random distribution");
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 4a71ee909b..37b774b8a5 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -1031,8 +1031,8 @@ bool CanFoldDotIntoIndexedArray(
StatusOr<Analysis::Array*>
IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- const PrecisionConfigProto& precision_config,
- ScalarIndexedConstantArray* lhs, ConstantArray* rhs) {
+ const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
+ ConstantArray* rhs) {
VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " "
<< ToString(rhs);
if (!CanFoldDotIntoIndexedArray(
@@ -1066,7 +1066,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
StatusOr<Analysis::Array*>
IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- const PrecisionConfigProto& precision_config, ConstantArray* lhs,
+ const PrecisionConfig& precision_config, ConstantArray* lhs,
ScalarIndexedConstantArray* rhs) {
VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " "
<< ToString(rhs);
@@ -1101,7 +1101,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs) {
+ const PrecisionConfig& precision_config, Array* lhs, Array* rhs) {
// Intuitively, if
//
// - The LHS of a dot product is a gathered sequence of rows from a constant
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index f21e784a4d..9746d176cc 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -267,17 +267,18 @@ class IndexedArrayAnalysis {
StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- const PrecisionConfigProto& precision_config,
- ScalarIndexedConstantArray* lhs, ConstantArray* rhs);
+ const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
+ ConstantArray* rhs);
StatusOr<Array*> ComputeArrayForDotWithIndexedRhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- const PrecisionConfigProto& precision_config, ConstantArray* lhs,
+ const PrecisionConfig& precision_config, ConstantArray* lhs,
ScalarIndexedConstantArray* rhs);
- StatusOr<Array*> ComputeArrayForDot(
- const Shape& shape, const DotDimensionNumbers& dim_numbers,
- const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs);
+ StatusOr<Array*> ComputeArrayForDot(const Shape& shape,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfig& precision_config,
+ Array* lhs, Array* rhs);
// This tries to fold a ScalarIndexedArray which has another
// ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index e3328203a6..2b2a2eb42a 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -1064,9 +1064,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- PrecisionConfigProto precision_config;
+ PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
- /*new_size=*/2, PrecisionConfigProto::DEFAULT);
+ /*new_size=*/2, PrecisionConfig::DEFAULT);
auto dot = builder.AddInstruction(
HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index edab480091..3df99aac7d 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -121,10 +121,10 @@ StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass,
}
/* static */
-PrecisionConfigProto HloTestBase::DefaultPrecisionConfig(int operands) {
- PrecisionConfigProto precision_config;
+PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) {
+ PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
- operands, PrecisionConfigProto::DEFAULT);
+ operands, PrecisionConfig::DEFAULT);
return precision_config;
}
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 89e72a045e..21d77c0cc4 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -80,7 +80,7 @@ class HloTestBase : public ::testing::Test {
static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass,
HloModule* module);
- static PrecisionConfigProto DefaultPrecisionConfig(int operands);
+ static PrecisionConfig DefaultPrecisionConfig(int operands);
protected:
// This uses the interpreter backend as the reference backend and
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 8e43f275e1..dd329f1181 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -580,7 +580,7 @@ message SourceTarget {
// Used to indicate the precision configuration. It has backend specific
// meaning.
-message PrecisionConfigProto {
+message PrecisionConfig {
enum Precision {
DEFAULT = 0;
HIGH = 1;