diff options
-rw-r--r-- | tensorflow/compiler/xla/primitive_util.h | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 16 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/cpu_compiler.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_emitter.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/elemental_ir_emitter.cc | 289 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_element_type_converter.cc | 137 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_element_type_converter.h | 49 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/reduce_window_test.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 16 |
12 files changed, 416 insertions, 119 deletions
diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 19c6a13888..cb4583d198 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -26,6 +26,13 @@ limitations under the License. namespace xla { namespace primitive_util { +// The number of exponent bits in a BF16 value. +const int kBFloat16ExponentBits = 8; + +// The number of mantissa bits in a BF16 value. There is an implicit leading +// 1, so there is an implicit additional bit of precision. +const int kBFloat16MantissaBits = 7; + // Returns the XLA primitive type (eg, F32) corresponding to the given // template parameter native type (eg, float). template <typename NativeT> diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 1023d3e5dc..baa4afde2d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1892,6 +1892,22 @@ tf_cc_test( ) cc_library( + name = "hlo_element_type_converter", + srcs = ["hlo_element_type_converter.cc"], + hdrs = ["hlo_element_type_converter.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + +cc_library( name = "device_memory_allocator", srcs = ["device_memory_allocator.cc"], hdrs = ["device_memory_allocator.h"], diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 32abb1b559..fe537dfdf2 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -110,6 +110,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 6c72ef6849..a476a75027 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -68,6 +68,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -318,6 +319,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { [](const Shape&, const Shape&) { return true; }, /*enable_dot_strength_reduction=*/false); pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true); + pipeline.AddPass<HloElementTypeConverter>(BF16, F32); // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = module->config().intra_op_parallelism_threads() > 0 diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 85d9668f89..dd027986b2 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -516,7 +516,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{operand}, - /*supported_types=*/{F32})); + /*supported_types=*/{F32, BF16})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(window)) { diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index b9407818cd..7e88bbd631 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -50,11 +50,161 @@ using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; using tensorflow::strings::StrCat; +namespace { + +llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, + int64 mantissa_bits, + llvm::IRBuilder<>* ir_builder) { + // Integer and float types for casting and constant generation. + llvm::Type* float_type = x->getType(); + llvm::IntegerType* int_type = ir_builder->getInt32Ty(); + + // Cast the input value to an integer for bitwise manipulation. + llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type); + + if (mantissa_bits < 23) { + // Last remaining mantissa bit. + const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + + // Compute rounding bias for round-to-nearest with ties to even. This is + // equal to a base value of 0111... plus one bit if the last remaining + // mantissa bit is 1. + const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; + llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr( + ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), + (23 - mantissa_bits)); + llvm::Value* x_rounding_bias = ir_builder->CreateAdd( + x_last_mantissa_bit, + llvm::ConstantInt::get(int_type, base_rounding_bias)); + + // Add rounding bias, and mask out truncated bits. Note that the case + // where adding the rounding bias overflows into the exponent bits is + // correct; the non-masked mantissa bits will all be zero, and the + // exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias); + x_as_int = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); + } + + if (exponent_bits < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- + // significant bit -- is equal to 1.0f for all exponent sizes. Adding + // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- + // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' + // exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n is + // (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + llvm::Value* x_exponent = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + llvm::Value* x_overflows = ir_builder->CreateICmpUGT( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); + llvm::Value* x_underflows = ir_builder->CreateICmpULE( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); + + // Compute appropriately-signed values of zero and infinity. + llvm::Value* x_signed_zero = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); + llvm::Value* x_signed_inf = ir_builder->CreateOr( + x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + + // Force to zero or infinity if overflow or underflow. (Note that this + // truncates all denormal values to zero, rather than rounding them.) + x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int); + x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int); + } + + // Cast the result back to a floating-point type. + llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type); + + // Correct result for NaN inputs. + // + // The exponent handling will "normalize" NaN values to infinities, which is + // undesirable (except in the case with no mantissa bits, in which case it + // is mandatory). This logic also handles cases where mantissa-rounding + // causes a NaN's mantissa to overflow into the exponent bits, which would + // otherwise create an erroneous zero value. + // + // If the fast-math flags are set to assume no NaNs, the comparison is likely + // to be optimized away, so there's no point in even emitting it. + if (!ir_builder->getFastMathFlags().noNaNs()) { + llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x); + + if (mantissa_bits > 0) { + result = ir_builder->CreateSelect(x_is_nan, x, result); + } else { + result = ir_builder->CreateSelect( + x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); + } + } + return result; +} + +llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, + llvm::IRBuilder<>* ir_builder) { + auto reduced_precision = EmitReducePrecisionFloat( + f32_value, + /*exponent_bits=*/primitive_util::kBFloat16ExponentBits, + /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder); + auto as_int32 = + ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty()); + auto shifted = ir_builder->CreateLShr(as_int32, 16); + auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty()); + return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty()); +} + +llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, + llvm::IRBuilder<>* ir_builder) { + auto as_int16 = + ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty()); + auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty()); + auto shifted = ir_builder->CreateShl(as_int32, 16); + return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy()); +} + +llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, + PrimitiveType from_type, + PrimitiveType to_type, llvm::Module* module, + llvm::IRBuilder<>* ir_builder) { + if (primitive_util::IsSignedIntegralType(from_type)) { + return ir_builder->CreateSIToFP( + integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + } else { + CHECK(primitive_util::IsUnsignedIntegralType(from_type) || + from_type == PRED); + return ir_builder->CreateUIToFP( + integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + } +} + +} // namespace + StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { if (op->opcode() == HloOpcode::kCopy) { return operand_value; - } else if (operand_value->getType()->isIntegerTy()) { + } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + op->operand(0)->shape().element_type() == PRED) { return EmitIntegerUnaryOp(op, operand_value); } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) { return EmitComplexUnaryOp(op, operand_value); @@ -79,15 +229,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp( primitive_util::IsSignedIntegralType(to_type)); } if (primitive_util::IsFloatingPointType(to_type)) { - if (primitive_util::IsSignedIntegralType(from_type)) { - return ir_builder_->CreateSIToFP( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); - } - if (primitive_util::IsUnsignedIntegralType(from_type) || - from_type == PRED) { - return ir_builder_->CreateUIToFP( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + if (to_type == BF16) { + return EmitF32ToBF16( + EmitIntegralToFloating(operand_value, from_type, F32, module_, + ir_builder_), + ir_builder_); } + return EmitIntegralToFloating(operand_value, from_type, to_type, + module_, ir_builder_); } if (primitive_util::IsComplexType(to_type)) { auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( @@ -207,6 +356,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), nullptr); } + if (from_type == BF16) { + TF_RET_CHECK(to_type != BF16); + operand_value = EmitBF16ToF32(operand_value, ir_builder_); + from_type = F32; + if (from_type == to_type) { + return operand_value; + } + } + if (from_type == F32 && to_type == BF16) { + return EmitF32ToBF16(operand_value, ir_builder_); + } if (primitive_util::IsFloatingPointType(to_type)) { return ir_builder_->CreateFPCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); @@ -449,7 +609,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { PrimitiveType operand_type = op->operand(0)->shape().element_type(); - if (lhs_value->getType()->isIntegerTy()) { + if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + operand_type == PRED) { return EmitIntegerBinaryOp( op, lhs_value, rhs_value, primitive_util::IsSignedIntegralType(operand_type)); @@ -717,111 +878,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision( if (hlo->operand(0)->shape().element_type() != F32) { return Unimplemented("reduce-precision only implemented for F32"); } - - // Integer and float types for casting and constant generation. - llvm::Type* float_type = x->getType(); - llvm::IntegerType* int_type = ir_builder_->getInt32Ty(); - - // Cast the input value to an integer for bitwise manipulation. - llvm::Value* x_as_int = ir_builder_->CreateBitCast(x, int_type); - - if (hlo->mantissa_bits() < 23) { - // Last remaining mantissa bit. - const uint32_t last_mantissa_bit_mask = 1u << (23 - hlo->mantissa_bits()); - - // Compute rounding bias for round-to-nearest with ties to even. This is - // equal to a base value of 0111... plus one bit if the last remaining - // mantissa bit is 1. - const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; - llvm::Value* x_last_mantissa_bit = ir_builder_->CreateLShr( - ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), - (23 - hlo->mantissa_bits())); - llvm::Value* x_rounding_bias = ir_builder_->CreateAdd( - x_last_mantissa_bit, - llvm::ConstantInt::get(int_type, base_rounding_bias)); - - // Add rounding bias, and mask out truncated bits. Note that the case - // where adding the rounding bias overflows into the exponent bits is - // correct; the non-masked mantissa bits will all be zero, and the - // exponent will be incremented by one. - const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); - x_as_int = ir_builder_->CreateAdd(x_as_int, x_rounding_bias); - x_as_int = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); - } - - if (hlo->exponent_bits() < 8) { - // Masks for f32 values. - const uint32_t f32_sign_bit_mask = 1u << 31; - const uint32_t f32_exp_bits_mask = 0xffu << 23; - - // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- - // significant bit -- is equal to 1.0f for all exponent sizes. Adding - // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- - // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' - // exponent (corresponding to 0.0f). - // - // Thus, the f32 exponent corresponding to the highest non-infinite - // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 - // exponent corresponding to the lowest exponent for a bit size of n is - // (2^7-1) - 2^(n-1)-1. - // - // Note that we have already checked that exponents_bits >= 1. - const uint32_t f32_exponent_bias = (1 << 7) - 1; - const uint32_t reduced_exponent_bias = - (1 << (hlo->exponent_bits() - 1)) - 1; - const uint32_t reduced_max_exponent = - f32_exponent_bias + reduced_exponent_bias; - const uint32_t reduced_min_exponent = - f32_exponent_bias - reduced_exponent_bias; - - // Do we overflow or underflow? - llvm::Value* x_exponent = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); - llvm::Value* x_overflows = ir_builder_->CreateICmpUGT( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); - llvm::Value* x_underflows = ir_builder_->CreateICmpULE( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); - - // Compute appropriately-signed values of zero and infinity. - llvm::Value* x_signed_zero = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); - llvm::Value* x_signed_inf = ir_builder_->CreateOr( - x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); - - // Force to zero or infinity if overflow or underflow. (Note that this - // truncates all denormal values to zero, rather than rounding them.) - x_as_int = ir_builder_->CreateSelect(x_overflows, x_signed_inf, x_as_int); - x_as_int = ir_builder_->CreateSelect(x_underflows, x_signed_zero, x_as_int); - } - - // Cast the result back to a floating-point type. - llvm::Value* result = ir_builder_->CreateBitCast(x_as_int, float_type); - - // Correct result for NaN inputs. - // - // The exponent handling will "normalize" NaN values to infinities, which is - // undesirable (except in the case with no mantissa bits, in which case it - // is mandatory). This logic also handles cases where mantissa-rounding - // causes a NaN's mantissa to overflow into the exponent bits, which would - // otherwise create an erroneous zero value. - // - // If the fast-math flags are set to assume no NaNs, the comparison is likely - // to be optimized away, so there's no point in even emitting it. - if (!ir_builder_->getFastMathFlags().noNaNs()) { - llvm::Value* x_is_nan = ir_builder_->CreateFCmpUNO(x, x); - - if (hlo->mantissa_bits() > 0) { - result = ir_builder_->CreateSelect(x_is_nan, x, result); - } else { - result = ir_builder_->CreateSelect( - x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); - } - } - return result; + return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(), + /*mantissa_bits=*/hlo->mantissa_bits(), + ir_builder_); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp( diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc new file mode 100644 index 0000000000..1773bb401d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -0,0 +1,137 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace { + +HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) { + if (hlo->shape().element_type() != type) { + Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type); + hlo = hlo->parent()->AddInstruction( + HloInstruction::CreateConvert(shape, hlo)); + } + CHECK_EQ(hlo->shape().element_type(), type); + return hlo; +} + +bool HasOperandType(HloInstruction* hlo, PrimitiveType type) { + for (HloInstruction* operand : hlo->operands()) { + if (operand->shape().element_type() == type) { + return true; + } + } + return false; +} + +} // namespace + +HloElementTypeConverter::HloElementTypeConverter( + PrimitiveType eliminate_type, PrimitiveType replace_with_type) + : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {} + +StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) { + XLA_VLOG_LINES( + 3, "HloElementTypeConverter::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* computation : module->computations()) { + for (auto* hlo : computation->MakeInstructionPostOrder()) { + // These are ops where it does not make sense to convert them. + if (hlo->opcode() == HloOpcode::kParameter || + hlo->opcode() == HloOpcode::kConstant || + hlo->opcode() == HloOpcode::kTuple || + hlo->opcode() == HloOpcode::kConvert || + hlo->opcode() == HloOpcode::kGetTupleElement || + hlo->opcode() == HloOpcode::kInfeed || + hlo->opcode() == HloOpcode::kOutfeed) { + continue; + } + + // We cannot change a CustomCall since we have no way of adjusting the + // called binary to expect the updated type. + if (hlo->opcode() == HloOpcode::kCustomCall) { + continue; + } + + // These are ops with embedded computations where it suffices to convert + // the embedded computations instead of converting the ops themselves. + if (hlo->opcode() == HloOpcode::kWhile || + hlo->opcode() == HloOpcode::kCall || + hlo->opcode() == HloOpcode::kFusion || + hlo->opcode() == HloOpcode::kMap || + hlo->opcode() == HloOpcode::kReduce || + hlo->opcode() == HloOpcode::kReduceWindow || + hlo->opcode() == HloOpcode::kSelectAndScatter || + hlo->opcode() == HloOpcode::kConditional) { + continue; + } + TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); + + if (!HasOperandType(hlo, eliminate_type_)) { + // If this CHECK fires, then this was an instruction that does not take + // the elimination type as an operand but it does return it. This pass + // does not have a feature to change the output type in that case, so + // instead of silently failing to eliminate the type, it fails loudly. + TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_); + continue; + } + + std::vector<HloInstruction*> new_operands; + for (HloInstruction* operand : hlo->operands()) { + if (operand->shape().element_type() == eliminate_type_) { + operand = ToElementType(operand, replace_with_type_); + } + new_operands.push_back(operand); + } + + HloInstruction* new_hlo; + if (hlo->shape().element_type() == eliminate_type_) { + Shape shape = + ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); + new_hlo = ToElementType(new_hlo, eliminate_type_); + } else { + new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( + hlo->shape(), new_operands, hlo->GetModule())); + } + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo)); + changed = true; + } + } + XLA_VLOG_LINES( + 2, "HloElementTypeConverter::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h new file mode 100644 index 0000000000..2b109225d0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass that eliminates certain element types as the input or output of ops by +// inserting Convert ops. This allows a backend to support an element type while +// only actually implementing the Convert op for that element type. This is +// generally not the fastest approach, but it works. +class HloElementTypeConverter : public HloPassInterface { + public: + // eliminate_type is the type to eliminate as the input or output of ops, + // using Convert ops to replace it with replace_with_type. + HloElementTypeConverter(PrimitiveType eliminate_type, + PrimitiveType replace_with_type); + + tensorflow::StringPiece name() const override { + return "element_type_converter"; + } + + // Returns the pass on the module and returns whether the module was modified. + StatusOr<bool> Run(HloModule* module) override; + + private: + PrimitiveType eliminate_type_; + PrimitiveType replace_with_type_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ef5b6ad90e..9a0c94b1c7 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -142,6 +142,13 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, return llvm::Type::getInt8Ty(module->getContext()); case S16: case U16: + case BF16: + // For BF16 we just need some type that is 16 bits wide so that it will + // take up the right amount of space in memory. LLVM does not have a BF16 + // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so + // we can't map it directly to an LLVM type. We will not map a BF16 + // addition to an addition on this type (int16) - this is just the type + // used for storage. return llvm::Type::getInt16Ty(module->getContext()); case S32: case U32: @@ -280,6 +287,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, value = llvm::ConstantFP::get(ir_element_type, literal.Get<float>(*multi_index)); break; + case BF16: + value = llvm::ConstantInt::get( + ir_element_type, + tensorflow::bit_cast<uint16>(literal.Get<bfloat16>(*multi_index))); + break; case F64: value = llvm::ConstantFP::get(ir_element_type, literal.Get<double>(*multi_index)); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 6f03f1a4e0..6af01ae80d 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -802,8 +802,6 @@ xla_test( name = "bfloat16_test", srcs = ["bfloat16_test.cc"], blacklisted_backends = [ - "cpu", - "cpu_parallel", "gpu", ], shard_count = 40, diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 330575a02e..b32df74312 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -53,7 +53,7 @@ class ReduceWindowTestBase : public ClientLibraryTestBase { public: ErrorSpec DefaultErrorSpec() const { if (use_bfloat16()) { - return ErrorSpec(1e-1, 3e-2); + return ErrorSpec(1e-1, 5e-2); } else { return ErrorSpec(1e-3, 1e-3); } diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 93bce97a3e..780b292d1a 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -35,6 +35,19 @@ void PopulateWithRandomFloatingPointData(Literal* literal) { })); } +// The standard library does not have a case for bfloat16, unsurprisingly, so we +// handle that one specially. +template <> +void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), BF16); + std::minstd_rand0 engine; + std::uniform_real_distribution<float> generator(0.0f, 1.0f); + TF_CHECK_OK(literal->Populate<bfloat16>( + [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { + return static_cast<bfloat16>(generator(engine)); + })); +} + template <typename IntT> void PopulateWithRandomIntegralData(Literal* literal) { CHECK_EQ(literal->shape().element_type(), @@ -171,6 +184,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) { } std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape); switch (shape.element_type()) { + case BF16: + PopulateWithRandomFloatingPointData<bfloat16>(literal.get()); + break; case F32: PopulateWithRandomFloatingPointData<float>(literal.get()); break; |