aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/primitive_util.h7
-rw-r--r--tensorflow/compiler/xla/service/BUILD16
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc289
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.cc137
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.h49
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc12
-rw-r--r--tensorflow/compiler/xla/tests/BUILD2
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc16
12 files changed, 416 insertions, 119 deletions
diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h
index 19c6a13888..cb4583d198 100644
--- a/tensorflow/compiler/xla/primitive_util.h
+++ b/tensorflow/compiler/xla/primitive_util.h
@@ -26,6 +26,13 @@ limitations under the License.
namespace xla {
namespace primitive_util {
+// The number of exponent bits in a BF16 value.
+const int kBFloat16ExponentBits = 8;
+
+// The number of mantissa bits in a BF16 value. There is an implicit leading
+// 1, so there is an implicit additional bit of precision.
+const int kBFloat16MantissaBits = 7;
+
// Returns the XLA primitive type (eg, F32) corresponding to the given
// template parameter native type (eg, float).
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 1023d3e5dc..baa4afde2d 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1892,6 +1892,22 @@ tf_cc_test(
)
cc_library(
+ name = "hlo_element_type_converter",
+ srcs = ["hlo_element_type_converter.cc"],
+ hdrs = ["hlo_element_type_converter.h"],
+ deps = [
+ ":hlo",
+ ":hlo_evaluator",
+ ":hlo_pass",
+ ":hlo_query",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "device_memory_allocator",
srcs = ["device_memory_allocator.cc"],
hdrs = ["device_memory_allocator.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 32abb1b559..fe537dfdf2 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -110,6 +110,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_constant_folding",
"//tensorflow/compiler/xla/service:hlo_cse",
"//tensorflow/compiler/xla/service:hlo_dce",
+ "//tensorflow/compiler/xla/service:hlo_element_type_converter",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 6c72ef6849..a476a75027 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -68,6 +68,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
#include "tensorflow/compiler/xla/service/hlo_cse.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
@@ -318,6 +319,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
[](const Shape&, const Shape&) { return true; },
/*enable_dot_strength_reduction=*/false);
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
+ pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
// Outline ops in the entry computation into calls to subcomputations.
const int max_parallelism =
module->config().intra_op_parallelism_threads() > 0
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 85d9668f89..dd027986b2 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -516,7 +516,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
HloComputation* function = reduce_window->to_apply();
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*reduce_window, /*operands=*/{operand},
- /*supported_types=*/{F32}));
+ /*supported_types=*/{F32, BF16}));
// TODO(b/31410564): Implement dilation for reduce-window.
if (window_util::HasDilation(window)) {
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index b9407818cd..7e88bbd631 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -50,11 +50,161 @@ using llvm_ir::IrName;
using llvm_ir::SetToFirstInsertPoint;
using tensorflow::strings::StrCat;
+namespace {
+
+llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
+ int64 mantissa_bits,
+ llvm::IRBuilder<>* ir_builder) {
+ // Integer and float types for casting and constant generation.
+ llvm::Type* float_type = x->getType();
+ llvm::IntegerType* int_type = ir_builder->getInt32Ty();
+
+ // Cast the input value to an integer for bitwise manipulation.
+ llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type);
+
+ if (mantissa_bits < 23) {
+ // Last remaining mantissa bit.
+ const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
+
+ // Compute rounding bias for round-to-nearest with ties to even. This is
+ // equal to a base value of 0111... plus one bit if the last remaining
+ // mantissa bit is 1.
+ const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
+ llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr(
+ ir_builder->CreateAnd(
+ x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
+ (23 - mantissa_bits));
+ llvm::Value* x_rounding_bias = ir_builder->CreateAdd(
+ x_last_mantissa_bit,
+ llvm::ConstantInt::get(int_type, base_rounding_bias));
+
+ // Add rounding bias, and mask out truncated bits. Note that the case
+ // where adding the rounding bias overflows into the exponent bits is
+ // correct; the non-masked mantissa bits will all be zero, and the
+ // exponent will be incremented by one.
+ const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
+ x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias);
+ x_as_int = ir_builder->CreateAnd(
+ x_as_int, llvm::ConstantInt::get(int_type, truncation_mask));
+ }
+
+ if (exponent_bits < 8) {
+ // Masks for f32 values.
+ const uint32_t f32_sign_bit_mask = 1u << 31;
+ const uint32_t f32_exp_bits_mask = 0xffu << 23;
+
+ // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
+ // significant bit -- is equal to 1.0f for all exponent sizes. Adding
+ // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
+ // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
+ // exponent (corresponding to 0.0f).
+ //
+ // Thus, the f32 exponent corresponding to the highest non-infinite
+ // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
+ // exponent corresponding to the lowest exponent for a bit size of n is
+ // (2^7-1) - 2^(n-1)-1.
+ //
+ // Note that we have already checked that exponents_bits >= 1.
+ const uint32_t f32_exponent_bias = (1 << 7) - 1;
+ const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1;
+ const uint32_t reduced_max_exponent =
+ f32_exponent_bias + reduced_exponent_bias;
+ const uint32_t reduced_min_exponent =
+ f32_exponent_bias - reduced_exponent_bias;
+
+ // Do we overflow or underflow?
+ llvm::Value* x_exponent = ir_builder->CreateAnd(
+ x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
+ llvm::Value* x_overflows = ir_builder->CreateICmpUGT(
+ x_exponent,
+ llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
+ llvm::Value* x_underflows = ir_builder->CreateICmpULE(
+ x_exponent,
+ llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
+
+ // Compute appropriately-signed values of zero and infinity.
+ llvm::Value* x_signed_zero = ir_builder->CreateAnd(
+ x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
+ llvm::Value* x_signed_inf = ir_builder->CreateOr(
+ x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
+
+ // Force to zero or infinity if overflow or underflow. (Note that this
+ // truncates all denormal values to zero, rather than rounding them.)
+ x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int);
+ x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int);
+ }
+
+ // Cast the result back to a floating-point type.
+ llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type);
+
+ // Correct result for NaN inputs.
+ //
+ // The exponent handling will "normalize" NaN values to infinities, which is
+ // undesirable (except in the case with no mantissa bits, in which case it
+ // is mandatory). This logic also handles cases where mantissa-rounding
+ // causes a NaN's mantissa to overflow into the exponent bits, which would
+ // otherwise create an erroneous zero value.
+ //
+ // If the fast-math flags are set to assume no NaNs, the comparison is likely
+ // to be optimized away, so there's no point in even emitting it.
+ if (!ir_builder->getFastMathFlags().noNaNs()) {
+ llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x);
+
+ if (mantissa_bits > 0) {
+ result = ir_builder->CreateSelect(x_is_nan, x, result);
+ } else {
+ result = ir_builder->CreateSelect(
+ x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
+ }
+ }
+ return result;
+}
+
+llvm::Value* EmitF32ToBF16(llvm::Value* f32_value,
+ llvm::IRBuilder<>* ir_builder) {
+ auto reduced_precision = EmitReducePrecisionFloat(
+ f32_value,
+ /*exponent_bits=*/primitive_util::kBFloat16ExponentBits,
+ /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder);
+ auto as_int32 =
+ ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty());
+ auto shifted = ir_builder->CreateLShr(as_int32, 16);
+ auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty());
+ return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty());
+}
+
+llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value,
+ llvm::IRBuilder<>* ir_builder) {
+ auto as_int16 =
+ ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty());
+ auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty());
+ auto shifted = ir_builder->CreateShl(as_int32, 16);
+ return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy());
+}
+
+llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
+ PrimitiveType from_type,
+ PrimitiveType to_type, llvm::Module* module,
+ llvm::IRBuilder<>* ir_builder) {
+ if (primitive_util::IsSignedIntegralType(from_type)) {
+ return ir_builder->CreateSIToFP(
+ integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module));
+ } else {
+ CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
+ from_type == PRED);
+ return ir_builder->CreateUIToFP(
+ integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module));
+ }
+}
+
+} // namespace
+
StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
if (op->opcode() == HloOpcode::kCopy) {
return operand_value;
- } else if (operand_value->getType()->isIntegerTy()) {
+ } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
+ op->operand(0)->shape().element_type() == PRED) {
return EmitIntegerUnaryOp(op, operand_value);
} else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
return EmitComplexUnaryOp(op, operand_value);
@@ -79,15 +229,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
primitive_util::IsSignedIntegralType(to_type));
}
if (primitive_util::IsFloatingPointType(to_type)) {
- if (primitive_util::IsSignedIntegralType(from_type)) {
- return ir_builder_->CreateSIToFP(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
- }
- if (primitive_util::IsUnsignedIntegralType(from_type) ||
- from_type == PRED) {
- return ir_builder_->CreateUIToFP(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ if (to_type == BF16) {
+ return EmitF32ToBF16(
+ EmitIntegralToFloating(operand_value, from_type, F32, module_,
+ ir_builder_),
+ ir_builder_);
}
+ return EmitIntegralToFloating(operand_value, from_type, to_type,
+ module_, ir_builder_);
}
if (primitive_util::IsComplexType(to_type)) {
auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
@@ -207,6 +356,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
nullptr);
}
+ if (from_type == BF16) {
+ TF_RET_CHECK(to_type != BF16);
+ operand_value = EmitBF16ToF32(operand_value, ir_builder_);
+ from_type = F32;
+ if (from_type == to_type) {
+ return operand_value;
+ }
+ }
+ if (from_type == F32 && to_type == BF16) {
+ return EmitF32ToBF16(operand_value, ir_builder_);
+ }
if (primitive_util::IsFloatingPointType(to_type)) {
return ir_builder_->CreateFPCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
@@ -449,7 +609,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
PrimitiveType operand_type = op->operand(0)->shape().element_type();
- if (lhs_value->getType()->isIntegerTy()) {
+ if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
+ operand_type == PRED) {
return EmitIntegerBinaryOp(
op, lhs_value, rhs_value,
primitive_util::IsSignedIntegralType(operand_type));
@@ -717,111 +878,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
if (hlo->operand(0)->shape().element_type() != F32) {
return Unimplemented("reduce-precision only implemented for F32");
}
-
- // Integer and float types for casting and constant generation.
- llvm::Type* float_type = x->getType();
- llvm::IntegerType* int_type = ir_builder_->getInt32Ty();
-
- // Cast the input value to an integer for bitwise manipulation.
- llvm::Value* x_as_int = ir_builder_->CreateBitCast(x, int_type);
-
- if (hlo->mantissa_bits() < 23) {
- // Last remaining mantissa bit.
- const uint32_t last_mantissa_bit_mask = 1u << (23 - hlo->mantissa_bits());
-
- // Compute rounding bias for round-to-nearest with ties to even. This is
- // equal to a base value of 0111... plus one bit if the last remaining
- // mantissa bit is 1.
- const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
- llvm::Value* x_last_mantissa_bit = ir_builder_->CreateLShr(
- ir_builder_->CreateAnd(
- x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
- (23 - hlo->mantissa_bits()));
- llvm::Value* x_rounding_bias = ir_builder_->CreateAdd(
- x_last_mantissa_bit,
- llvm::ConstantInt::get(int_type, base_rounding_bias));
-
- // Add rounding bias, and mask out truncated bits. Note that the case
- // where adding the rounding bias overflows into the exponent bits is
- // correct; the non-masked mantissa bits will all be zero, and the
- // exponent will be incremented by one.
- const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
- x_as_int = ir_builder_->CreateAdd(x_as_int, x_rounding_bias);
- x_as_int = ir_builder_->CreateAnd(
- x_as_int, llvm::ConstantInt::get(int_type, truncation_mask));
- }
-
- if (hlo->exponent_bits() < 8) {
- // Masks for f32 values.
- const uint32_t f32_sign_bit_mask = 1u << 31;
- const uint32_t f32_exp_bits_mask = 0xffu << 23;
-
- // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
- // significant bit -- is equal to 1.0f for all exponent sizes. Adding
- // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
- // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
- // exponent (corresponding to 0.0f).
- //
- // Thus, the f32 exponent corresponding to the highest non-infinite
- // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
- // exponent corresponding to the lowest exponent for a bit size of n is
- // (2^7-1) - 2^(n-1)-1.
- //
- // Note that we have already checked that exponents_bits >= 1.
- const uint32_t f32_exponent_bias = (1 << 7) - 1;
- const uint32_t reduced_exponent_bias =
- (1 << (hlo->exponent_bits() - 1)) - 1;
- const uint32_t reduced_max_exponent =
- f32_exponent_bias + reduced_exponent_bias;
- const uint32_t reduced_min_exponent =
- f32_exponent_bias - reduced_exponent_bias;
-
- // Do we overflow or underflow?
- llvm::Value* x_exponent = ir_builder_->CreateAnd(
- x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
- llvm::Value* x_overflows = ir_builder_->CreateICmpUGT(
- x_exponent,
- llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
- llvm::Value* x_underflows = ir_builder_->CreateICmpULE(
- x_exponent,
- llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
-
- // Compute appropriately-signed values of zero and infinity.
- llvm::Value* x_signed_zero = ir_builder_->CreateAnd(
- x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
- llvm::Value* x_signed_inf = ir_builder_->CreateOr(
- x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
-
- // Force to zero or infinity if overflow or underflow. (Note that this
- // truncates all denormal values to zero, rather than rounding them.)
- x_as_int = ir_builder_->CreateSelect(x_overflows, x_signed_inf, x_as_int);
- x_as_int = ir_builder_->CreateSelect(x_underflows, x_signed_zero, x_as_int);
- }
-
- // Cast the result back to a floating-point type.
- llvm::Value* result = ir_builder_->CreateBitCast(x_as_int, float_type);
-
- // Correct result for NaN inputs.
- //
- // The exponent handling will "normalize" NaN values to infinities, which is
- // undesirable (except in the case with no mantissa bits, in which case it
- // is mandatory). This logic also handles cases where mantissa-rounding
- // causes a NaN's mantissa to overflow into the exponent bits, which would
- // otherwise create an erroneous zero value.
- //
- // If the fast-math flags are set to assume no NaNs, the comparison is likely
- // to be optimized away, so there's no point in even emitting it.
- if (!ir_builder_->getFastMathFlags().noNaNs()) {
- llvm::Value* x_is_nan = ir_builder_->CreateFCmpUNO(x, x);
-
- if (hlo->mantissa_bits() > 0) {
- result = ir_builder_->CreateSelect(x_is_nan, x, result);
- } else {
- result = ir_builder_->CreateSelect(
- x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
- }
- }
- return result;
+ return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(),
+ /*mantissa_bits=*/hlo->mantissa_bits(),
+ ir_builder_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
new file mode 100644
index 0000000000..1773bb401d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
@@ -0,0 +1,137 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_query.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+namespace {
+
+HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) {
+ if (hlo->shape().element_type() != type) {
+ Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
+ hlo = hlo->parent()->AddInstruction(
+ HloInstruction::CreateConvert(shape, hlo));
+ }
+ CHECK_EQ(hlo->shape().element_type(), type);
+ return hlo;
+}
+
+bool HasOperandType(HloInstruction* hlo, PrimitiveType type) {
+ for (HloInstruction* operand : hlo->operands()) {
+ if (operand->shape().element_type() == type) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+HloElementTypeConverter::HloElementTypeConverter(
+ PrimitiveType eliminate_type, PrimitiveType replace_with_type)
+ : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {}
+
+StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
+ XLA_VLOG_LINES(
+ 3, "HloElementTypeConverter::Run(), before:\n" + module->ToString());
+ bool changed = false;
+ for (auto* computation : module->computations()) {
+ for (auto* hlo : computation->MakeInstructionPostOrder()) {
+ // These are ops where it does not make sense to convert them.
+ if (hlo->opcode() == HloOpcode::kParameter ||
+ hlo->opcode() == HloOpcode::kConstant ||
+ hlo->opcode() == HloOpcode::kTuple ||
+ hlo->opcode() == HloOpcode::kConvert ||
+ hlo->opcode() == HloOpcode::kGetTupleElement ||
+ hlo->opcode() == HloOpcode::kInfeed ||
+ hlo->opcode() == HloOpcode::kOutfeed) {
+ continue;
+ }
+
+ // We cannot change a CustomCall since we have no way of adjusting the
+ // called binary to expect the updated type.
+ if (hlo->opcode() == HloOpcode::kCustomCall) {
+ continue;
+ }
+
+ // These are ops with embedded computations where it suffices to convert
+ // the embedded computations instead of converting the ops themselves.
+ if (hlo->opcode() == HloOpcode::kWhile ||
+ hlo->opcode() == HloOpcode::kCall ||
+ hlo->opcode() == HloOpcode::kFusion ||
+ hlo->opcode() == HloOpcode::kMap ||
+ hlo->opcode() == HloOpcode::kReduce ||
+ hlo->opcode() == HloOpcode::kReduceWindow ||
+ hlo->opcode() == HloOpcode::kSelectAndScatter ||
+ hlo->opcode() == HloOpcode::kConditional) {
+ continue;
+ }
+ TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
+
+ if (!HasOperandType(hlo, eliminate_type_)) {
+ // If this CHECK fires, then this was an instruction that does not take
+ // the elimination type as an operand but it does return it. This pass
+ // does not have a feature to change the output type in that case, so
+ // instead of silently failing to eliminate the type, it fails loudly.
+ TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_);
+ continue;
+ }
+
+ std::vector<HloInstruction*> new_operands;
+ for (HloInstruction* operand : hlo->operands()) {
+ if (operand->shape().element_type() == eliminate_type_) {
+ operand = ToElementType(operand, replace_with_type_);
+ }
+ new_operands.push_back(operand);
+ }
+
+ HloInstruction* new_hlo;
+ if (hlo->shape().element_type() == eliminate_type_) {
+ Shape shape =
+ ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_);
+ new_hlo = computation->AddInstruction(
+ hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule()));
+ new_hlo = ToElementType(new_hlo, eliminate_type_);
+ } else {
+ new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands(
+ hlo->shape(), new_operands, hlo->GetModule()));
+ }
+ TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo));
+ changed = true;
+ }
+ }
+ XLA_VLOG_LINES(
+ 2, "HloElementTypeConverter::Run(), after:\n" + module->ToString());
+ return changed;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
new file mode 100644
index 0000000000..2b109225d0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
@@ -0,0 +1,49 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// A pass that eliminates certain element types as the input or output of ops by
+// inserting Convert ops. This allows a backend to support an element type while
+// only actually implementing the Convert op for that element type. This is
+// generally not the fastest approach, but it works.
+class HloElementTypeConverter : public HloPassInterface {
+ public:
+ // eliminate_type is the type to eliminate as the input or output of ops,
+ // using Convert ops to replace it with replace_with_type.
+ HloElementTypeConverter(PrimitiveType eliminate_type,
+ PrimitiveType replace_with_type);
+
+ tensorflow::StringPiece name() const override {
+ return "element_type_converter";
+ }
+
+ // Returns the pass on the module and returns whether the module was modified.
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ PrimitiveType eliminate_type_;
+ PrimitiveType replace_with_type_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index ef5b6ad90e..9a0c94b1c7 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -142,6 +142,13 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
return llvm::Type::getInt8Ty(module->getContext());
case S16:
case U16:
+ case BF16:
+ // For BF16 we just need some type that is 16 bits wide so that it will
+ // take up the right amount of space in memory. LLVM does not have a BF16
+ // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so
+ // we can't map it directly to an LLVM type. We will not map a BF16
+ // addition to an addition on this type (int16) - this is just the type
+ // used for storage.
return llvm::Type::getInt16Ty(module->getContext());
case S32:
case U32:
@@ -280,6 +287,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
value = llvm::ConstantFP::get(ir_element_type,
literal.Get<float>(*multi_index));
break;
+ case BF16:
+ value = llvm::ConstantInt::get(
+ ir_element_type,
+ tensorflow::bit_cast<uint16>(literal.Get<bfloat16>(*multi_index)));
+ break;
case F64:
value = llvm::ConstantFP::get(ir_element_type,
literal.Get<double>(*multi_index));
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 6f03f1a4e0..6af01ae80d 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -802,8 +802,6 @@ xla_test(
name = "bfloat16_test",
srcs = ["bfloat16_test.cc"],
blacklisted_backends = [
- "cpu",
- "cpu_parallel",
"gpu",
],
shard_count = 40,
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 330575a02e..b32df74312 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -53,7 +53,7 @@ class ReduceWindowTestBase : public ClientLibraryTestBase {
public:
ErrorSpec DefaultErrorSpec() const {
if (use_bfloat16()) {
- return ErrorSpec(1e-1, 3e-2);
+ return ErrorSpec(1e-1, 5e-2);
} else {
return ErrorSpec(1e-3, 1e-3);
}
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 93bce97a3e..780b292d1a 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -35,6 +35,19 @@ void PopulateWithRandomFloatingPointData(Literal* literal) {
}));
}
+// The standard library does not have a case for bfloat16, unsurprisingly, so we
+// handle that one specially.
+template <>
+void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal) {
+ CHECK_EQ(literal->shape().element_type(), BF16);
+ std::minstd_rand0 engine;
+ std::uniform_real_distribution<float> generator(0.0f, 1.0f);
+ TF_CHECK_OK(literal->Populate<bfloat16>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return static_cast<bfloat16>(generator(engine));
+ }));
+}
+
template <typename IntT>
void PopulateWithRandomIntegralData(Literal* literal) {
CHECK_EQ(literal->shape().element_type(),
@@ -171,6 +184,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
}
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
switch (shape.element_type()) {
+ case BF16:
+ PopulateWithRandomFloatingPointData<bfloat16>(literal.get());
+ break;
case F32:
PopulateWithRandomFloatingPointData<float>(literal.get());
break;