aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-07-23 06:15:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-23 06:19:46 -0700
commit21d0205916eded7e2bf2f26e43dd41b2f86cba3f (patch)
tree2b6cd9be28ff6cbadc227aa9f7a26946005ede72
parent89e06304aad35bfb019a8c10f39fc1ead83e0f99 (diff)
[XLA:CPU,GPU] Implement more cases of convert
This adds support for {S32,U32,F32} -> PRED and adds test for several other cases as well. PiperOrigin-RevId: 205650630
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc18
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc70
2 files changed, 83 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 1eedd85363..b58b87a978 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -222,10 +222,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
PrimitiveType to_type = op->shape().element_type();
- CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED);
+ CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED)
+ << from_type;
if (from_type == to_type) {
return operand_value;
}
+ if (to_type == PRED) {
+ return b_->CreateZExt(
+ b_->CreateICmpNE(operand_value, llvm::ConstantInt::get(
+ operand_value->getType(), 0)),
+ llvm_ir::PrimitiveTypeToIrType(PRED, module_));
+ }
if (primitive_util::IsIntegralType(to_type)) {
return b_->CreateIntCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_),
@@ -342,7 +349,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
PrimitiveType to_type = op->shape().element_type();
- CHECK(primitive_util::IsFloatingPointType(from_type));
+ CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type;
if (from_type == to_type) {
return operand_value;
}
@@ -369,6 +376,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
if (from_type == F32 && to_type == BF16) {
return EmitF32ToBF16(operand_value, b_);
}
+ if (to_type == PRED) {
+ return b_->CreateZExt(
+ b_->CreateFCmpUNE(
+ operand_value,
+ llvm::ConstantFP::get(operand_value->getType(), 0.0)),
+ llvm_ir::PrimitiveTypeToIrType(PRED, module_));
+ }
if (primitive_util::IsFloatingPointType(to_type)) {
return b_->CreateFPCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index dca57fd1c7..0fb6853e3f 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <array>
#include <cstdint>
#include <limits>
#include <memory>
@@ -52,13 +53,67 @@ TEST_F(ConvertTest, ConvertR1S32ToR1S32) {
ComputeAndCompareR1<int32>(&builder, expected, {});
}
+TEST_F(ConvertTest, ConvertR1S32ToR1U32) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<int32>(&builder, {42, 64});
+ ConvertElementType(a, U32);
+
+ std::vector<uint32> expected = {42, 64};
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1S32ToR1PRED) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<int32>(&builder, {42, 0, -64});
+ ConvertElementType(a, PRED);
+
+ std::array<bool, 3> expected = {true, false, true};
+ ComputeAndCompareR1<bool>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1U32ToR1U32) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<uint32>(&builder, {42, 64});
+ ConvertElementType(a, U32);
+
+ std::vector<uint32> expected = {42, 64};
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1U32ToR1S32) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<uint32>(&builder, {42, 64});
+ ConvertElementType(a, S32);
+
+ std::vector<int32> expected = {42, 64};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1U32ToR1PRED) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<uint32>(&builder, {42, 0, 64});
+ ConvertElementType(a, PRED);
+
+ std::array<bool, 3> expected = {true, false, true};
+ ComputeAndCompareR1<bool>(&builder, expected, {});
+}
+
TEST_F(ConvertTest, ConvertR1F32ToR1F32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
ConvertElementType(a, F32);
std::vector<float> expected = {42.0f, 64.0f};
- ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1F32ToR1PRED) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<float>(&builder, {42.0f, 0.0f, 64.0f});
+ ConvertElementType(a, PRED);
+
+ std::array<bool, 3> expected = {true, false, true};
+ ComputeAndCompareR1<bool>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1S32ToR1F32) {
@@ -67,7 +122,7 @@ TEST_F(ConvertTest, ConvertR1S32ToR1F32) {
ConvertElementType(a, F32);
std::vector<float> expected = {42.0f, 64.0f};
- ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareR1<float>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1PREDToR1S32) {
@@ -79,6 +134,15 @@ TEST_F(ConvertTest, ConvertR1PREDToR1S32) {
ComputeAndCompareR1<int32>(&builder, expected, {});
}
+TEST_F(ConvertTest, ConvertR1PREDToR1U32) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<bool>(&builder, {true, false, true});
+ ConvertElementType(a, U32);
+
+ std::vector<uint32> expected = {1, 0, 1};
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
TEST_F(ConvertTest, ConvertR1PREDToR1F32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<bool>(&builder, {true, false, true});
@@ -94,7 +158,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) {
ConvertElementType(a, F32);
std::vector<float> expected = {};
- ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareR1<float>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1F32ToR1S32) {