aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Bjarke Hammersholt Roune <broune@google.com>2017-12-08 13:37:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-08 13:41:13 -0800
commit2f16f3afdcde16cf0de2f051c57b32cd61a12ec0 (patch)
tree016e5f89025746fed9d6643d9bfde209cc7ce4ee /tensorflow
parentdc04e89bc6f0421bf77ac69f21c1f2f57618f53c (diff)
Add bfloat16 support to the CPU backend.
* A few ops, in particular Convert, directly support bfloat16. * Added an HLO pass HloElementTypeConverter which converts graphs away from bfloat16 without changing the numerics, using Convert ops. This can be improved in many ways, but the feature here is that one can run XLA graphs that use bfloat16 on the CPU backend and get the correct result. PiperOrigin-RevId: 178419829
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/primitive_util.h7
-rw-r--r--tensorflow/compiler/xla/service/BUILD16
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc289
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.cc137
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.h49
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc12
-rw-r--r--tensorflow/compiler/xla/tests/BUILD2
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc16
12 files changed, 416 insertions, 119 deletions
diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h
index 19c6a13888..cb4583d198 100644
--- a/tensorflow/compiler/xla/primitive_util.h
+++ b/tensorflow/compiler/xla/primitive_util.h
@@ -26,6 +26,13 @@ limitations under the License.
namespace xla {
namespace primitive_util {
+// The number of exponent bits in a BF16 value.
+const int kBFloat16ExponentBits = 8;
+
+// The number of mantissa bits in a BF16 value. There is an implicit leading
+// 1, so there is an implicit additional bit of precision.
+const int kBFloat16MantissaBits = 7;
+
// Returns the XLA primitive type (eg, F32) corresponding to the given
// template parameter native type (eg, float).
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 1023d3e5dc..baa4afde2d 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1892,6 +1892,22 @@ tf_cc_test(
)
cc_library(
+ name = "hlo_element_type_converter",
+ srcs = ["hlo_element_type_converter.cc"],
+ hdrs = ["hlo_element_type_converter.h"],
+ deps = [
+ ":hlo",
+ ":hlo_evaluator",
+ ":hlo_pass",
+ ":hlo_query",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "device_memory_allocator",
srcs = ["device_memory_allocator.cc"],
hdrs = ["device_memory_allocator.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 32abb1b559..fe537dfdf2 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -110,6 +110,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_constant_folding",
"//tensorflow/compiler/xla/service:hlo_cse",
"//tensorflow/compiler/xla/service:hlo_dce",
+ "//tensorflow/compiler/xla/service:hlo_element_type_converter",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 6c72ef6849..a476a75027 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -68,6 +68,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
#include "tensorflow/compiler/xla/service/hlo_cse.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
@@ -318,6 +319,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
[](const Shape&, const Shape&) { return true; },
/*enable_dot_strength_reduction=*/false);
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
+ pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
// Outline ops in the entry computation into calls to subcomputations.
const int max_parallelism =
module->config().intra_op_parallelism_threads() > 0
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 85d9668f89..dd027986b2 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -516,7 +516,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
HloComputation* function = reduce_window->to_apply();
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*reduce_window, /*operands=*/{operand},
- /*supported_types=*/{F32}));
+ /*supported_types=*/{F32, BF16}));
// TODO(b/31410564): Implement dilation for reduce-window.
if (window_util::HasDilation(window)) {
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index b9407818cd..7e88bbd631 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -50,11 +50,161 @@ using llvm_ir::IrName;
using llvm_ir::SetToFirstInsertPoint;
using tensorflow::strings::StrCat;
+namespace {
+
+llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
+ int64 mantissa_bits,
+ llvm::IRBuilder<>* ir_builder) {
+ // Integer and float types for casting and constant generation.
+ llvm::Type* float_type = x->getType();
+ llvm::IntegerType* int_type = ir_builder->getInt32Ty();
+
+ // Cast the input value to an integer for bitwise manipulation.
+ llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type);
+
+ if (mantissa_bits < 23) {
+ // Last remaining mantissa bit.
+ const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
+
+ // Compute rounding bias for round-to-nearest with ties to even. This is
+ // equal to a base value of 0111... plus one bit if the last remaining
+ // mantissa bit is 1.
+ const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
+ llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr(
+ ir_builder->CreateAnd(
+ x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
+ (23 - mantissa_bits));
+ llvm::Value* x_rounding_bias = ir_builder->CreateAdd(
+ x_last_mantissa_bit,
+ llvm::ConstantInt::get(int_type, base_rounding_bias));
+
+ // Add rounding bias, and mask out truncated bits. Note that the case
+ // where adding the rounding bias overflows into the exponent bits is
+ // correct; the non-masked mantissa bits will all be zero, and the
+ // exponent will be incremented by one.
+ const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
+ x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias);
+ x_as_int = ir_builder->CreateAnd(
+ x_as_int, llvm::ConstantInt::get(int_type, truncation_mask));
+ }
+
+ if (exponent_bits < 8) {
+ // Masks for f32 values.
+ const uint32_t f32_sign_bit_mask = 1u << 31;
+ const uint32_t f32_exp_bits_mask = 0xffu << 23;
+
+ // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
+ // significant bit -- is equal to 1.0f for all exponent sizes. Adding
+ // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
+ // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
+ // exponent (corresponding to 0.0f).
+ //
+ // Thus, the f32 exponent corresponding to the highest non-infinite
+ // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
+ // exponent corresponding to the lowest exponent for a bit size of n is
+ // (2^7-1) - 2^(n-1)-1.
+ //
+ // Note that we have already checked that exponents_bits >= 1.
+ const uint32_t f32_exponent_bias = (1 << 7) - 1;
+ const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1;
+ const uint32_t reduced_max_exponent =
+ f32_exponent_bias + reduced_exponent_bias;
+ const uint32_t reduced_min_exponent =
+ f32_exponent_bias - reduced_exponent_bias;
+
+ // Do we overflow or underflow?
+ llvm::Value* x_exponent = ir_builder->CreateAnd(
+ x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
+ llvm::Value* x_overflows = ir_builder->CreateICmpUGT(
+ x_exponent,
+ llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
+ llvm::Value* x_underflows = ir_builder->CreateICmpULE(
+ x_exponent,
+ llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
+
+ // Compute appropriately-signed values of zero and infinity.
+ llvm::Value* x_signed_zero = ir_builder->CreateAnd(
+ x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
+ llvm::Value* x_signed_inf = ir_builder->CreateOr(
+ x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
+
+ // Force to zero or infinity if overflow or underflow. (Note that this
+ // truncates all denormal values to zero, rather than rounding them.)
+ x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int);
+ x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int);
+ }
+
+ // Cast the result back to a floating-point type.
+ llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type);
+
+ // Correct result for NaN inputs.
+ //
+ // The exponent handling will "normalize" NaN values to infinities, which is
+ // undesirable (except in the case with no mantissa bits, in which case it
+ // is mandatory). This logic also handles cases where mantissa-rounding
+ // causes a NaN's mantissa to overflow into the exponent bits, which would
+ // otherwise create an erroneous zero value.
+ //
+ // If the fast-math flags are set to assume no NaNs, the comparison is likely
+ // to be optimized away, so there's no point in even emitting it.
+ if (!ir_builder->getFastMathFlags().noNaNs()) {
+ llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x);
+
+ if (mantissa_bits > 0) {
+ result = ir_builder->CreateSelect(x_is_nan, x, result);
+ } else {
+ result = ir_builder->CreateSelect(
+ x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
+ }
+ }
+ return result;
+}
+
+llvm::Value* EmitF32ToBF16(llvm::Value* f32_value,
+ llvm::IRBuilder<>* ir_builder) {
+ auto reduced_precision = EmitReducePrecisionFloat(
+ f32_value,
+ /*exponent_bits=*/primitive_util::kBFloat16ExponentBits,
+ /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder);
+ auto as_int32 =
+ ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty());
+ auto shifted = ir_builder->CreateLShr(as_int32, 16);
+ auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty());
+ return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty());
+}
+
+llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value,
+ llvm::IRBuilder<>* ir_builder) {
+ auto as_int16 =
+ ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty());
+ auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty());
+ auto shifted = ir_builder->CreateShl(as_int32, 16);
+ return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy());
+}
+
+llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
+ PrimitiveType from_type,
+ PrimitiveType to_type, llvm::Module* module,
+ llvm::IRBuilder<>* ir_builder) {
+ if (primitive_util::IsSignedIntegralType(from_type)) {
+ return ir_builder->CreateSIToFP(
+ integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module));
+ } else {
+ CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
+ from_type == PRED);
+ return ir_builder->CreateUIToFP(
+ integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module));
+ }
+}
+
+} // namespace
+
StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
if (op->opcode() == HloOpcode::kCopy) {
return operand_value;
- } else if (operand_value->getType()->isIntegerTy()) {
+ } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
+ op->operand(0)->shape().element_type() == PRED) {
return EmitIntegerUnaryOp(op, operand_value);
} else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
return EmitComplexUnaryOp(op, operand_value);
@@ -79,15 +229,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
primitive_util::IsSignedIntegralType(to_type));
}
if (primitive_util::IsFloatingPointType(to_type)) {
- if (primitive_util::IsSignedIntegralType(from_type)) {
- return ir_builder_->CreateSIToFP(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
- }
- if (primitive_util::IsUnsignedIntegralType(from_type) ||
- from_type == PRED) {
- return ir_builder_->CreateUIToFP(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ if (to_type == BF16) {
+ return EmitF32ToBF16(
+ EmitIntegralToFloating(operand_value, from_type, F32, module_,
+ ir_builder_),
+ ir_builder_);
}
+ return EmitIntegralToFloating(operand_value, from_type, to_type,
+ module_, ir_builder_);
}
if (primitive_util::IsComplexType(to_type)) {
auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
@@ -207,6 +356,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
nullptr);
}
+ if (from_type == BF16) {
+ TF_RET_CHECK(to_type != BF16);
+ operand_value = EmitBF16ToF32(operand_value, ir_builder_);
+ from_type = F32;
+ if (from_type == to_type) {
+ return operand_value;
+ }
+ }
+ if (from_type == F32 && to_type == BF16) {
+ return EmitF32ToBF16(operand_value, ir_builder_);
+ }
if (primitive_util::IsFloatingPointType(to_type)) {
return ir_builder_->CreateFPCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
@@ -449,7 +609,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
PrimitiveType operand_type = op->operand(0)->shape().element_type();
- if (lhs_value->getType()->isIntegerTy()) {
+ if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
+ operand_type == PRED) {
return EmitIntegerBinaryOp(
op, lhs_value, rhs_value,
primitive_util::IsSignedIntegralType(operand_type));
@@ -717,111 +878,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
if (hlo->operand(0)->shape().element_type() != F32) {
return Unimplemented("reduce-precision only implemented for F32");
}
-
- // Integer and float types for casting and constant generation.
- llvm::Type* float_type = x->getType();
- llvm::IntegerType* int_type = ir_builder_->getInt32Ty();
-
- // Cast the input value to an integer for bitwise manipulation.
- llvm::Value* x_as_int = ir_builder_->CreateBitCast(x, int_type);
-
- if (hlo->mantissa_bits() < 23) {
- // Last remaining mantissa bit.
- const uint32_t last_mantissa_bit_mask = 1u << (23 - hlo->mantissa_bits());
-
- // Compute rounding bias for round-to-nearest with ties to even. This is
- // equal to a base value of 0111... plus one bit if the last remaining
- // mantissa bit is 1.
- const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
- llvm::Value* x_last_mantissa_bit = ir_builder_->CreateLShr(
- ir_builder_->CreateAnd(
- x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
- (23 - hlo->mantissa_bits()));
- llvm::Value* x_rounding_bias = ir_builder_->CreateAdd(
- x_last_mantissa_bit,
- llvm::ConstantInt::get(int_type, base_rounding_bias));
-
- // Add rounding bias, and mask out truncated bits. Note that the case
- // where adding the rounding bias overflows into the exponent bits is
- // correct; the non-masked mantissa bits will all be zero, and the
- // exponent will be incremented by one.
- const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
- x_as_int = ir_builder_->CreateAdd(x_as_int, x_rounding_bias);
- x_as_int = ir_builder_->CreateAnd(
- x_as_int, llvm::ConstantInt::get(int_type, truncation_mask));
- }
-
- if (hlo->exponent_bits() < 8) {
- // Masks for f32 values.
- const uint32_t f32_sign_bit_mask = 1u << 31;
- const uint32_t f32_exp_bits_mask = 0xffu << 23;
-
- // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
- // significant bit -- is equal to 1.0f for all exponent sizes. Adding
- // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
- // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
- // exponent (corresponding to 0.0f).
- //
- // Thus, the f32 exponent corresponding to the highest non-infinite
- // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
- // exponent corresponding to the lowest exponent for a bit size of n is
- // (2^7-1) - 2^(n-1)-1.
- //
- // Note that we have already checked that exponents_bits >= 1.
- const uint32_t f32_exponent_bias = (1 << 7) - 1;
- const uint32_t reduced_exponent_bias =
- (1 << (hlo->exponent_bits() - 1)) - 1;
- const uint32_t reduced_max_exponent =
- f32_exponent_bias + reduced_exponent_bias;
- const uint32_t reduced_min_exponent =
- f32_exponent_bias - reduced_exponent_bias;
-
- // Do we overflow or underflow?
- llvm::Value* x_exponent = ir_builder_->CreateAnd(
- x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
- llvm::Value* x_overflows = ir_builder_->CreateICmpUGT(
- x_exponent,
- llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
- llvm::Value* x_underflows = ir_builder_->CreateICmpULE(
- x_exponent,
- llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
-
- // Compute appropriately-signed values of zero and infinity.
- llvm::Value* x_signed_zero = ir_builder_->CreateAnd(
- x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
- llvm::Value* x_signed_inf = ir_builder_->CreateOr(
- x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
-
- // Force to zero or infinity if overflow or underflow. (Note that this
- // truncates all denormal values to zero, rather than rounding them.)
- x_as_int = ir_builder_->CreateSelect(x_overflows, x_signed_inf, x_as_int);
- x_as_int = ir_builder_->CreateSelect(x_underflows, x_signed_zero, x_as_int);
- }
-
- // Cast the result back to a floating-point type.
- llvm::Value* result = ir_builder_->CreateBitCast(x_as_int, float_type);
-
- // Correct result for NaN inputs.
- //
- // The exponent handling will "normalize" NaN values to infinities, which is
- // undesirable (except in the case with no mantissa bits, in which case it
- // is mandatory). This logic also handles cases where mantissa-rounding
- // causes a NaN's mantissa to overflow into the exponent bits, which would
- // otherwise create an erroneous zero value.
- //
- // If the fast-math flags are set to assume no NaNs, the comparison is likely
- // to be optimized away, so there's no point in even emitting it.
- if (!ir_builder_->getFastMathFlags().noNaNs()) {
- llvm::Value* x_is_nan = ir_builder_->CreateFCmpUNO(x, x);
-
- if (hlo->mantissa_bits() > 0) {
- result = ir_builder_->CreateSelect(x_is_nan, x, result);
- } else {
- result = ir_builder_->CreateSelect(
- x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
- }
- }
- return result;
+ return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(),
+ /*mantissa_bits=*/hlo->mantissa_bits(),
+ ir_builder_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
new file mode 100644
index 0000000000..1773bb401d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
@@ -0,0 +1,137 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_query.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+namespace {
+
+HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) {
+ if (hlo->shape().element_type() != type) {
+ Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
+ hlo = hlo->parent()->AddInstruction(
+ HloInstruction::CreateConvert(shape, hlo));
+ }
+ CHECK_EQ(hlo->shape().element_type(), type);
+ return hlo;
+}
+
+bool HasOperandType(HloInstruction* hlo, PrimitiveType type) {
+ for (HloInstruction* operand : hlo->operands()) {
+ if (operand->shape().element_type() == type) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+HloElementTypeConverter::HloElementTypeConverter(
+ PrimitiveType eliminate_type, PrimitiveType replace_with_type)
+ : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {}
+
+StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
+ XLA_VLOG_LINES(
+ 3, "HloElementTypeConverter::Run(), before:\n" + module->ToString());
+ bool changed = false;
+ for (auto* computation : module->computations()) {
+ for (auto* hlo : computation->MakeInstructionPostOrder()) {
+ // These are ops where it does not make sense to convert them.
+ if (hlo->opcode() == HloOpcode::kParameter ||
+ hlo->opcode() == HloOpcode::kConstant ||
+ hlo->opcode() == HloOpcode::kTuple ||
+ hlo->opcode() == HloOpcode::kConvert ||
+ hlo->opcode() == HloOpcode::kGetTupleElement ||
+ hlo->opcode() == HloOpcode::kInfeed ||
+ hlo->opcode() == HloOpcode::kOutfeed) {
+ continue;
+ }
+
+ // We cannot change a CustomCall since we have no way of adjusting the
+ // called binary to expect the updated type.
+ if (hlo->opcode() == HloOpcode::kCustomCall) {
+ continue;
+ }
+
+ // These are ops with embedded computations where it suffices to convert
+ // the embedded computations instead of converting the ops themselves.
+ if (hlo->opcode() == HloOpcode::kWhile ||
+ hlo->opcode() == HloOpcode::kCall ||
+ hlo->opcode() == HloOpcode::kFusion ||
+ hlo->opcode() == HloOpcode::kMap ||
+ hlo->opcode() == HloOpcode::kReduce ||
+ hlo->opcode() == HloOpcode::kReduceWindow ||
+ hlo->opcode() == HloOpcode::kSelectAndScatter ||
+ hlo->opcode() == HloOpcode::kConditional) {
+ continue;
+ }
+ TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
+
+ if (!HasOperandType(hlo, eliminate_type_)) {
+ // If this CHECK fires, then this was an instruction that does not take
+ // the elimination type as an operand but it does return it. This pass
+ // does not have a feature to change the output type in that case, so
+ // instead of silently failing to eliminate the type, it fails loudly.
+ TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_);
+ continue;
+ }
+
+ std::vector<HloInstruction*> new_operands;
+ for (HloInstruction* operand : hlo->operands()) {
+ if (operand->shape().element_type() == eliminate_type_) {
+ operand = ToElementType(operand, replace_with_type_);
+ }
+ new_operands.push_back(operand);
+ }
+
+ HloInstruction* new_hlo;
+ if (hlo->shape().element_type() == eliminate_type_) {
+ Shape shape =
+ ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_);
+ new_hlo = computation->AddInstruction(
+ hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule()));
+ new_hlo = ToElementType(new_hlo, eliminate_type_);
+ } else {
+ new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands(
+ hlo->shape(), new_operands, hlo->GetModule()));
+ }
+ TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo));
+ changed = true;
+ }
+ }
+ XLA_VLOG_LINES(
+ 2, "HloElementTypeConverter::Run(), after:\n" + module->ToString());
+ return changed;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
new file mode 100644
index 0000000000..2b109225d0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
@@ -0,0 +1,49 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// A pass that eliminates certain element types as the input or output of ops by
+// inserting Convert ops. This allows a backend to support an element type while
+// only actually implementing the Convert op for that element type. This is
+// generally not the fastest approach, but it works.
+class HloElementTypeConverter : public HloPassInterface {
+ public:
+ // eliminate_type is the type to eliminate as the input or output of ops,
+ // using Convert ops to replace it with replace_with_type.
+ HloElementTypeConverter(PrimitiveType eliminate_type,
+ PrimitiveType replace_with_type);
+
+ tensorflow::StringPiece name() const override {
+ return "element_type_converter";
+ }
+
+ // Returns the pass on the module and returns whether the module was modified.
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ PrimitiveType eliminate_type_;
+ PrimitiveType replace_with_type_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index ef5b6ad90e..9a0c94b1c7 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -142,6 +142,13 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
return llvm::Type::getInt8Ty(module->getContext());
case S16:
case U16:
+ case BF16:
+ // For BF16 we just need some type that is 16 bits wide so that it will
+ // take up the right amount of space in memory. LLVM does not have a BF16
+ // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so
+ // we can't map it directly to an LLVM type. We will not map a BF16
+ // addition to an addition on this type (int16) - this is just the type
+ // used for storage.
return llvm::Type::getInt16Ty(module->getContext());
case S32:
case U32:
@@ -280,6 +287,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
value = llvm::ConstantFP::get(ir_element_type,
literal.Get<float>(*multi_index));
break;
+ case BF16:
+ value = llvm::ConstantInt::get(
+ ir_element_type,
+ tensorflow::bit_cast<uint16>(literal.Get<bfloat16>(*multi_index)));
+ break;
case F64:
value = llvm::ConstantFP::get(ir_element_type,
literal.Get<double>(*multi_index));
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 6f03f1a4e0..6af01ae80d 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -802,8 +802,6 @@ xla_test(
name = "bfloat16_test",
srcs = ["bfloat16_test.cc"],
blacklisted_backends = [
- "cpu",
- "cpu_parallel",
"gpu",
],
shard_count = 40,
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 330575a02e..b32df74312 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -53,7 +53,7 @@ class ReduceWindowTestBase : public ClientLibraryTestBase {
public:
ErrorSpec DefaultErrorSpec() const {
if (use_bfloat16()) {
- return ErrorSpec(1e-1, 3e-2);
+ return ErrorSpec(1e-1, 5e-2);
} else {
return ErrorSpec(1e-3, 1e-3);
}
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 93bce97a3e..780b292d1a 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -35,6 +35,19 @@ void PopulateWithRandomFloatingPointData(Literal* literal) {
}));
}
+// The standard library does not have a case for bfloat16, unsurprisingly, so we
+// handle that one specially.
+template <>
+void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal) {
+ CHECK_EQ(literal->shape().element_type(), BF16);
+ std::minstd_rand0 engine;
+ std::uniform_real_distribution<float> generator(0.0f, 1.0f);
+ TF_CHECK_OK(literal->Populate<bfloat16>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return static_cast<bfloat16>(generator(engine));
+ }));
+}
+
template <typename IntT>
void PopulateWithRandomIntegralData(Literal* literal) {
CHECK_EQ(literal->shape().element_type(),
@@ -171,6 +184,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
}
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
switch (shape.element_type()) {
+ case BF16:
+ PopulateWithRandomFloatingPointData<bfloat16>(literal.get());
+ break;
case F32:
PopulateWithRandomFloatingPointData<float>(literal.get());
break;