diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-12 17:41:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-12 17:45:09 -0700 |
commit | 915a8ac568f0a67d6000ab70a665817deff7888c (patch) | |
tree | cbce859a80fba83b959798d5734fd302d7f96b36 | |
parent | 4b178957917d95fbe6305381764e39453f6bb8d0 (diff) |
[TF:XLA] Implement BitwiseAnd, BitwiseOr, and Invert operators.
PiperOrigin-RevId: 172038787
-rw-r--r-- | tensorflow/compiler/tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tests/binary_ops_test.py | 15 | ||||
-rw-r--r-- | tensorflow/compiler/tests/randomized_tests.cc | 42 | ||||
-rw-r--r-- | tensorflow/compiler/tests/unary_ops_test.py | 9 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/binary_ops.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/unary_ops.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/elemental_ir_emitter.cc | 23 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 15 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc | 203 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/scalar_computations_test.cc | 68 |
10 files changed, 349 insertions, 30 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 72a0360de2..0eed475140 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -103,6 +103,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", + "//tensorflow/python:bitwise_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 792c01327c..44b32b1668 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -24,6 +24,7 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops @@ -45,6 +46,10 @@ class BinaryOpsTest(XLATestCase): equality_test = self.assertAllClose equality_test(result, expected, rtol=1e-3) + def _testSymmetricBinary(self, op, a, b, expected, equality_test=None): + self._testBinary(op, a, b, expected, equality_test) + self._testBinary(op, b, a, expected, equality_test) + def ListsAreClose(self, result, expected, rtol): """Tests closeness of two lists of floats.""" self.assertEqual(len(result), len(expected)) @@ -193,6 +198,16 @@ class BinaryOpsTest(XLATestCase): np.array([3, 3, -1, -9, -8], dtype=dtype), np.array([2, -2, 7, 2, -4], dtype=dtype), expected=np.array([1, -1, 0, -4, 2], dtype=dtype)) + self._testSymmetricBinary( + bitwise_ops.bitwise_and, + np.array([0b1, 0b101, 0b1000], dtype=dtype), + np.array([0b0, 0b101, 0b1001], dtype=dtype), + expected=np.array([0b0, 0b101, 0b1000], dtype=dtype)) + self._testSymmetricBinary( + bitwise_ops.bitwise_or, + np.array([0b1, 0b101, 0b1000], dtype=dtype), + np.array([0b0, 0b101, 0b1001], dtype=dtype), + expected=np.array([0b1, 0b101, 0b1001], dtype=dtype)) def testNumericOps(self): for dtype in self.numeric_types: diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index fef12d9397..56e10a1587 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -1168,6 +1168,28 @@ TEST_F(OpTest, BiasAddV1) { }); } +TEST_F(OpTest, BitwiseAnd) { + Repeatedly([this]() { + DataType type = DT_INT32; + auto dims = BroadcastableDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseAnd") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, BitwiseOr) { + Repeatedly([this]() { + DataType type = DT_INT32; + auto dims = BroadcastableDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseOr") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); + }); +} + TEST_F(OpTest, BroadcastArgs) { Repeatedly([this]() { // TODO(phawkins): only int32 seems to be implemented in Tensorflow. @@ -1729,6 +1751,14 @@ TEST_F(OpTest, GreaterEqual) { }); } +TEST_F(OpTest, Invert) { + Repeatedly([this]() { + DataType type = DT_INT32; + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Invert").RandomInput(type).Attr("T", type)); + }); +} + TEST_F(OpTest, L2Loss) { Repeatedly([this]() { DataType type = DT_FLOAT; @@ -1791,28 +1821,28 @@ TEST_F(OpTest, Log1p) { }); } -TEST_F(OpTest, BooleanAnd) { +TEST_F(OpTest, LogicalAnd) { Repeatedly([this]() { auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("BooleanAnd") + OpTestBuilder("LogicalAnd") .RandomInput(DT_BOOL, dims.first) .RandomInput(DT_BOOL, dims.second)); }); } -TEST_F(OpTest, BooleanNot) { +TEST_F(OpTest, LogicalNot) { Repeatedly([this]() { return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("BooleanNot").RandomInput(DT_BOOL)); + OpTestBuilder("LogicalNot").RandomInput(DT_BOOL)); }); } -TEST_F(OpTest, BooleanOr) { +TEST_F(OpTest, LogicalOr) { Repeatedly([this]() { auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("BooleanOr") + OpTestBuilder("LogicalOr") .RandomInput(DT_BOOL, dims.first) .RandomInput(DT_BOOL, dims.second)); }); diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 6f19834160..71221b284d 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -26,6 +26,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -327,6 +328,13 @@ class UnaryOpsTest(XLATestCase): np.array([-1, -0.5, 0, 0.3], dtype=dtype), expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype)) + def testIntOps(self): + for dtype in self.int_types: + self._assertOpOutputMatchesExpected( + bitwise_ops.invert, + np.array([0, -1, 1, 16, 42], dtype=dtype), + expected=np.array([-1, 0, -2, -17, -43], dtype=dtype)) + def testNumericOps(self): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( @@ -558,5 +566,6 @@ class UnaryOpsTest(XLATestCase): log_eps + ten, -log_eps, -log_eps - one, -log_eps + one, -log_eps - ten, -log_eps + ten], dtype) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index a180f1e4d9..d635507989 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -96,6 +96,8 @@ static xla::ComputationDataHandle FloorModImpl(xla::ComputationBuilder* b, XLA_MAKE_BINARY(FloorMod, FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(BitwiseAnd, b->And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseOr, b->Or(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(LogicalAnd, b->And(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(LogicalOr, b->Or(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions)); diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 8f04fc94be..651bbe2b40 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -87,6 +87,7 @@ XLAJIT_MAKE_UNARY(Log, b->Log(x)); // TODO(b/34703906): use a more accurate implementation of log1p. XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x))); +XLAJIT_MAKE_UNARY(Invert, b->Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 3a8f70a8ef..fb4d233d04 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -126,14 +126,21 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp( } case HloOpcode::kNegate: return ir_builder_->CreateNeg(operand_value); - case HloOpcode::kNot: - // It is not sufficient to just call CreateNot() here because a PRED is - // represented as an i8 and the truth value is stored only in the bottom - // bit. - return ir_builder_->CreateZExt( - ir_builder_->CreateNot(ir_builder_->CreateTrunc( - operand_value, ir_builder_->getInt1Ty())), - llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_)); + case HloOpcode::kNot: { + auto type = op->shape().element_type(); + if (type == PRED) { + // It is not sufficient to just call CreateNot() here because a PRED + // is represented as an i8 and the truth value is stored only in the + // bottom bit. + return ir_builder_->CreateZExt( + ir_builder_->CreateNot(ir_builder_->CreateTrunc( + operand_value, ir_builder_->getInt1Ty())), + llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_)); + } else if (primitive_util::IsIntegralType(type)) { + return ir_builder_->CreateNot(operand_value); + } + return Unimplemented("unary op Not is not defined for type '%d'", type); + } default: return Unimplemented("unary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index a9f65331e2..a091a067c1 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -323,10 +323,11 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, return arg; case UNOP_NOT: - if (arg.element_type() != PRED) { + if (arg.element_type() != PRED && + !primitive_util::IsIntegralType(arg.element_type())) { return InvalidArgument( - "expected pred element type in argument to logical-not operation; " - "got %s", + "expected pred or an integral element type in argument to not " + "operation; got %s", PrimitiveType_Name(arg.element_type()).c_str()); } return arg; @@ -752,15 +753,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( case BINOP_AND: case BINOP_OR: - if (lhs.element_type() != PRED) { + if (lhs.element_type() != PRED && + !primitive_util::IsIntegralType(lhs.element_type())) { return InvalidArgument( - "expected pred element type in argument to logical and/or " - "operation; got %s", + "expected pred or integral type in argument to and/or operation; " + "got %s", PrimitiveType_Name(lhs.element_type()).c_str()); } return InferElementwiseBinaryOpShape(operation, lhs, rhs, broadcast_dimensions); - case BINOP_EQ: case BINOP_GE: case BINOP_GT: diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 08b39b6379..eb931dcff3 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -496,7 +496,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { ComputeAndCompareR1<uint32>(&builder, expected, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, BooleanAnd) { +XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1<bool>({false, false, true, true}); auto b = builder.ConstantR1<bool>({false, true, false, true}); @@ -505,7 +505,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, BooleanAnd) { ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, BooleanAndZeroElement) { +XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2<bool>({{false, false}, {true, true}}); + auto b = builder.ConstantR2<bool>({{false, true}, {false, true}}); + auto out = builder.And(a, b); + + Array2D<bool> expected_array({{false, false}, {false, true}}); + ComputeAndCompareR2<bool>(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1<bool>({}); auto b = builder.ConstantR1<bool>({}); @@ -514,7 +524,63 @@ XLA_TEST_F(ArrayElementwiseOpTest, BooleanAndZeroElement) { ComputeAndCompareR1<bool>(&builder, {}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, BooleanOr) { +XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<int32>({0, -1, -8}); + auto b = builder.ConstantR1<int32>({5, -7, 12}); + auto out = builder.And(a, b); + + ComputeAndCompareR1<int32>(&builder, {0, -7, 8}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2<int32>({{0, -5}, {-1, 5}}); + auto b = builder.ConstantR2<int32>({{1, -6}, {4, 5}}); + auto out = builder.And(a, b); + + Array2D<int32> expected_array({{0, -6}, {4, 5}}); + ComputeAndCompareR2<int32>(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<int32>({}); + auto b = builder.ConstantR1<int32>({}); + auto out = builder.And(a, b); + + ComputeAndCompareR1<int32>(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<int32>({0, 1, 8}); + auto b = builder.ConstantR1<int32>({5, 7, 12}); + auto out = builder.And(a, b); + + ComputeAndCompareR1<int32>(&builder, {0, 1, 8}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2<uint32>({{0, 1}, {3, 8}}); + auto b = builder.ConstantR2<uint32>({{1, 0}, {7, 6}}); + auto out = builder.And(a, b); + + Array2D<uint32> expected_array({{0, 0}, {3, 0}}); + ComputeAndCompareR2<uint32>(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<uint32>({}); + auto b = builder.ConstantR1<uint32>({}); + auto out = builder.And(a, b); + + ComputeAndCompareR1<uint32>(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1<bool>({false, false, true, true}); auto b = builder.ConstantR1<bool>({false, true, false, true}); @@ -523,7 +589,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, BooleanOr) { ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, BooleanOrZeroElement) { +XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2<bool>({{false, false}, {true, true}}); + auto b = builder.ConstantR2<bool>({{false, true}, {false, true}}); + auto out = builder.Or(a, b); + + Array2D<bool> expected_array({{false, true}, {true, true}}); + ComputeAndCompareR2<bool>(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1<bool>({}); auto b = builder.ConstantR1<bool>({}); @@ -532,7 +608,63 @@ XLA_TEST_F(ArrayElementwiseOpTest, BooleanOrZeroElement) { ComputeAndCompareR1<bool>(&builder, {}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, BooleanNot) { +XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<int32>({0, -1, 8}); + auto b = builder.ConstantR1<int32>({5, -7, 4}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1<int32>(&builder, {5, -1, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2<int32>({{0, -1}, {8, 8}}); + auto b = builder.ConstantR2<int32>({{5, -7}, {4, 1}}); + auto out = builder.Or(a, b); + + Array2D<int32> expected_array({{5, -1}, {12, 9}}); + ComputeAndCompareR2<int32>(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<int32>({}); + auto b = builder.ConstantR1<int32>({}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1<int32>(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<uint32>({0, 1, 8}); + auto b = builder.ConstantR1<uint32>({5, 7, 4}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1<uint32>(&builder, {5, 7, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2<uint32>({{0, 1}, {8, 8}}); + auto b = builder.ConstantR2<uint32>({{5, 7}, {4, 1}}); + auto out = builder.Or(a, b); + + Array2D<uint32> expected_array({{5, 7}, {12, 9}}); + ComputeAndCompareR2<uint32>(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<uint32>({}); + auto b = builder.ConstantR1<uint32>({}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1<uint32>(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1<bool>({false, true, true, false}); auto out = builder.Not(a); @@ -540,7 +672,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, BooleanNot) { ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, BooleanNotZeroElement) { +XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2<bool>({{false, true}, {true, false}}); + auto out = builder.Not(a); + + Array2D<bool> expected_array({{true, false}, {false, true}}); + ComputeAndCompareR2<bool>(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1<bool>({}); auto out = builder.Not(a); @@ -548,6 +689,56 @@ XLA_TEST_F(ArrayElementwiseOpTest, BooleanNotZeroElement) { ComputeAndCompareR1<bool>(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<int32>({-1, 0, 1}); + auto out = builder.Not(a); + + ComputeAndCompareR1<int32>(&builder, {0, -1, -2}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2<int32>({{-1, 0}, {1, 8}}); + auto out = builder.Not(a); + + Array2D<int32> expected_array({{0, -1}, {-2, -9}}); + ComputeAndCompareR2<int32>(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<int32>({}); + auto out = builder.Not(a); + + ComputeAndCompareR1<int32>(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<uint32>({0, 4294967295}); + auto out = builder.Not(a); + + ComputeAndCompareR1<uint32>(&builder, {4294967295, 0}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2<uint32>({{0, 4294967295}, {1, 4294967294}}); + auto out = builder.Not(a); + + Array2D<uint32> expected_array({{4294967295, 0}, {4294967294, 1}}); + ComputeAndCompareR2<uint32>(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<uint32>({}); + auto out = builder.Not(a); + + ComputeAndCompareR1<uint32>(&builder, {}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { SetFastMathDisabled(true); ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index da84d185ca..b5e7570778 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -459,7 +459,7 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { ComputeAndCompareR0<uint32>(&builder, 2, {}); } -XLA_TEST_F(ScalarComputationsTest, BooleanAnd) { +XLA_TEST_F(ScalarComputationsTest, AndBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { ComputationBuilder builder(client_, TestName()); @@ -470,7 +470,29 @@ XLA_TEST_F(ScalarComputationsTest, BooleanAnd) { } } -XLA_TEST_F(ScalarComputationsTest, BooleanOr) { +XLA_TEST_F(ScalarComputationsTest, AndS32) { + for (int32 x : {0, 8}) { + for (int32 y : {1, -16}) { + ComputationBuilder builder(client_, TestName()); + builder.And(builder.ConstantR0<int32>(x), builder.ConstantR0<int32>(y)); + + ComputeAndCompareR0<int32>(&builder, x & y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, AndU32) { + for (uint32 x : {0, 8}) { + for (uint32 y : {1, 16}) { + ComputationBuilder builder(client_, TestName()); + builder.And(builder.ConstantR0<uint32>(x), builder.ConstantR0<uint32>(y)); + + ComputeAndCompareR0<uint32>(&builder, x & y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, OrBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { ComputationBuilder builder(client_, TestName()); @@ -481,7 +503,29 @@ XLA_TEST_F(ScalarComputationsTest, BooleanOr) { } } -XLA_TEST_F(ScalarComputationsTest, BooleanNot) { +XLA_TEST_F(ScalarComputationsTest, OrS32) { + for (int32 x : {0, 8}) { + for (int32 y : {1, -16}) { + ComputationBuilder builder(client_, TestName()); + builder.Or(builder.ConstantR0<int32>(x), builder.ConstantR0<int32>(y)); + + ComputeAndCompareR0<int32>(&builder, x | y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, OrU32) { + for (uint32 x : {0, 8}) { + for (uint32 y : {1, 16}) { + ComputationBuilder builder(client_, TestName()); + builder.Or(builder.ConstantR0<uint32>(x), builder.ConstantR0<uint32>(y)); + + ComputeAndCompareR0<uint32>(&builder, x | y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, NotBool) { for (bool x : {false, true}) { ComputationBuilder builder(client_, TestName()); builder.Not(builder.ConstantR0<bool>(x)); @@ -490,6 +534,24 @@ XLA_TEST_F(ScalarComputationsTest, BooleanNot) { } } +XLA_TEST_F(ScalarComputationsTest, NotS32) { + for (int32 x : {-1, 0, 1}) { + ComputationBuilder builder(client_, TestName()); + builder.Not(builder.ConstantR0<int32>(x)); + + ComputeAndCompareR0<int32>(&builder, ~x, {}); + } +} + +XLA_TEST_F(ScalarComputationsTest, NotU32) { + for (uint32 x : {0, 1, 2}) { + ComputationBuilder builder(client_, TestName()); + builder.Not(builder.ConstantR0<uint32>(x)); + + ComputeAndCompareR0<uint32>(&builder, ~x, {}); + } +} + XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { ComputationBuilder builder(client_, TestName()); builder.Select(builder.ConstantR0<bool>(true), // The predicate. |