aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-12 17:41:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-12 17:45:09 -0700
commit915a8ac568f0a67d6000ab70a665817deff7888c (patch)
treecbce859a80fba83b959798d5734fd302d7f96b36
parent4b178957917d95fbe6305381764e39453f6bb8d0 (diff)
[TF:XLA] Implement BitwiseAnd, BitwiseOr, and Invert operators.
PiperOrigin-RevId: 172038787
-rw-r--r--tensorflow/compiler/tests/BUILD1
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py15
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc42
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc1
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc23
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc15
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc203
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc68
10 files changed, 349 insertions, 30 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 72a0360de2..0eed475140 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -103,6 +103,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:bitwise_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 792c01327c..44b32b1668 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -24,6 +24,7 @@ from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
@@ -45,6 +46,10 @@ class BinaryOpsTest(XLATestCase):
equality_test = self.assertAllClose
equality_test(result, expected, rtol=1e-3)
+ def _testSymmetricBinary(self, op, a, b, expected, equality_test=None):
+ self._testBinary(op, a, b, expected, equality_test)
+ self._testBinary(op, b, a, expected, equality_test)
+
def ListsAreClose(self, result, expected, rtol):
"""Tests closeness of two lists of floats."""
self.assertEqual(len(result), len(expected))
@@ -193,6 +198,16 @@ class BinaryOpsTest(XLATestCase):
np.array([3, 3, -1, -9, -8], dtype=dtype),
np.array([2, -2, 7, 2, -4], dtype=dtype),
expected=np.array([1, -1, 0, -4, 2], dtype=dtype))
+ self._testSymmetricBinary(
+ bitwise_ops.bitwise_and,
+ np.array([0b1, 0b101, 0b1000], dtype=dtype),
+ np.array([0b0, 0b101, 0b1001], dtype=dtype),
+ expected=np.array([0b0, 0b101, 0b1000], dtype=dtype))
+ self._testSymmetricBinary(
+ bitwise_ops.bitwise_or,
+ np.array([0b1, 0b101, 0b1000], dtype=dtype),
+ np.array([0b0, 0b101, 0b1001], dtype=dtype),
+ expected=np.array([0b1, 0b101, 0b1001], dtype=dtype))
def testNumericOps(self):
for dtype in self.numeric_types:
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index fef12d9397..56e10a1587 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -1168,6 +1168,28 @@ TEST_F(OpTest, BiasAddV1) {
});
}
+TEST_F(OpTest, BitwiseAnd) {
+ Repeatedly([this]() {
+ DataType type = DT_INT32;
+ auto dims = BroadcastableDims();
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseAnd")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, BitwiseOr) {
+ Repeatedly([this]() {
+ DataType type = DT_INT32;
+ auto dims = BroadcastableDims();
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseOr")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
+ });
+}
+
TEST_F(OpTest, BroadcastArgs) {
Repeatedly([this]() {
// TODO(phawkins): only int32 seems to be implemented in Tensorflow.
@@ -1729,6 +1751,14 @@ TEST_F(OpTest, GreaterEqual) {
});
}
+TEST_F(OpTest, Invert) {
+ Repeatedly([this]() {
+ DataType type = DT_INT32;
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Invert").RandomInput(type).Attr("T", type));
+ });
+}
+
TEST_F(OpTest, L2Loss) {
Repeatedly([this]() {
DataType type = DT_FLOAT;
@@ -1791,28 +1821,28 @@ TEST_F(OpTest, Log1p) {
});
}
-TEST_F(OpTest, BooleanAnd) {
+TEST_F(OpTest, LogicalAnd) {
Repeatedly([this]() {
auto dims = BroadcastableDims();
return ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("BooleanAnd")
+ OpTestBuilder("LogicalAnd")
.RandomInput(DT_BOOL, dims.first)
.RandomInput(DT_BOOL, dims.second));
});
}
-TEST_F(OpTest, BooleanNot) {
+TEST_F(OpTest, LogicalNot) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("BooleanNot").RandomInput(DT_BOOL));
+ OpTestBuilder("LogicalNot").RandomInput(DT_BOOL));
});
}
-TEST_F(OpTest, BooleanOr) {
+TEST_F(OpTest, LogicalOr) {
Repeatedly([this]() {
auto dims = BroadcastableDims();
return ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("BooleanOr")
+ OpTestBuilder("LogicalOr")
.RandomInput(DT_BOOL, dims.first)
.RandomInput(DT_BOOL, dims.second));
});
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 6f19834160..71221b284d 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -26,6 +26,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@@ -327,6 +328,13 @@ class UnaryOpsTest(XLATestCase):
np.array([-1, -0.5, 0, 0.3], dtype=dtype),
expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype))
+ def testIntOps(self):
+ for dtype in self.int_types:
+ self._assertOpOutputMatchesExpected(
+ bitwise_ops.invert,
+ np.array([0, -1, 1, 16, 42], dtype=dtype),
+ expected=np.array([-1, 0, -2, -17, -43], dtype=dtype))
+
def testNumericOps(self):
for dtype in self.numeric_types:
self._assertOpOutputMatchesExpected(
@@ -558,5 +566,6 @@ class UnaryOpsTest(XLATestCase):
log_eps + ten, -log_eps, -log_eps - one, -log_eps + one,
-log_eps - ten, -log_eps + ten], dtype)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index a180f1e4d9..d635507989 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -96,6 +96,8 @@ static xla::ComputationDataHandle FloorModImpl(xla::ComputationBuilder* b,
XLA_MAKE_BINARY(FloorMod,
FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+XLA_MAKE_BINARY(BitwiseAnd, b->And(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(BitwiseOr, b->Or(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(LogicalAnd, b->And(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(LogicalOr, b->Or(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions));
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index 8f04fc94be..651bbe2b40 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -87,6 +87,7 @@ XLAJIT_MAKE_UNARY(Log, b->Log(x));
// TODO(b/34703906): use a more accurate implementation of log1p.
XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x)));
+XLAJIT_MAKE_UNARY(Invert, b->Not(x));
XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x));
XLAJIT_MAKE_UNARY(Neg, b->Neg(x));
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 3a8f70a8ef..fb4d233d04 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -126,14 +126,21 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
}
case HloOpcode::kNegate:
return ir_builder_->CreateNeg(operand_value);
- case HloOpcode::kNot:
- // It is not sufficient to just call CreateNot() here because a PRED is
- // represented as an i8 and the truth value is stored only in the bottom
- // bit.
- return ir_builder_->CreateZExt(
- ir_builder_->CreateNot(ir_builder_->CreateTrunc(
- operand_value, ir_builder_->getInt1Ty())),
- llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_));
+ case HloOpcode::kNot: {
+ auto type = op->shape().element_type();
+ if (type == PRED) {
+ // It is not sufficient to just call CreateNot() here because a PRED
+ // is represented as an i8 and the truth value is stored only in the
+ // bottom bit.
+ return ir_builder_->CreateZExt(
+ ir_builder_->CreateNot(ir_builder_->CreateTrunc(
+ operand_value, ir_builder_->getInt1Ty())),
+ llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_));
+ } else if (primitive_util::IsIntegralType(type)) {
+ return ir_builder_->CreateNot(operand_value);
+ }
+ return Unimplemented("unary op Not is not defined for type '%d'", type);
+ }
default:
return Unimplemented("unary integer op '%s'",
HloOpcodeString(op->opcode()).c_str());
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index a9f65331e2..a091a067c1 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -323,10 +323,11 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return arg;
case UNOP_NOT:
- if (arg.element_type() != PRED) {
+ if (arg.element_type() != PRED &&
+ !primitive_util::IsIntegralType(arg.element_type())) {
return InvalidArgument(
- "expected pred element type in argument to logical-not operation; "
- "got %s",
+ "expected pred or an integral element type in argument to not "
+ "operation; got %s",
PrimitiveType_Name(arg.element_type()).c_str());
}
return arg;
@@ -752,15 +753,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
case BINOP_AND:
case BINOP_OR:
- if (lhs.element_type() != PRED) {
+ if (lhs.element_type() != PRED &&
+ !primitive_util::IsIntegralType(lhs.element_type())) {
return InvalidArgument(
- "expected pred element type in argument to logical and/or "
- "operation; got %s",
+ "expected pred or integral type in argument to and/or operation; "
+ "got %s",
PrimitiveType_Name(lhs.element_type()).c_str());
}
return InferElementwiseBinaryOpShape(operation, lhs, rhs,
broadcast_dimensions);
-
case BINOP_EQ:
case BINOP_GE:
case BINOP_GT:
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 08b39b6379..eb931dcff3 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -496,7 +496,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
ComputeAndCompareR1<uint32>(&builder, expected, {});
}
-XLA_TEST_F(ArrayElementwiseOpTest, BooleanAnd) {
+XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<bool>({false, false, true, true});
auto b = builder.ConstantR1<bool>({false, true, false, true});
@@ -505,7 +505,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, BooleanAnd) {
ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {});
}
-XLA_TEST_F(ArrayElementwiseOpTest, BooleanAndZeroElement) {
+XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<bool>({{false, false}, {true, true}});
+ auto b = builder.ConstantR2<bool>({{false, true}, {false, true}});
+ auto out = builder.And(a, b);
+
+ Array2D<bool> expected_array({{false, false}, {false, true}});
+ ComputeAndCompareR2<bool>(&builder, expected_array, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<bool>({});
auto b = builder.ConstantR1<bool>({});
@@ -514,7 +524,63 @@ XLA_TEST_F(ArrayElementwiseOpTest, BooleanAndZeroElement) {
ComputeAndCompareR1<bool>(&builder, {}, {});
}
-XLA_TEST_F(ArrayElementwiseOpTest, BooleanOr) {
+XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({0, -1, -8});
+ auto b = builder.ConstantR1<int32>({5, -7, 12});
+ auto out = builder.And(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {0, -7, 8}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<int32>({{0, -5}, {-1, 5}});
+ auto b = builder.ConstantR2<int32>({{1, -6}, {4, 5}});
+ auto out = builder.And(a, b);
+
+ Array2D<int32> expected_array({{0, -6}, {4, 5}});
+ ComputeAndCompareR2<int32>(&builder, expected_array, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({});
+ auto b = builder.ConstantR1<int32>({});
+ auto out = builder.And(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({0, 1, 8});
+ auto b = builder.ConstantR1<int32>({5, 7, 12});
+ auto out = builder.And(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {0, 1, 8}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<uint32>({{0, 1}, {3, 8}});
+ auto b = builder.ConstantR2<uint32>({{1, 0}, {7, 6}});
+ auto out = builder.And(a, b);
+
+ Array2D<uint32> expected_array({{0, 0}, {3, 0}});
+ ComputeAndCompareR2<uint32>(&builder, expected_array, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint32>({});
+ auto b = builder.ConstantR1<uint32>({});
+ auto out = builder.And(a, b);
+
+ ComputeAndCompareR1<uint32>(&builder, {}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<bool>({false, false, true, true});
auto b = builder.ConstantR1<bool>({false, true, false, true});
@@ -523,7 +589,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, BooleanOr) {
ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {});
}
-XLA_TEST_F(ArrayElementwiseOpTest, BooleanOrZeroElement) {
+XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<bool>({{false, false}, {true, true}});
+ auto b = builder.ConstantR2<bool>({{false, true}, {false, true}});
+ auto out = builder.Or(a, b);
+
+ Array2D<bool> expected_array({{false, true}, {true, true}});
+ ComputeAndCompareR2<bool>(&builder, expected_array, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<bool>({});
auto b = builder.ConstantR1<bool>({});
@@ -532,7 +608,63 @@ XLA_TEST_F(ArrayElementwiseOpTest, BooleanOrZeroElement) {
ComputeAndCompareR1<bool>(&builder, {}, {});
}
-XLA_TEST_F(ArrayElementwiseOpTest, BooleanNot) {
+XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({0, -1, 8});
+ auto b = builder.ConstantR1<int32>({5, -7, 4});
+ auto out = builder.Or(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {5, -1, 12}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<int32>({{0, -1}, {8, 8}});
+ auto b = builder.ConstantR2<int32>({{5, -7}, {4, 1}});
+ auto out = builder.Or(a, b);
+
+ Array2D<int32> expected_array({{5, -1}, {12, 9}});
+ ComputeAndCompareR2<int32>(&builder, expected_array, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({});
+ auto b = builder.ConstantR1<int32>({});
+ auto out = builder.Or(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint32>({0, 1, 8});
+ auto b = builder.ConstantR1<uint32>({5, 7, 4});
+ auto out = builder.Or(a, b);
+
+ ComputeAndCompareR1<uint32>(&builder, {5, 7, 12}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<uint32>({{0, 1}, {8, 8}});
+ auto b = builder.ConstantR2<uint32>({{5, 7}, {4, 1}});
+ auto out = builder.Or(a, b);
+
+ Array2D<uint32> expected_array({{5, 7}, {12, 9}});
+ ComputeAndCompareR2<uint32>(&builder, expected_array, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint32>({});
+ auto b = builder.ConstantR1<uint32>({});
+ auto out = builder.Or(a, b);
+
+ ComputeAndCompareR1<uint32>(&builder, {}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<bool>({false, true, true, false});
auto out = builder.Not(a);
@@ -540,7 +672,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, BooleanNot) {
ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
}
-XLA_TEST_F(ArrayElementwiseOpTest, BooleanNotZeroElement) {
+XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<bool>({{false, true}, {true, false}});
+ auto out = builder.Not(a);
+
+ Array2D<bool> expected_array({{true, false}, {false, true}});
+ ComputeAndCompareR2<bool>(&builder, expected_array, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<bool>({});
auto out = builder.Not(a);
@@ -548,6 +689,56 @@ XLA_TEST_F(ArrayElementwiseOpTest, BooleanNotZeroElement) {
ComputeAndCompareR1<bool>(&builder, {}, {});
}
+XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({-1, 0, 1});
+ auto out = builder.Not(a);
+
+ ComputeAndCompareR1<int32>(&builder, {0, -1, -2}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<int32>({{-1, 0}, {1, 8}});
+ auto out = builder.Not(a);
+
+ Array2D<int32> expected_array({{0, -1}, {-2, -9}});
+ ComputeAndCompareR2<int32>(&builder, expected_array, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({});
+ auto out = builder.Not(a);
+
+ ComputeAndCompareR1<int32>(&builder, {}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint32>({0, 4294967295});
+ auto out = builder.Not(a);
+
+ ComputeAndCompareR1<uint32>(&builder, {4294967295, 0}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<uint32>({{0, 4294967295}, {1, 4294967294}});
+ auto out = builder.Not(a);
+
+ Array2D<uint32> expected_array({{4294967295, 0}, {4294967294, 1}});
+ ComputeAndCompareR2<uint32>(&builder, expected_array, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint32>({});
+ auto out = builder.Not(a);
+
+ ComputeAndCompareR1<uint32>(&builder, {}, {});
+}
+
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
SetFastMathDisabled(true);
ComputationBuilder builder(client_, TestName());
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index da84d185ca..b5e7570778 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -459,7 +459,7 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) {
ComputeAndCompareR0<uint32>(&builder, 2, {});
}
-XLA_TEST_F(ScalarComputationsTest, BooleanAnd) {
+XLA_TEST_F(ScalarComputationsTest, AndBool) {
for (bool x : {false, true}) {
for (bool y : {false, true}) {
ComputationBuilder builder(client_, TestName());
@@ -470,7 +470,29 @@ XLA_TEST_F(ScalarComputationsTest, BooleanAnd) {
}
}
-XLA_TEST_F(ScalarComputationsTest, BooleanOr) {
+XLA_TEST_F(ScalarComputationsTest, AndS32) {
+ for (int32 x : {0, 8}) {
+ for (int32 y : {1, -16}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.And(builder.ConstantR0<int32>(x), builder.ConstantR0<int32>(y));
+
+ ComputeAndCompareR0<int32>(&builder, x & y, {});
+ }
+ }
+}
+
+XLA_TEST_F(ScalarComputationsTest, AndU32) {
+ for (uint32 x : {0, 8}) {
+ for (uint32 y : {1, 16}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.And(builder.ConstantR0<uint32>(x), builder.ConstantR0<uint32>(y));
+
+ ComputeAndCompareR0<uint32>(&builder, x & y, {});
+ }
+ }
+}
+
+XLA_TEST_F(ScalarComputationsTest, OrBool) {
for (bool x : {false, true}) {
for (bool y : {false, true}) {
ComputationBuilder builder(client_, TestName());
@@ -481,7 +503,29 @@ XLA_TEST_F(ScalarComputationsTest, BooleanOr) {
}
}
-XLA_TEST_F(ScalarComputationsTest, BooleanNot) {
+XLA_TEST_F(ScalarComputationsTest, OrS32) {
+ for (int32 x : {0, 8}) {
+ for (int32 y : {1, -16}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Or(builder.ConstantR0<int32>(x), builder.ConstantR0<int32>(y));
+
+ ComputeAndCompareR0<int32>(&builder, x | y, {});
+ }
+ }
+}
+
+XLA_TEST_F(ScalarComputationsTest, OrU32) {
+ for (uint32 x : {0, 8}) {
+ for (uint32 y : {1, 16}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Or(builder.ConstantR0<uint32>(x), builder.ConstantR0<uint32>(y));
+
+ ComputeAndCompareR0<uint32>(&builder, x | y, {});
+ }
+ }
+}
+
+XLA_TEST_F(ScalarComputationsTest, NotBool) {
for (bool x : {false, true}) {
ComputationBuilder builder(client_, TestName());
builder.Not(builder.ConstantR0<bool>(x));
@@ -490,6 +534,24 @@ XLA_TEST_F(ScalarComputationsTest, BooleanNot) {
}
}
+XLA_TEST_F(ScalarComputationsTest, NotS32) {
+ for (int32 x : {-1, 0, 1}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Not(builder.ConstantR0<int32>(x));
+
+ ComputeAndCompareR0<int32>(&builder, ~x, {});
+ }
+}
+
+XLA_TEST_F(ScalarComputationsTest, NotU32) {
+ for (uint32 x : {0, 1, 2}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Not(builder.ConstantR0<uint32>(x));
+
+ ComputeAndCompareR0<uint32>(&builder, ~x, {});
+ }
+}
+
XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) {
ComputationBuilder builder(client_, TestName());
builder.Select(builder.ConstantR0<bool>(true), // The predicate.