aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-25 00:05:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-25 00:09:39 -0700
commitd1a9ea61ef8271b3d2fe273a68ff5940fcba7ccd (patch)
treefb4527dc7de04f4a2a1a0bc101334c31d2cbef38 /tensorflow/compiler/xla/service
parent73b120ea3b517b6af2267ca078bf571f966fd606 (diff)
[XLA] Teach CPU and GPU compilers to optionally invoke the HLO insert-reduce-precision-operations pass.
This also required a few additions and fixups. We add pieces to ReducePrecisionInsertion to translate between the protocol-buffer representation of the pass options and the predicate-function actually used in the pass. To facilitate this translation, we also add a function to HloOpcode to return the number of opcodes so that we can iterate over the whole set easily. PiperOrigin-RevId: 163037250
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc18
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc31
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h5
-rw-r--r--tensorflow/compiler/xla/service/reduce_precision_insertion.cc39
-rw-r--r--tensorflow/compiler/xla/service/reduce_precision_insertion.h22
8 files changed, 112 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 696dc28564..a4612bb6c1 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1945,6 +1945,7 @@ cc_library(
":buffer_liveness",
":hlo",
":hlo_pass",
+ "//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 7248cb5f4c..2ca4af67cd 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -72,6 +72,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:inliner",
+ "//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # fixdeps: keep
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 6d819355c4..b86342d0b3 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -74,6 +74,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/inliner.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -253,6 +254,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
HloPassPipeline pipeline("CPU");
pipeline.AddInvariantChecker<HloVerifier>();
+ for (const auto& reduce_precision_options :
+ module->config().debug_options().hlo_reduce_precision_options()) {
+ if (reduce_precision_options.pass_timing() ==
+ HloReducePrecisionOptions::BEFORE_OP_FUSION) {
+ pipeline.AddPass<ReducePrecisionInsertion>(reduce_precision_options);
+ }
+ }
+
// TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding
// where we will take this pass in future.
// pipeline.AddPass<Inliner>();
@@ -278,6 +287,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
TransposeFolding::NeverFoldTranspose);
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
pipeline.AddPass<CpuInstructionFusion>();
+
+ for (const auto& reduce_precision_options :
+ module->config().debug_options().hlo_reduce_precision_options()) {
+ if (reduce_precision_options.pass_timing() ==
+ HloReducePrecisionOptions::AFTER_OP_FUSION) {
+ pipeline.AddPass<ReducePrecisionInsertion>(reduce_precision_options);
+ }
+ }
+
pipeline.AddPass<CpuLayoutAssignment>(
module->mutable_entry_computation_layout());
// The LayoutAssignment pass may leave behind kCopy instructions which are
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index fa95e23499..cdd7c8187c 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -432,6 +432,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_proto_util",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index d60c45a5c3..2acf95084a 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -56,6 +56,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -123,6 +124,15 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
{
HloPassPipeline pipeline("optimization");
pipeline.AddInvariantChecker<HloVerifier>();
+
+ for (const auto& reduce_precision_options :
+ hlo_module->config().debug_options().hlo_reduce_precision_options()) {
+ if (reduce_precision_options.pass_timing() ==
+ HloReducePrecisionOptions::BEFORE_OP_FUSION) {
+ pipeline.AddPass<ReducePrecisionInsertion>(reduce_precision_options);
+ }
+ }
+
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
@@ -149,8 +159,27 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
fusion.AddPass<FusionMerger>();
- return fusion.Run(hlo_module).status();
+ TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
+
+ HloPassPipeline reduce_pipeline("reduce-precision");
+ for (const auto& reduce_precision_options :
+ hlo_module->config().debug_options().hlo_reduce_precision_options()) {
+ if (reduce_precision_options.pass_timing() ==
+ HloReducePrecisionOptions::AFTER_OP_FUSION) {
+ reduce_pipeline.AddPass<ReducePrecisionInsertion>(
+ reduce_precision_options);
+ }
+ }
+ StatusOr<bool> reduce_result = reduce_pipeline.Run(hlo_module);
+ TF_RETURN_IF_ERROR(reduce_result.status());
+
+ if (reduce_result.ValueOrDie()) {
+ // Do another fusion pass, with the expectation that we may be able to
+ // fuse the new ReducePrecision operations.
+ TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
+ }
}
+ return tensorflow::Status::OK();
}
// Modifies the given HLO module so that it will be accepted by IrEmitter.
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 358e611d57..8a6376b2d1 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -112,6 +112,11 @@ bool HloOpcodeIsComparison(HloOpcode opcode);
// Returns true iff the given opcode has variadic operands.
bool HloOpcodeIsVariadic(HloOpcode opcode);
+// Returns the number of HloOpcode values.
+inline const uint32_t HloOpcodeCount() {
+ return static_cast<uint32_t>(HloOpcode::kWhile) + 1;
+}
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc
index dafefdc491..e083226b14 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -30,14 +31,15 @@ StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
for (auto& instruction : computation->instructions()) {
VLOG(3) << "Visited instruction: " << instruction->ToString();
- // For now, ReducePrecision is only implemented for F32 data, so this
+ // For now, ReducePrecision is only implemented for F32 arrays, so this
// ignore instructions that produce other data. In particular, this
// currently ignores instructions producing tuples, even if those tuples
- // contain F32 data inside them. The assumption is that in most cases
+ // contain F32 arrays inside them. The assumption is that in most cases
// equivalent behavior can be obtained by adding ReducePrecision
- // instructions after the instructions that pull the F32 data out of the
- // tuples.
+ // instructions after the instructions that pull the F32 arrays out of
+ // the tuples.
if (instruction->shape().element_type() == PrimitiveType::F32 &&
+ !ShapeUtil::IsScalar(instruction->shape()) &&
should_reduce_output_precision_(instruction->opcode())) {
instructions_to_suffix.push_back(instruction.get());
}
@@ -58,4 +60,33 @@ StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
return changed;
}
+ReducePrecisionInsertion::OpcodeFilterFunction
+ReducePrecisionInsertion::make_filter_function(
+ const HloReducePrecisionOptions& reduce_precision_options) {
+ // Implement the filter function with a lookup table.
+ std::vector<bool> filter(HloOpcodeCount(), false);
+ for (const auto& opcode : reduce_precision_options.opcodes_to_suffix()) {
+ filter[opcode] = true;
+ }
+ return [filter](const HloOpcode opcode) {
+ return filter[static_cast<unsigned int>(opcode)];
+ };
+}
+
+HloReducePrecisionOptions ReducePrecisionInsertion::make_options_proto(
+ const HloReducePrecisionOptions::PassTiming pass_timing,
+ const int exponent_bits, const int mantissa_bits,
+ const OpcodeFilterFunction& should_reduce_output_precision) {
+ HloReducePrecisionOptions options;
+ options.set_pass_timing(pass_timing);
+ options.set_exponent_bits(exponent_bits);
+ options.set_mantissa_bits(mantissa_bits);
+ for (uint32_t opcode = 0; opcode < HloOpcodeCount(); opcode++) {
+ if (should_reduce_output_precision(static_cast<HloOpcode>(opcode))) {
+ options.add_opcodes_to_suffix(opcode);
+ }
+ }
+ return options;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
index e9c8bba031..34b865b9ce 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
@@ -42,6 +42,17 @@ class ReducePrecisionInsertion : public HloPassInterface {
: exponent_bits_(exponent_bits),
mantissa_bits_(mantissa_bits),
should_reduce_output_precision_(should_reduce_output_precision) {}
+
+ // Version of the constructor that takes an HloReducePrecisionOptions proto
+ // rather than explicitly-enumerated parameters, for convenience when
+ // creating passes based on DebugOptions.
+ explicit ReducePrecisionInsertion(
+ const HloReducePrecisionOptions& reduce_precision_options)
+ : exponent_bits_(reduce_precision_options.exponent_bits()),
+ mantissa_bits_(reduce_precision_options.mantissa_bits()),
+ should_reduce_output_precision_(
+ make_filter_function(reduce_precision_options)) {}
+
~ReducePrecisionInsertion() override{};
tensorflow::StringPiece name() const override {
@@ -52,6 +63,15 @@ class ReducePrecisionInsertion : public HloPassInterface {
// (reduce-precision instructions were inserted).
StatusOr<bool> Run(HloModule* module) override;
+ // Convert between the (inconvenient) xla.proto HloReducePrecisionOptions
+ // representation and OpcodeFilterFunction functions.
+ static OpcodeFilterFunction make_filter_function(
+ const HloReducePrecisionOptions& reduce_precision_options);
+ static HloReducePrecisionOptions make_options_proto(
+ const HloReducePrecisionOptions::PassTiming pass_timing,
+ const int exponent_bits, const int mantissa_bits,
+ const OpcodeFilterFunction& should_reduce_output_precision);
+
private:
// Parameters for the precision reduction to be added.
const int exponent_bits_;
@@ -59,7 +79,7 @@ class ReducePrecisionInsertion : public HloPassInterface {
// Function to determine (from the opcode) whether a given instruction should
// have a reduce-precision instruction inserted in its output stream.
- const OpcodeFilterFunction& should_reduce_output_precision_;
+ const OpcodeFilterFunction should_reduce_output_precision_;
};
} // namespace xla