diff options
-rw-r--r-- | tensorflow/compiler/xla/service/elemental_ir_emitter.cc | 18 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/convert_test.cc | 70 |
2 files changed, 83 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 1eedd85363..b58b87a978 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -222,10 +222,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp( case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); PrimitiveType to_type = op->shape().element_type(); - CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED); + CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED) + << from_type; if (from_type == to_type) { return operand_value; } + if (to_type == PRED) { + return b_->CreateZExt( + b_->CreateICmpNE(operand_value, llvm::ConstantInt::get( + operand_value->getType(), 0)), + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + } if (primitive_util::IsIntegralType(to_type)) { return b_->CreateIntCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), @@ -342,7 +349,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); PrimitiveType to_type = op->shape().element_type(); - CHECK(primitive_util::IsFloatingPointType(from_type)); + CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type; if (from_type == to_type) { return operand_value; } @@ -369,6 +376,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( if (from_type == F32 && to_type == BF16) { return EmitF32ToBF16(operand_value, b_); } + if (to_type == PRED) { + return b_->CreateZExt( + b_->CreateFCmpUNE( + operand_value, + llvm::ConstantFP::get(operand_value->getType(), 0.0)), + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + } if (primitive_util::IsFloatingPointType(to_type)) { return b_->CreateFPCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index dca57fd1c7..0fb6853e3f 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <array> #include <cstdint> #include <limits> #include <memory> @@ -52,13 +53,67 @@ TEST_F(ConvertTest, ConvertR1S32ToR1S32) { ComputeAndCompareR1<int32>(&builder, expected, {}); } +TEST_F(ConvertTest, ConvertR1S32ToR1U32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<int32>(&builder, {42, 64}); + ConvertElementType(a, U32); + + std::vector<uint32> expected = {42, 64}; + ComputeAndCompareR1<uint32>(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1S32ToR1PRED) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<int32>(&builder, {42, 0, -64}); + ConvertElementType(a, PRED); + + std::array<bool, 3> expected = {true, false, true}; + ComputeAndCompareR1<bool>(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1U32ToR1U32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<uint32>(&builder, {42, 64}); + ConvertElementType(a, U32); + + std::vector<uint32> expected = {42, 64}; + ComputeAndCompareR1<uint32>(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1U32ToR1S32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<uint32>(&builder, {42, 64}); + ConvertElementType(a, S32); + + std::vector<int32> expected = {42, 64}; + ComputeAndCompareR1<int32>(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1U32ToR1PRED) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<uint32>(&builder, {42, 0, 64}); + ConvertElementType(a, PRED); + + std::array<bool, 3> expected = {true, false, true}; + ComputeAndCompareR1<bool>(&builder, expected, {}); +} + TEST_F(ConvertTest, ConvertR1F32ToR1F32) { XlaBuilder builder(TestName()); auto a = ConstantR1<float>(&builder, {42.0f, 64.0f}); ConvertElementType(a, F32); std::vector<float> expected = {42.0f, 64.0f}; - ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1<float>(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1F32ToR1PRED) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<float>(&builder, {42.0f, 0.0f, 64.0f}); + ConvertElementType(a, PRED); + + std::array<bool, 3> expected = {true, false, true}; + ComputeAndCompareR1<bool>(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1S32ToR1F32) { @@ -67,7 +122,7 @@ TEST_F(ConvertTest, ConvertR1S32ToR1F32) { ConvertElementType(a, F32); std::vector<float> expected = {42.0f, 64.0f}; - ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1<float>(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1PREDToR1S32) { @@ -79,6 +134,15 @@ TEST_F(ConvertTest, ConvertR1PREDToR1S32) { ComputeAndCompareR1<int32>(&builder, expected, {}); } +TEST_F(ConvertTest, ConvertR1PREDToR1U32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<bool>(&builder, {true, false, true}); + ConvertElementType(a, U32); + + std::vector<uint32> expected = {1, 0, 1}; + ComputeAndCompareR1<uint32>(&builder, expected, {}); +} + TEST_F(ConvertTest, ConvertR1PREDToR1F32) { XlaBuilder builder(TestName()); auto a = ConstantR1<bool>(&builder, {true, false, true}); @@ -94,7 +158,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { ConvertElementType(a, F32); std::vector<float> expected = {}; - ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1<float>(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1F32ToR1S32) { |