aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc86
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h66
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc17
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.cc1
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc1
-rw-r--r--tensorflow/compiler/xla/service/dot_decomposer.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h31
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc58
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc2
-rw-r--r--tensorflow/compiler/xla/xla_data.proto15
13 files changed, 277 insertions, 59 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index c6976fd849..7bc6e8d860 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
@@ -808,7 +809,8 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs,
return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions);
}
-XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) {
+XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs,
+ const PrecisionConfigProto* precision_config_proto) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -816,12 +818,14 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) {
dimension_numbers.add_lhs_contracting_dimensions(
lhs_shape.dimensions_size() == 1 ? 0 : 1);
dimension_numbers.add_rhs_contracting_dimensions(0);
- return DotGeneral(lhs, rhs, dimension_numbers);
+ return DotGeneral(lhs, rhs, dimension_numbers, precision_config_proto);
});
}
-XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers) {
+XlaOp XlaBuilder::DotGeneral(
+ const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto* precision_config_proto) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -830,6 +834,9 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
dimension_numbers));
*instr.mutable_dot_dimension_numbers() = dimension_numbers;
+ if (precision_config_proto != nullptr) {
+ *instr.mutable_precision_config() = *precision_config_proto;
+ }
return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
});
}
@@ -883,28 +890,31 @@ Status XlaBuilder::VerifyConvolution(
XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding, int64 feature_group_count) {
+ Padding padding, int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto) {
return ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()),
- feature_group_count);
+ feature_group_count, precision_config_proto);
}
XlaOp XlaBuilder::ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- int64 feature_group_count) {
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto) {
return ConvGeneral(lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()),
- feature_group_count);
+ feature_group_count, precision_config_proto);
}
XlaOp XlaBuilder::ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -931,7 +941,8 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
return ConvGeneral(lhs, rhs, window_strides,
MakePadding(base_area_dimensions, window_dimensions,
window_strides, padding),
- dimension_numbers, feature_group_count);
+ dimension_numbers, feature_group_count,
+ precision_config_proto);
});
}
@@ -940,9 +951,11 @@ XlaOp XlaBuilder::ConvGeneral(
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto) {
return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
- dimension_numbers, feature_group_count);
+ dimension_numbers, feature_group_count,
+ precision_config_proto);
}
XlaOp XlaBuilder::ConvGeneralDilated(
@@ -952,7 +965,8 @@ XlaOp XlaBuilder::ConvGeneralDilated(
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -979,6 +993,10 @@ XlaOp XlaBuilder::ConvGeneralDilated(
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
instr.set_feature_group_count(feature_group_count);
+ if (precision_config_proto != nullptr) {
+ *instr.mutable_precision_config() = *precision_config_proto;
+ }
+
return AddInstruction(std::move(instr), HloOpcode::kConvolution,
{lhs, rhs});
});
@@ -2548,48 +2566,57 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
return lhs.builder()->Le(lhs, rhs, broadcast_dimensions);
}
-XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs) {
- return lhs.builder()->Dot(lhs, rhs);
+XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
+ const PrecisionConfigProto* precision_config_proto) {
+ return lhs.builder()->Dot(lhs, rhs, precision_config_proto);
}
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers) {
- return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers);
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto* precision_config_proto) {
+ return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
+ precision_config_proto);
}
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- int64 feature_group_count) {
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto) {
return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
- feature_group_count);
+ feature_group_count, precision_config_proto);
}
XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- int64 feature_group_count) {
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto) {
return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides,
- padding, feature_group_count);
+ padding, feature_group_count,
+ precision_config_proto);
}
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
- return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides,
- padding, dimension_numbers,
- feature_group_count);
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto) {
+ return lhs.builder()->ConvWithGeneralDimensions(
+ lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
+ precision_config_proto);
}
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto) {
return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
- dimension_numbers, feature_group_count);
+ dimension_numbers, feature_group_count,
+ precision_config_proto);
}
XlaOp ConvGeneralDilated(
@@ -2599,10 +2626,11 @@ XlaOp ConvGeneralDilated(
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto) {
return lhs.builder()->ConvGeneralDilated(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
- dimension_numbers, feature_group_count);
+ dimension_numbers, feature_group_count, precision_config_proto);
}
XlaOp Fft(const XlaOp& operand, FftType fft_type,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 089967147f..8d9ec9a18a 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -501,17 +501,21 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a dot instruction onto the computation.
- XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+ XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a general dot instruction onto the computation.
- XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers);
+ XlaOp DotGeneral(
+ const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- int64 feature_group_count = 1);
+ int64 feature_group_count = 1,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
@@ -519,7 +523,8 @@ class XlaBuilder {
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- int64 feature_group_count = 1);
+ int64 feature_group_count = 1,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
@@ -527,7 +532,8 @@ class XlaBuilder {
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ int64 feature_group_count = 1,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
@@ -536,7 +542,8 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ int64 feature_group_count = 1,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
@@ -547,7 +554,8 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ int64 feature_group_count = 1,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
@@ -1146,28 +1154,34 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+ friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
+ const PrecisionConfigProto* precision_config_proto);
friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers);
+ const DotDimensionNumbers& dimension_number,
+ const PrecisionConfigProto* precision_config_proto);
friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding, int64 feature_group_count);
+ Padding padding, int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto);
friend XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- int64 feature_group_count);
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto);
friend XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count);
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto);
friend XlaOp ConvGeneral(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count);
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto);
friend XlaOp ConvGeneralDilated(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
@@ -1175,7 +1189,8 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count);
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto);
friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
tensorflow::gtl::ArraySlice<int64> fft_length);
friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
@@ -1626,17 +1641,20 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a dot instruction onto the computation.
-XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a general dot instruction onto the computation.
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers);
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- int64 feature_group_count = 1);
+ int64 feature_group_count = 1,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
@@ -1644,7 +1662,8 @@ XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- int64 feature_group_count = 1);
+ int64 feature_group_count = 1,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
@@ -1652,7 +1671,8 @@ XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ int64 feature_group_count = 1,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
@@ -1660,7 +1680,8 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ int64 feature_group_count = 1,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
@@ -1671,7 +1692,8 @@ XlaOp ConvGeneralDilated(
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ int64 feature_group_count = 1,
+ const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 0a040b5d16..b86b7d2e71 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -268,7 +268,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
- const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim,
+ const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped);
StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
@@ -829,18 +829,18 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
TF_ASSIGN_OR_RETURN(
HloInstruction * optimized_lhs_concat,
- OptimizeDotOfConcatHelper(dot->shape(), lhs, lhs_contracting_dim, rhs,
+ OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs,
rhs_contracting_dim, /*swapped=*/false));
if (optimized_lhs_concat) {
return optimized_lhs_concat;
}
- return OptimizeDotOfConcatHelper(dot->shape(), rhs, rhs_contracting_dim, lhs,
+ return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs,
lhs_contracting_dim, /*swapped=*/true);
}
StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
- const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim,
+ const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) {
bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate &&
lhs->concatenate_dimension() == lhs_contracting_dim &&
@@ -939,11 +939,12 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
}
auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
- dot_shape, new_dot_lhs, new_dot_rhs, new_dot_dnums));
+ dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums));
+ new_dot->set_precision_config(dot.precision_config());
if (add_result) {
add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
- dot_shape, HloOpcode::kAdd, add_result, new_dot));
+ dot.shape(), HloOpcode::kAdd, add_result, new_dot));
} else {
add_result = new_dot;
}
@@ -1042,6 +1043,7 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n});
auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot(
memoized_shape, left_operand, right_operand, dnums));
+ memoized_inst->set_precision_config(dot->precision_config());
// Get pair {start, 0} or {0, start}.
HloInstruction* original_start_indices =
lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1);
@@ -1139,6 +1141,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
rhs->mutable_operand(0), lhs->mutable_operand(0),
dot_dimension_numbers));
+ new_dot->set_precision_config(dot->precision_config());
return ReplaceWithNewInstruction(
dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
}
@@ -2297,6 +2300,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers));
+ dot->set_precision_config(convolution->precision_config());
+
return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
}
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
index b226e7ecb0..be6fbcc9e3 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
@@ -64,6 +64,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
MakeDotHlo(new_lhs, new_rhs, new_dim_numbers));
+ new_dot->set_precision_config(batch_dot->precision_config());
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped,
MakeReshapeHlo(batch_dot->shape(), new_dot));
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
index 8affa08b65..9c81a86bbb 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
@@ -224,6 +224,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
auto new_convolution = HloInstruction::CreateConvolve(
convolution->shape(), convolution->mutable_operand(0), new_filter,
convolution->window(), dim_numbers, /*feature_group_count=*/1);
+ new_convolution->set_precision_config(convolution->precision_config());
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
convolution, std::move(new_convolution)));
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
index 0985b9297f..098ce17a56 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
@@ -132,6 +132,7 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
HloInstruction* new_conv = module->entry_computation()->AddInstruction(
HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel,
hlo->window(), new_dnums));
+ new_conv->set_precision_config(hlo->precision_config());
// Reshape the output back to the shape of the original convolution.
TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction(
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc
index 12faed6967..09cb10d6ee 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.cc
+++ b/tensorflow/compiler/xla/service/dot_decomposer.cc
@@ -136,6 +136,7 @@ Status DecomposeBatchDot(HloInstruction* dot) {
dot_dnums.add_rhs_contracting_dimensions(0);
auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot(
dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums));
+ dot_r2->set_precision_config(dot->precision_config());
// Reshape Dot to R3 so we can concat along batch dimension.
auto dot_r3 = computation->AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index fa218657fe..12b609a60f 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
-// Next ID: 51
+// Next ID: 52
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -171,6 +171,9 @@ message HloInstructionProto {
bool is_host_transfer = 47;
xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
+
+ // Precision configuration for the instruction. Has backend-specific meaning.
+ xla.PrecisionConfigProto precision_config = 51;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 8a9856c1da..9d795da100 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -444,6 +444,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
+ instruction->precision_config_ = proto.precision_config();
if (proto.has_dot_dimension_numbers()) {
instruction->dot_dimension_numbers_ =
@@ -1019,6 +1020,7 @@ void HloInstruction::SetupDerivedInstruction(
derived_instruction->clear_sharding();
}
derived_instruction->set_metadata(metadata_);
+ derived_instruction->set_precision_config(precision_config_);
}
bool HloInstruction::HasSideEffectNoRecurse() const {
@@ -1279,6 +1281,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
}
break;
}
+ // SetupDerivedInstruction will setup the precision_config_ field.
SetupDerivedInstruction(clone.get());
clone->set_parent(parent_);
clone->set_raw_backend_config_string(backend_config_);
@@ -2000,6 +2003,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(DotDimensionNumbersToString());
}
+ string precision_config_string = PrecisionConfigToString();
+ if (!precision_config_string.empty()) {
+ extra.push_back(precision_config_string);
+ }
+
if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
if (opcode() == HloOpcode::kWhile) {
@@ -2121,6 +2129,7 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_metadata() = metadata_;
proto.set_backend_config(backend_config_);
+ *proto.mutable_precision_config() = precision_config_;
if (opcode() != HloOpcode::kFusion) {
for (const HloComputation* computation : called_computations_) {
proto.add_called_computation_ids(computation->unique_id());
@@ -2819,6 +2828,11 @@ string RandomDistributionToString(const RandomDistribution& distribution) {
return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution));
}
+string PrecisionToString(const PrecisionConfigProto::Precision& precision) {
+ return tensorflow::str_util::Lowercase(
+ PrecisionConfigProto::Precision_Name(precision));
+}
+
string ConvolutionDimensionNumbersToString(
const ConvolutionDimensionNumbers& dnums) {
// lhs_dims[i] is the symbol of the logical dimension i for the lhs
@@ -2889,6 +2903,44 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
return found->second;
}
+string HloInstruction::PrecisionConfigToString() const {
+ if (precision_config_.operand_precision().empty()) {
+ return "";
+ }
+ return StrCat(
+ "operand_precision={",
+ Join(precision_config_.operand_precision(), ",",
+ [](string* out, int32 precision) {
+ CHECK(PrecisionConfigProto::Precision_IsValid(precision))
+ << precision;
+ StrAppend(
+ out,
+ PrecisionToString(
+ static_cast<PrecisionConfigProto::Precision>(precision)));
+ }),
+ "}");
+}
+
+StatusOr<PrecisionConfigProto::Precision> StringToPrecision(
+ const string& name) {
+ static std::unordered_map<string, PrecisionConfigProto::Precision>* map = [] {
+ static auto* map =
+ new std::unordered_map<string, PrecisionConfigProto::Precision>;
+ for (int i = 0; i < PrecisionConfigProto::Precision_ARRAYSIZE; i++) {
+ if (PrecisionConfigProto::Precision_IsValid(i)) {
+ auto value = static_cast<PrecisionConfigProto::Precision>(i);
+ (*map)[PrecisionToString(value)] = value;
+ }
+ }
+ return map;
+ }();
+ auto found = map->find(tensorflow::str_util::Lowercase(name));
+ if (found == map->end()) {
+ return InvalidArgument("Unknown distribution");
+ }
+ return found->second;
+}
+
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
return os << ToString(kind);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 69397a4b37..21710bd31d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -102,6 +102,7 @@ class HloPrintOptions {
return HloPrintOptions()
.set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
.set_print_metadata(false)
+ .set_print_backend_config(false)
.set_compact_operands(true)
.set_print_operand_shape(true)
.set_print_program_shape(false)
@@ -183,7 +184,7 @@ class HloPrintOptions {
return print_subcomputation_mode_;
}
bool print_metadata() const { return print_metadata_; }
- bool print_backend_config() const { return print_metadata_; }
+ bool print_backend_config() const { return print_backend_config_; }
bool compact_operands() const { return compact_operands_; }
bool print_operand_shape() const { return print_operand_shape_; }
bool print_program_shape() const { return print_program_shape_; }
@@ -858,6 +859,11 @@ class HloInstruction {
return false;
}
+ if (!ContainersEqual(precision_config_.operand_precision(),
+ other.precision_config_.operand_precision())) {
+ return false;
+ }
+
return IdenticalSlowPath(other, eq_computations);
}
@@ -1105,6 +1111,9 @@ class HloInstruction {
// Returns the dump string of the dot dimension numbers.
string DotDimensionNumbersToString() const;
+ // Returns the dump string of the precision configuration.
+ string PrecisionConfigToString() const;
+
// Clones the HLO instruction. The clone will have the same opcode, shape, and
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
@@ -1248,6 +1257,20 @@ class HloInstruction {
static StatusOr<string> BackendConfigToRawString(
const tensorflow::protobuf::Message& proto);
+ // Returns the information used to tell the implementation information about
+ // what sort of precision is requested. The meaning of the field is backend
+ // specific. At the moment, it is only supported for kConvolution and kDot.
+ // Transformations on one kDot or kConvolution to another will preserve this
+ // information. Transformations to other HLOs will not preserve this
+ // information but it is presumed that the alternate lowering is strictly
+ // superior.
+ const PrecisionConfigProto& precision_config() const {
+ return precision_config_;
+ }
+ void set_precision_config(const PrecisionConfigProto& precision_config) {
+ precision_config_ = precision_config;
+ }
+
// Sets the debug metadata for this instruction.
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
const OpMetadata& metadata() const { return metadata_; }
@@ -1653,6 +1676,10 @@ class HloInstruction {
// HLO. See the documentation on backend_config().
string backend_config_;
+ // Information used to communicate to the implementation about the algorithm
+ // used to produce results. See the documentation on precision_config().
+ PrecisionConfigProto precision_config_;
+
// String identifier for instruction.
string name_;
@@ -1675,10 +1702,12 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
string PaddingConfigToString(const PaddingConfig& padding);
string OpMetadataToString(const OpMetadata& metadata);
string RandomDistributionToString(const RandomDistribution& distribution);
+string PrecisionToString(const PrecisionConfigProto::Precision& precision);
string ConvolutionDimensionNumbersToString(
const ConvolutionDimensionNumbers& dnums);
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
+StatusOr<PrecisionConfigProto::Precision> StringToPrecision(const string& name);
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index b4793998ec..ede55510d3 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -155,6 +155,7 @@ class HloParser {
kFusionKind,
kDistribution,
kDomain,
+ kPrecisionList,
};
struct AttrConfig {
@@ -220,6 +221,7 @@ class HloParser {
bool ParseWindowPad(std::vector<std::vector<tensorflow::int64>>* pad);
bool ParseSliceRanges(SliceRanges* result);
+ bool ParsePrecisionList(std::vector<PrecisionConfigProto::Precision>* result);
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<tensorflow::int64>* result);
@@ -238,6 +240,7 @@ class HloParser {
bool ParseFftType(FftType* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseRandomDistribution(RandomDistribution* result);
+ bool ParsePrecision(PrecisionConfigProto::Precision* result);
bool ParseInt64(tensorflow::int64* result);
bool ParseDouble(double* result);
bool ParseBool(bool* result);
@@ -502,6 +505,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
&backend_config};
+ optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
+ attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
+ &operand_precision};
+
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
@@ -1366,6 +1373,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (backend_config) {
instruction->set_raw_backend_config_string(std::move(*backend_config));
}
+ if (operand_precision) {
+ PrecisionConfigProto precision_config;
+ *precision_config.mutable_operand_precision() = {operand_precision->begin(),
+ operand_precision->end()};
+ instruction->set_precision_config(precision_config);
+ }
return AddInstruction(name, instruction, name_loc);
} // NOLINT(readability/fn_size)
@@ -2343,6 +2356,16 @@ bool HloParser::ParseAttributeHelper(
case AttrTy::kDomain: {
return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
}
+ case AttrTy::kPrecisionList: {
+ std::vector<PrecisionConfigProto::Precision> result;
+ if (!ParsePrecisionList(&result)) {
+ return false;
+ }
+ static_cast<optional<std::vector<PrecisionConfigProto::Precision>>*>(
+ attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
}
}();
if (!success) {
@@ -2615,6 +2638,24 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) {
return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
}
+// precisionlist ::= start precision_elements end
+// precision_elements
+// ::= /*empty*/
+// ::= precision_val (delim precision_val)*
+bool HloParser::ParsePrecisionList(
+ std::vector<PrecisionConfigProto::Precision>* result) {
+ auto parse_and_add_item = [&]() {
+ PrecisionConfigProto::Precision item;
+ if (!ParsePrecision(&item)) {
+ return false;
+ }
+ result->push_back(item);
+ return true;
+ };
+ return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
+ parse_and_add_item);
+}
+
// int64list ::= start int64_elements end
// int64_elements
// ::= /*empty*/
@@ -2941,6 +2982,23 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
return true;
}
+bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) {
+ VLOG(1) << "ParsePrecision";
+ if (lexer_.GetKind() != TokKind::kIdent) {
+ return TokenError("expects random distribution");
+ }
+ string val = lexer_.GetStrVal();
+ auto status_or_result = StringToPrecision(val);
+ if (!status_or_result.ok()) {
+ return TokenError(
+ Printf("expects precision but sees: %s, error: %s", val.c_str(),
+ status_or_result.status().error_message().c_str()));
+ }
+ *result = status_or_result.ValueOrDie();
+ lexer_.Lex();
+ return true;
+}
+
bool HloParser::ParseInt64(tensorflow::int64* result) {
VLOG(1) << "ParseInt64";
if (lexer_.GetKind() != TokKind::kInt) {
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index 49e1f87319..530f40e4b2 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -109,6 +109,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) {
std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
dot->shape(), new_lhs, new_rhs, new_dim_numbers);
+ new_dot->set_precision_config(dot->precision_config());
return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
}
@@ -178,6 +179,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
auto new_conv = HloInstruction::CreateConvolve(
convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
+ new_conv->set_precision_config(convolution.precision_config());
TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
&convolution, std::move(new_conv)));
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 27aa94c2cb..9451e0c315 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -569,3 +569,18 @@ message ReplicaGroup {
// ids matters in some op (e.g., all-to-all).
repeated int64 replica_ids = 1;
}
+
+// Used to indicate the precision configuration. It has backend specific
+// meaning.
+message PrecisionConfigProto {
+ enum Precision {
+ DEFAULT = 0;
+ HIGH = 1;
+ HIGHEST = 2;
+
+ // Next: 3
+ }
+ repeated Precision operand_precision = 1;
+
+ // Next: 2
+}