diff options
author | 2017-11-13 11:34:15 -0800 | |
---|---|---|
committer | 2017-11-13 11:40:57 -0800 | |
commit | 58f7858601b72aa3c5854571f2152b91d1795e29 (patch) | |
tree | 214e1ff498ecc21573dbe444fc5a0142915152af /tensorflow/compiler | |
parent | 659d8cbc3aaffc0249afee1ec437639beda8d243 (diff) |
[TF:XLA] Adding test coverage for more C64 operations, and ensuring they pass.
Included here:
- reduction ops (reduce_sum, reduce_prod)
- unaries: tanh, sigmoid (currently GPU only)
- binaries: pow (currently GPU only)
PiperOrigin-RevId: 175562417
Diffstat (limited to 'tensorflow/compiler')
16 files changed, 623 insertions, 180 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index d412c572ae..654dc15e86 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -366,16 +366,52 @@ class BinaryOpsTest(XLATestCase): self._testBinary( gen_math_ops._real_div, - np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype), - np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype), + np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j], dtype=dtype), + np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j], dtype=dtype), + expected=np.array( + [1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2], + dtype=dtype)) + + # Test inf/nan scenarios. + self._testBinary( + gen_math_ops._real_div, + np.array([4 + 3j, 4, 3j, -4, -4j, 2 - 3j], dtype=dtype), + np.array([0, 0, 0, 0, 0, 0], dtype=dtype), expected=np.array( [ - 1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2, - float("inf") + dtype(1 + 1j) / 0, + dtype(1) / 0, + dtype(1j) / 0, + dtype(-1) / 0, + dtype(-1j) / 0, + dtype(1 - 1j) / 0 ], dtype=dtype)) - # TODO(b/65408531): support+test pow for cplx + atan2_supported = self.device == "XLA_GPU" + if atan2_supported: + self._testBinary( + math_ops.pow, + dtype(3 + 2j), + dtype(4 - 5j), + expected=np.power(dtype(3 + 2j), dtype(4 - 5j))) + self._testBinary( # empty rhs + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[0, 2], dtype=dtype), + expected=np.zeros(shape=[0, 2], dtype=dtype)) + self._testBinary( # to zero power + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[1, 2], dtype=dtype), + expected=np.ones(shape=[1, 2], dtype=dtype)) + lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype) + rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype) + scalar = dtype(2 + 2j) + self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs)) + self._testBinary( + math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs)) + self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar)) lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) @@ -385,7 +421,9 @@ class BinaryOpsTest(XLATestCase): self._testBinary( gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) - # TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow) + if atan2_supported: + self._testBinary( + gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) self._testBinary( gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index efda2cc207..965fdf684b 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -67,25 +67,37 @@ class ReduceOpsTest(XLATestCase): np.arange(-10, -4).reshape(2, 3), np.arange(-4, 2).reshape(2, 3), ] - NONEMPTY_FLOAT_DATA = [ - np.arange(1, 7).reshape(2, 3), - np.arange(-10, -4).reshape(2, 3), - np.arange(-4, 2).reshape(2, 3), + COMPLEX_DATA = [ + np.zeros(shape=(2, 0)).astype(np.complex64), + np.zeros(shape=(0, 30)).astype(np.complex64), + np.arange(1, 13, dtype=np.float32).view(np.complex64).reshape(2, 3), + np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3), + np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3), ] + NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0] + NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0] BOOL_DATA = [ np.array([], dtype=np.bool).reshape(2, 0), np.array([], dtype=np.bool).reshape(0, 3), np.array([[False, True, False], [True, True, False]]), ] - def testReduceSum(self): + def testReduceSumF32(self): self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.FLOAT_DATA) - def testReduceProd(self): + def testReduceSumC64(self): + self._testReduction(math_ops.reduce_sum, np.sum, np.complex64, + self.COMPLEX_DATA) + + def testReduceProdF32(self): self._testReduction(math_ops.reduce_prod, np.prod, np.float32, self.FLOAT_DATA) + def testReduceProdC64(self): + self._testReduction(math_ops.reduce_prod, np.prod, np.complex64, + self.COMPLEX_DATA) + def testReduceMin(self): def reference_min(inp, axis): @@ -108,12 +120,16 @@ class ReduceOpsTest(XLATestCase): self._testReduction(math_ops.reduce_max, reference_max, np.float32, self.FLOAT_DATA) - def testReduceMean(self): + def testReduceMeanF32(self): # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when # reducing across zero inputs. self._testReduction(math_ops.reduce_mean, np.mean, np.float32, self.NONEMPTY_FLOAT_DATA) + def testReduceMeanC64(self): + self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, + self.NONEMPTY_COMPLEX_DATA) + def testReduceAll(self): self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 76644380bd..a9a3f4f97f 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -330,13 +330,23 @@ class UnaryOpsTest(XLATestCase): def testComplexOps(self): for dtype in self.complex_types: - # TODO(b/65408531): math_ops.acosh (needs pow) - # TODO(b/65408531): math_ops.asinh (needs pow) # TODO(b/65408531): Wider support for log (needs atan2). atan2_supported = self.device == "XLA_GPU" if atan2_supported: self._assertOpOutputMatchesExpected( + math_ops.acosh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arccosh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.asinh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arcsinh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + + self._assertOpOutputMatchesExpected( math_ops.atanh, np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), expected=np.arctanh( @@ -392,19 +402,26 @@ class UnaryOpsTest(XLATestCase): expected=np.log1p( np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) - # TODO(b/34703906): math_ops.rsqrt (needs pow) + val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.rsqrt, val, expected=1 / np.sqrt(val)) - # TODO(b/34703906): math_ops.sigmoid (needs tanh) + self._assertOpOutputMatchesExpected( + math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val))) - # TODO(b/34703906): math_ops.sqrt (needs pow) + self._assertOpOutputMatchesExpected( + math_ops.sqrt, val, expected=np.sqrt(val)) + + self._assertOpOutputMatchesExpected( + math_ops.tanh, + np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), + expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.tan, np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) - # TODO(b/34703906): math_ops.tanh (as itself) - ctypes = {np.complex64: np.float32} self._assertOpOutputMatchesExpected( math_ops.abs, diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 35fe0d1a51..5c9b29f6e2 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -135,7 +135,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleConvert(HloInstruction* convert) override; + Status HandleComplex(HloInstruction* complex) override; + Status HandleReal(HloInstruction* real) override; + Status HandleImag(HloInstruction* imag) override; Status HandleConvolution(HloInstruction* convolution) override; @@ -947,6 +950,18 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { return Status::OK(); } +// Complex(Real(c), Imag(c)) -> c +Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) { + auto real = complex->mutable_operand(0); + auto imag = complex->mutable_operand(1); + if (real->opcode() == HloOpcode::kReal && + imag->opcode() == HloOpcode::kImag && + real->operand(0) == imag->operand(0)) { + return ReplaceInstruction(complex, real->mutable_operand(0)); + } + return Status::OK(); +} + // Real(Complex(r, i)) -> r Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) { auto operand = real->mutable_operand(0); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index c06e330bc1..620f0a54fa 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -371,6 +371,31 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { EXPECT_EQ(root, param0); } +// Test that complex(real(c), imag(c)) is simplified to c. +TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2c64, "param0")); + HloInstruction* real = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0)); + HloInstruction* imag = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0)); + HloInstruction* cplx = builder.AddInstruction( + HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, cplx); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + // Test that real(complex(r,i)) is simplified to r. TEST_F(AlgebraicSimplifierTest, RealOfComplex) { Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index a945657712..606868034a 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -93,14 +93,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp( auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( primitive_util::ComplexComponentType(to_type), module_); if (primitive_util::IsSignedIntegralType(from_type)) { - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateSIToFP(operand_value, to_ir_component_type), nullptr); } if (primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED) { - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateUIToFP(operand_value, to_ir_component_type), nullptr); @@ -178,9 +178,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( PrimitiveType to_component_type = primitive_util::ComplexComponentType(to_type); if (from_type == to_component_type) { - return ComposeComplex(op, operand_value, nullptr); + return EmitComposeComplex(op, operand_value, nullptr); } - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFPCast( operand_value, @@ -269,15 +269,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; switch (op->opcode()) { // TODO(b/65209142): Angle/Log require atan2. - // case HloOpcode::kAngle: // case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -291,24 +284,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( primitive_util::ComplexComponentType(to_type); auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); - return ComposeComplex( + return EmitComposeComplex( op, - ir_builder_->CreateFPCast(real(operand_value), to_ir_component_type), - ir_builder_->CreateFPCast(imag(operand_value), to_ir_component_type)); + ir_builder_->CreateFPCast(EmitExtractReal(operand_value), + to_ir_component_type), + ir_builder_->CreateFPCast(EmitExtractImag(operand_value), + to_ir_component_type)); } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) auto exp_a = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::exp, {real(operand_value)}, - {real(operand_value)->getType()}, ir_builder_); + llvm::Intrinsic::exp, {EmitExtractReal(operand_value)}, + {EmitExtractReal(operand_value)->getType()}, ir_builder_); auto cos_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::cos, {imag(operand_value)}, - {imag(operand_value)->getType()}, ir_builder_); + llvm::Intrinsic::cos, {EmitExtractImag(operand_value)}, + {EmitExtractImag(operand_value)->getType()}, ir_builder_); auto sin_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::sin, {imag(operand_value)}, - {imag(operand_value)->getType()}, ir_builder_); - return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); + llvm::Intrinsic::sin, {EmitExtractImag(operand_value)}, + {EmitExtractImag(operand_value)->getType()}, ir_builder_); + return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), + ir_builder_->CreateFMul(exp_a, sin_b)); } case HloOpcode::kCos: { // cos(z) = .5(e^(iz) + e^(-iz)) @@ -318,8 +313,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( // cos(-x) = cos(x) and sin(-x) = -sin(x), so // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i)) // = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) - auto a = real(operand_value); - auto b = imag(operand_value); + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); auto type = a->getType(); auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, {type}, ir_builder_); @@ -331,7 +326,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( {type}, ir_builder_); auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, {type}, ir_builder_); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), @@ -348,8 +343,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( // cos(-x) = cos(x) and sin(-x) = -sin(x), so // = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a))) // = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) - auto a = real(operand_value); - auto b = imag(operand_value); + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); auto type = a->getType(); auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, {type}, ir_builder_); @@ -361,7 +356,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( {type}, ir_builder_); auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, {type}, ir_builder_); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), @@ -370,33 +365,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( } case HloOpcode::kAbs: { auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(operand_value), real(operand_value)), - ir_builder_->CreateFMul(imag(operand_value), imag(operand_value))); + ir_builder_->CreateFMul(EmitExtractReal(operand_value), + EmitExtractReal(operand_value)), + ir_builder_->CreateFMul(EmitExtractImag(operand_value), + EmitExtractImag(operand_value))); return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_); } case HloOpcode::kSign: { // Sign(c) = c / |c| auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(operand_value), real(operand_value)), - ir_builder_->CreateFMul(imag(operand_value), imag(operand_value))); + ir_builder_->CreateFMul(EmitExtractReal(operand_value), + EmitExtractReal(operand_value)), + ir_builder_->CreateFMul(EmitExtractImag(operand_value), + EmitExtractImag(operand_value))); auto cplx_abs = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_); auto type = cplx_abs->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero); return ir_builder_->CreateSelect( - oeq, ComposeComplex(op, zero, zero), - ComposeComplex( - op, ir_builder_->CreateFDiv(real(operand_value), cplx_abs), - ir_builder_->CreateFDiv(imag(operand_value), cplx_abs))); + oeq, EmitComposeComplex(op, zero, zero), + EmitComposeComplex( + op, + ir_builder_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), + ir_builder_->CreateFDiv(EmitExtractImag(operand_value), + cplx_abs))); } case HloOpcode::kNegate: - return ComposeComplex(op, ir_builder_->CreateFNeg(real(operand_value)), - ir_builder_->CreateFNeg(imag(operand_value))); + return EmitComposeComplex( + op, ir_builder_->CreateFNeg(EmitExtractReal(operand_value)), + ir_builder_->CreateFNeg(EmitExtractImag(operand_value))); case HloOpcode::kReal: - return real(operand_value); + return EmitExtractReal(operand_value); case HloOpcode::kImag: - return imag(operand_value); + return EmitExtractImag(operand_value); default: return Unimplemented("unary complex op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -424,7 +426,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp( switch (op->opcode()) { // case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support case HloOpcode::kComplex: - return ComposeComplex(op, lhs_value, rhs_value); + return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: return ir_builder_->CreateFAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: @@ -479,54 +481,66 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp( StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; switch (op->opcode()) { case HloOpcode::kAdd: - return ComposeComplex( - op, ir_builder_->CreateFAdd(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFAdd(imag(lhs_value), imag(rhs_value))); + return EmitComposeComplex( + op, + ir_builder_->CreateFAdd(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFAdd(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); case HloOpcode::kSubtract: - return ComposeComplex( - op, ir_builder_->CreateFSub(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFSub(imag(lhs_value), imag(rhs_value))); + return EmitComposeComplex( + op, + ir_builder_->CreateFSub(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFSub(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); case HloOpcode::kMultiply: - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFSub( - ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))), + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)))); + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)))); case HloOpcode::kDivide: { // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) auto rhs_sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(rhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(rhs_value), imag(rhs_value))); + ir_builder_->CreateFMul(EmitExtractReal(rhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(rhs_value), + EmitExtractImag(rhs_value))); auto type = rhs_sum_sq->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero); + auto real_inf_or_nan = + ir_builder_->CreateFDiv(EmitExtractReal(lhs_value), zero); + auto imag_inf_or_nan = + ir_builder_->CreateFDiv(EmitExtractImag(lhs_value), zero); return ir_builder_->CreateSelect( - oeq, ComposeComplex(op, llvm::ConstantFP::getInfinity(type), zero), - ComposeComplex( + oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan), + EmitComposeComplex( op, ir_builder_->CreateFDiv( ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), - imag(rhs_value))), + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), rhs_sum_sq), ir_builder_->CreateFDiv( ir_builder_->CreateFSub( - ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(real(lhs_value), - imag(rhs_value))), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value))), rhs_sum_sq))); } // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered @@ -538,16 +552,20 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp( // matches C++'s semantics. case HloOpcode::kEq: return ir_builder_->CreateAnd( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, real(lhs_value), - real(rhs_value), ir_builder_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, imag(lhs_value), - imag(rhs_value), ir_builder_)); + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), ir_builder_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), ir_builder_)); case HloOpcode::kNe: return ir_builder_->CreateOr( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, real(lhs_value), - real(rhs_value), ir_builder_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, imag(lhs_value), - imag(rhs_value), ir_builder_)); + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), ir_builder_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), ir_builder_)); // TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic // case HloOpcode::kPower: @@ -1565,25 +1583,25 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); llvm::Value* next_accumulator; if (primitive_util::IsComplexType(primitive_type)) { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; llvm::Value* product_real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))); + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); llvm::Value* product_imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value))); + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value))); next_accumulator = ir_builder_->CreateInsertValue( current_accumulator, - ir_builder_->CreateFAdd(real(current_accumulator), product_real), + ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), + product_real), {0}); next_accumulator = ir_builder_->CreateInsertValue( next_accumulator, - ir_builder_->CreateFAdd(imag(current_accumulator), product_imag), + ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), + product_imag), {1}); } else if (primitive_util::IsFloatingPointType(primitive_type)) { next_accumulator = ir_builder_->CreateFAdd( @@ -1607,9 +1625,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( } } -llvm::Value* ElementalIrEmitter::ComposeComplex(const HloInstruction* op, - llvm::Value* real, - llvm::Value* imag) const { +llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const { + return ir_builder_->CreateExtractValue(value, {0}); +} + +llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const { + return ir_builder_->CreateExtractValue(value, {1}); +} + +llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, + llvm::Value* real, + llvm::Value* imag) const { auto cplx_type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); auto complex = ir_builder_->CreateInsertValue( diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 9d32436e38..cccb498f82 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -95,6 +95,13 @@ class ElementalIrEmitter { virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x) const; + virtual llvm::Value* EmitExtractReal(llvm::Value* value) const; + virtual llvm::Value* EmitExtractImag(llvm::Value* value) const; + + // Composes a complex struct. imag may be nullptr for simple cast operations. + llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, + llvm::Value* imag) const; + // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and // the target array index, computes the source array index of its // `operand_no`-th operand. @@ -117,11 +124,6 @@ class ElementalIrEmitter { // compiled executable outside of the HLO code itself. const HloModuleConfig& hlo_module_config_; - protected: - // Composes a complex struct. imag may be nullptr for simple cast operations. - llvm::Value* ComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const; - private: // Returns a ElementGenerator for a RNG HloInstruction. llvm_ir::ElementGenerator MakeRngElementGenerator( diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 1b94499bc6..6bf00cfb8a 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -230,6 +230,66 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp( } } +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const { + PrimitiveType input_type = op->operand(0)->shape().element_type(); + TF_RET_CHECK(primitive_util::IsComplexType(input_type)); + PrimitiveType component_type = + primitive_util::ComplexComponentType(input_type); + switch (op->opcode()) { + case HloOpcode::kPower: { + // (a+bi)^(c+di) = + // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), + // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) + auto a = EmitExtractReal(lhs_value); + auto b = EmitExtractImag(lhs_value); + auto c = EmitExtractReal(rhs_value); + auto d = EmitExtractImag(rhs_value); + auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); + auto half_c = ir_builder_->CreateFMul(one_half, c); + + TF_ASSIGN_OR_RETURN( + auto aa_p_bb_to_half_c, + EmitLibdeviceMathCall("__nv_pow", {aa_p_bb, half_c}, + {component_type, component_type}, + component_type)); + auto neg_d = ir_builder_->CreateFNeg(d); + TF_ASSIGN_OR_RETURN( + auto arg_lhs, EmitLibdeviceMathCall("__nv_atan2", {b, a}, + {component_type, component_type}, + component_type)); + auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); + TF_ASSIGN_OR_RETURN( + auto e_to_neg_d_arg_lhs, + EmitLibdeviceMathCall("__nv_exp", {neg_d_arg_lhs}, {component_type}, + component_type)); + auto coeff = + ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + TF_ASSIGN_OR_RETURN( + auto ln_aa_p_bb, + EmitLibdeviceMathCall("__nv_log", {aa_p_bb}, {component_type}, + component_type)); + auto half_d = ir_builder_->CreateFMul(one_half, d); + auto q = + ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), + ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); + TF_ASSIGN_OR_RETURN( + auto cos_q, EmitLibdeviceMathCall("__nv_cos", {q}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_q, EmitLibdeviceMathCall("__nv_sin", {q}, {component_type}, + component_type)); + return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), + ir_builder_->CreateFMul(coeff, sin_q)); + } + default: + return ElementalIrEmitter::EmitComplexBinaryOp(op, lhs_value, rhs_value); + } +} + StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { PrimitiveType input_type = op->operand(0)->shape().element_type(); @@ -237,18 +297,12 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp( primitive_util::IsComplexType(input_type) ? primitive_util::ComplexComponentType(input_type) : input_type; - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; switch (op->opcode()) { case HloOpcode::kLog: { // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) - auto a = real(operand_value); - auto b = imag(operand_value); + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), ir_builder_->CreateFMul(b, b)); @@ -261,34 +315,33 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp( {component_type, component_type}, component_type)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return ComposeComplex(op, ir_builder_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex( + op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); } - // TODO(b/65408531): Implement kPower on GPU, where atan2 is available. - // case HloOpcode::kPower: - // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(0.5(c+di)) case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) - auto b = imag(operand_value); + auto b = EmitExtractImag(operand_value); TF_ASSIGN_OR_RETURN( - auto exp_a, EmitLibdeviceMathCall("__nv_exp", {real(operand_value)}, - {component_type}, component_type)); + auto exp_a, + EmitLibdeviceMathCall("__nv_exp", {EmitExtractReal(operand_value)}, + {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, component_type)); - return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); + return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), + ir_builder_->CreateFMul(exp_a, sin_b)); } case HloOpcode::kCos: { // cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) - auto a = real(operand_value); + auto a = EmitExtractReal(operand_value); auto llvm_ty = a->getType(); TF_ASSIGN_OR_RETURN( - auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)}, - {component_type}, component_type)); + auto exp_b, + EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, + {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, component_type)); @@ -299,7 +352,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp( ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), @@ -309,11 +362,12 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp( case HloOpcode::kSin: { // sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) - auto a = real(operand_value); + auto a = EmitExtractReal(operand_value); auto llvm_ty = a->getType(); TF_ASSIGN_OR_RETURN( - auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)}, - {component_type}, component_type)); + auto exp_b, + EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, + {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, component_type)); @@ -324,13 +378,71 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp( ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); } + case HloOpcode::kTanh: { + /* + tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x)) + e^(a+bi) = e^a*(cos(b)+sin(b)i) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a)) + cos(b)=cos(-b), sin(-b)=-sin(b) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a)) + =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) / + (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a)) + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) / + (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a)) + This is a complex division, so we can multiply by denom_conj/denom_conj + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) * + (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + + i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + */ + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitLibdeviceMathCall("__nv_exp", {a}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, + component_type)); + auto exp_neg_a = ir_builder_->CreateFDiv( + llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(exp_a, exp_a), + ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); + auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); + auto real_num = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); + auto exp_a_plus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); + auto exp_a_minus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = ir_builder_->CreateFMul( + cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, + exp_a_minus_exp_neg_a_sq)); + auto denom = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), + ir_builder_->CreateFDiv(imag_num, denom)); + } default: return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value); } diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 3defa1b696..6a537d0152 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -61,6 +61,10 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const override; + StatusOr<llvm::Value*> EmitComplexBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const override; + StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const override; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 9d55c7859d..af2a92e11e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -293,29 +293,30 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation( computation, {old_output_location, source_address}, new_output_location)); // (old_output, success) = atomicCAS(output_address, old_output, new_output); - llvm::Type* element_int_ir_type = - ir_builder_.getIntNTy(element_ir_type->getScalarSizeInBits()); - // cmpxchg accetps integer only, so we bitcast the operands (old_output and - // new_output) to integers of the same bit width, and bitcast the result - // back to the original element type. - llvm::Value* old_output = - ir_builder_.CreateLoad(old_output_location, "old_output"); - llvm::Value* new_output = - ir_builder_.CreateLoad(new_output_location, "new_output"); + int num_bits = llvm_ir::GetSizeInBits(element_ir_type); + llvm::Type* element_int_ir_type = ir_builder_.getIntNTy(num_bits); + // cmpxchg accepts integer only, and bitcast refuses to operate on aggregate + // types, so we bitcast load and store addresses to intN* of the same bit + // width. + llvm::Value* old_output = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(old_output_location, + element_int_ir_type->getPointerTo()), + "old_output"); + llvm::Value* new_output = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(new_output_location, + element_int_ir_type->getPointerTo()), + "new_output"); llvm::Value* ret_value = ir_builder_.CreateAtomicCmpXchg( ir_builder_.CreateBitCast(output_address, element_int_ir_type->getPointerTo()), - ir_builder_.CreateBitCast(old_output, element_int_ir_type), - ir_builder_.CreateBitCast(new_output, element_int_ir_type), - llvm::AtomicOrdering::SequentiallyConsistent, + old_output, new_output, llvm::AtomicOrdering::SequentiallyConsistent, llvm::AtomicOrdering::SequentiallyConsistent); // cmpxchg returns a pair. The first element is the original value at // output_address and the second element is whether the swap is successful. ir_builder_.CreateStore( - ir_builder_.CreateBitCast( - ir_builder_.CreateExtractValue(ret_value, 0, "old_output"), - element_ir_type), - old_output_location); + ir_builder_.CreateExtractValue(ret_value, 0, "old_output"), + ir_builder_.CreateBitCast(old_output_location, + element_int_ir_type->getPointerTo())); ir_builder_.CreateCondBr( ir_builder_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 7b4662fc80..db78f4b84d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1081,16 +1081,25 @@ Status IrEmitterUnnested::EmitRowReduction( // from the warp. llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &ir_builder_); + int bit_width = llvm_ir::GetSizeInBits(element_ir_type); + // bitcast cannot be applied to aggregate types (even packed ones), so we + // instead bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() + ? ir_builder_.getIntNTy(bit_width) + : element_ir_type; for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - partial_reduction_result_address, "partial_reduction_result"); + ir_builder_.CreateBitCast(partial_reduction_result_address, + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); ir_builder_.CreateStore( EmitShuffleDown(partial_reduction_result, ir_builder_.getInt32(shuffle_distance), &ir_builder_), - result_from_other_lane); + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducer, {partial_reduction_result_address, result_from_other_lane}, partial_reduction_result_address)); diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index d95409e399..086c8dae9e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -163,8 +163,9 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, // z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the // imaginary part of z. return llvm::StructType::create( - "complex64", llvm::Type::getFloatTy(module->getContext()), - llvm::Type::getFloatTy(module->getContext())); + {llvm::Type::getFloatTy(module->getContext()), + llvm::Type::getFloatTy(module->getContext())}, + "complex64", /*isPacked=*/true); } return cplx_t; } @@ -178,6 +179,21 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, } } +int GetSizeInBits(llvm::Type* type) { + const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type); + if (struct_ty) { + CHECK(struct_ty->isPacked()); + int bits = 0; + for (auto element_type : struct_ty->elements()) { + bits += GetSizeInBits(element_type); + } + return bits; + } + int bits = type->getPrimitiveSizeInBits(); + CHECK_GT(bits, 0) << "type is not sized"; + return bits; +} + llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module); if (ShapeUtil::IsTuple(shape)) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index f70d9f88b3..063ead2b64 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -129,6 +129,9 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, llvm::Module* module); +// Returns the type size in bits. If "type" is a struct, it must be packed. +int GetSizeInBits(llvm::Type* type); + // Returns the LLVM type which represents the given XLA shape. For example, // if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]]. llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module); diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 0b700fbb6f..c6e8b24d12 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -82,6 +82,25 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { {}); } +XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<complex64>({}); + auto result = builder.Neg(a); + + ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<complex64>( + {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); + auto result = builder.Neg(a); + + ComputeAndCompareR1<complex64>( + &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, + {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1<float>({}); @@ -145,6 +164,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<complex64>( + {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); + auto b = builder.ConstantR1<complex64>( + {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); + auto add = builder.Add(a, b); + + ComputeAndCompareR1<complex64>( + &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<complex64>({}); + auto b = builder.ConstantR1<complex64>({}); + auto add = builder.Add(a, b); + + ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); +} + TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); ComputationBuilder builder(client_, TestName()); @@ -222,6 +263,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { ComputeAndCompareR1<int32>(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<complex64>( + {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); + auto b = builder.ConstantR1<complex64>( + {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1<complex64>( + &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<complex64>({}); + auto b = builder.ConstantR1<complex64>({}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); @@ -385,6 +448,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } } +XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<complex64>( + {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); + auto b = builder.ConstantR1<complex64>( + {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); + auto div = builder.Div(a, b); + + ComputeAndCompareR1<complex64>( + &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<complex64>({}); + auto b = builder.ConstantR1<complex64>({}); + auto div = builder.Div(a, b); + + ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1<float>( @@ -496,6 +580,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { ComputeAndCompareR1<uint32>(&builder, expected, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<complex64>( + {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); + auto b = builder.ConstantR1<complex64>( + {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1<complex64>( + &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<complex64>({}); + auto b = builder.ConstantR1<complex64>({}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1<bool>({false, false, true, true}); @@ -886,6 +992,53 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { ComputeAndCompareR1<bool>(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { + SetFastMathDisabled(true); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<complex64>({{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = builder.ConstantR1<complex64>({{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<complex64>({}); + auto rhs = builder.ConstantR1<complex64>({}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1<bool>(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { + // Disable fast-math because we're operating on NaNs. + SetFastMathDisabled(true); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<complex64>({{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = builder.ConstantR1<complex64>({{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + auto compare = builder.Ne(lhs, rhs); + + ComputeAndCompareR1<bool>(&builder, {true, true, false, true, true}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index b578667735..1dc274c591 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -332,8 +332,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( ComputationBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { static_assert(std::is_same<NativeT, float>::value || - std::is_same<NativeT, double>::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same<NativeT, double>::value || + std::is_same<NativeT, complex64>::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr<Literal> expected_literal = Literal::CreateR0<NativeT>(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -355,8 +356,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { static_assert(std::is_same<NativeT, float>::value || - std::is_same<NativeT, double>::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same<NativeT, double>::value || + std::is_same<NativeT, complex64>::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr<Literal> expected_literal = Literal::CreateR1<NativeT>(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index b72dd2707c..bfb04fd9f9 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -386,35 +386,39 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) { - constexpr bool kLhsRowMajor = false; - constexpr bool kRhsRowMajor = false; - TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot<float>(false, false); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) { - constexpr bool kLhsRowMajor = false; - constexpr bool kRhsRowMajor = true; - TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot<float>(false, true); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) { - constexpr bool kLhsRowMajor = true; - constexpr bool kRhsRowMajor = false; - TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot<float>(true, false); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { - constexpr bool kLhsRowMajor = true; - constexpr bool kRhsRowMajor = true; - TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot<float>(true, true); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) { TestNonsquareMatrixDot<double>(); } -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) { - TestNonsquareMatrixDot<complex64>(); +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFF) { + TestNonsquareMatrixDot<complex64>(false, false); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFT) { + TestNonsquareMatrixDot<complex64>(false, true); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTF) { + TestNonsquareMatrixDot<complex64>(true, false); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTT) { + TestNonsquareMatrixDot<complex64>(true, true); } XLA_TEST_F(DotOperationTest, MatrixVectorC64) { |