diff options
8 files changed, 330 insertions, 372 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 905dd9fc7b..65706b35d6 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -94,14 +94,12 @@ class BinaryOpsTest(XLATestCase): dtype(4), expected=np.array([[16], [81]], dtype=dtype)) - atan2_supported = self.device == "XLA_GPU" - if atan2_supported: - self._testBinary( - math_ops.atan2, - np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype), - np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype), - expected=np.array( - [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) + self._testBinary( + math_ops.atan2, + np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype), + np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype), + expected=np.array( + [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) self._testBinary( gen_math_ops._reciprocal_grad, @@ -388,30 +386,28 @@ class BinaryOpsTest(XLATestCase): ], dtype=dtype)) - atan2_supported = self.device == "XLA_GPU" - if atan2_supported: - self._testBinary( - math_ops.pow, - dtype(3 + 2j), - dtype(4 - 5j), - expected=np.power(dtype(3 + 2j), dtype(4 - 5j))) - self._testBinary( # empty rhs - math_ops.pow, - np.array([1 + 2j, 2 - 3j], dtype=dtype), - np.zeros(shape=[0, 2], dtype=dtype), - expected=np.zeros(shape=[0, 2], dtype=dtype)) - self._testBinary( # to zero power - math_ops.pow, - np.array([1 + 2j, 2 - 3j], dtype=dtype), - np.zeros(shape=[1, 2], dtype=dtype), - expected=np.ones(shape=[1, 2], dtype=dtype)) - lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype) - rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype) - scalar = dtype(2 + 2j) - self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs)) - self._testBinary( - math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs)) - self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar)) + self._testBinary( + math_ops.pow, + dtype(3 + 2j), + dtype(4 - 5j), + expected=np.power(dtype(3 + 2j), dtype(4 - 5j))) + self._testBinary( # empty rhs + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[0, 2], dtype=dtype), + expected=np.zeros(shape=[0, 2], dtype=dtype)) + self._testBinary( # to zero power + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[1, 2], dtype=dtype), + expected=np.ones(shape=[1, 2], dtype=dtype)) + lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype) + rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype) + scalar = dtype(2 + 2j) + self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs)) + self._testBinary( + math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs)) + self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar)) lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) @@ -421,9 +417,8 @@ class BinaryOpsTest(XLATestCase): self._testBinary( gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) - if atan2_supported: - self._testBinary( - gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) + self._testBinary( + gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) self._testBinary( gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index ecba5a4fb0..0a6fe04d3c 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -363,26 +363,23 @@ class UnaryOpsTest(XLATestCase): def testComplexOps(self): for dtype in self.complex_types: - # TODO(b/65408531): Wider support for log (needs atan2). - atan2_supported = self.device == "XLA_GPU" - if atan2_supported: - self._assertOpOutputMatchesExpected( - math_ops.acosh, - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), - expected=np.arccosh( - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.acosh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arccosh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) - self._assertOpOutputMatchesExpected( - math_ops.asinh, - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), - expected=np.arcsinh( - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.asinh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arcsinh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) - self._assertOpOutputMatchesExpected( - math_ops.atanh, - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), - expected=np.arctanh( - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.atanh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arctanh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.cosh, @@ -409,11 +406,10 @@ class UnaryOpsTest(XLATestCase): np.array([[1, 2j, 2 + 3j]], dtype=dtype), expected=1.0 / np.array([[1, 2j, 2 + 3j]], dtype=dtype)) - if atan2_supported: - self._assertOpOutputMatchesExpected( - math_ops.log, - np.array([[5j, 3 - 2j]], dtype=dtype), - expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.log, + np.array([[5j, 3 - 2j]], dtype=dtype), + expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.sin, @@ -427,27 +423,26 @@ class UnaryOpsTest(XLATestCase): # TODO(b/34703906): improve log1p implementation and make tolerance # tighter. - if atan2_supported: # TODO(b/34703906): log support - self._assertOpOutputMatchesExpected( - math_ops.log1p, - np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), - expected=np.log1p( - np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.log1p, + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), + expected=np.log1p( + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) - val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) - self._assertOpOutputMatchesExpected( - math_ops.rsqrt, val, expected=1 / np.sqrt(val)) + val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.rsqrt, val, expected=1 / np.sqrt(val)) - self._assertOpOutputMatchesExpected( - math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val))) + self._assertOpOutputMatchesExpected( + math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val))) - self._assertOpOutputMatchesExpected( - math_ops.sqrt, val, expected=np.sqrt(val)) + self._assertOpOutputMatchesExpected( + math_ops.sqrt, val, expected=np.sqrt(val)) - self._assertOpOutputMatchesExpected( - math_ops.tanh, - np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), - expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.tanh, + np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), + expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.tan, @@ -480,12 +475,10 @@ class UnaryOpsTest(XLATestCase): np.array([[-4j, 3 + 2j], [2, -1j]], dtype=dtype), expected=np.array([[1, 1], [1, 1]], dtype=dtype)) - if atan2_supported: # TODO(b/34703906): atan2 support - self._assertOpOutputMatchesExpected( - math_ops.angle, - np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), - expected=np.angle( - np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.angle, + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), + expected=np.angle(np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.conj, diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index ba693ec89a..ebd96c4c42 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -44,15 +44,11 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp( default: return Unimplemented("tanh"); } - // Create function type for the function. - llvm::FunctionType* function_type = llvm::FunctionType::get( - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - /*isVarArg=*/false); // Create function declaration for 'tanhf'. llvm::Function* function = llvm::cast<llvm::Function>(module_->getOrInsertFunction( - llvm_ir::AsStringRef(function_name), function_type)); + llvm_ir::AsStringRef(function_name), operand_value->getType(), + operand_value->getType())); function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); function->setDoesNotAccessMemory(); @@ -64,6 +60,31 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp( } } +StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2( + PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { + string function_name; + switch (prim_type) { + case F32: + function_name = "atan2f"; + break; + case F64: + function_name = "atan2"; + break; + default: + return Unimplemented("atan2"); + } + // Create function declaration for 'atan2'. + llvm::Function* function = + llvm::cast<llvm::Function>(module_->getOrInsertFunction( + llvm_ir::AsStringRef(function_name), lhs->getType(), lhs->getType(), + rhs->getType())); + function->setCallingConv(llvm::CallingConv::C); + function->setDoesNotThrow(); + function->setDoesNotAccessMemory(); + // Create instruction to call 'atan2'. + return ir_builder_->CreateCall(function, {lhs, rhs}); +} + llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) const { diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 7e9f27befb..4446dfd282 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -41,6 +41,8 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { protected: StatusOr<llvm::Value*> EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const override; + StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, + llvm::Value* rhs) const override; IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 7e88bbd631..3792929432 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -404,21 +404,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( primitive_util::BitWidth(to_type)); } case HloOpcode::kExp: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitExp(op->shape().element_type(), operand_value); case HloOpcode::kLog: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitLog(op->shape().element_type(), operand_value); case HloOpcode::kCos: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitCos(op->shape().element_type(), operand_value); case HloOpcode::kSin: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitSin(op->shape().element_type(), operand_value); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()}, @@ -469,9 +461,25 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { + PrimitiveType input_type = op->operand(0)->shape().element_type(); + PrimitiveType component_type = + primitive_util::IsComplexType(input_type) + ? primitive_util::ComplexComponentType(input_type) + : input_type; switch (op->opcode()) { - // TODO(b/65209142): Angle/Log require atan2. - // case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + case HloOpcode::kLog: { + // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + llvm::Type* llvm_ty = a->getType(); + auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); + TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); + auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); + return EmitComposeComplex( + op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); + } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); TF_RET_CHECK(primitive_util::IsComplexType(from_type)); @@ -493,15 +501,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) - auto exp_a = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::exp, {EmitExtractReal(operand_value)}, - {EmitExtractReal(operand_value)->getType()}, ir_builder_); - auto cos_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::cos, {EmitExtractImag(operand_value)}, - {EmitExtractImag(operand_value)->getType()}, ir_builder_); - auto sin_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::sin, {EmitExtractImag(operand_value)}, - {EmitExtractImag(operand_value)->getType()}, ir_builder_); + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value))); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), ir_builder_->CreateFMul(exp_a, sin_b)); } @@ -516,16 +521,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); auto type = a->getType(); - auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); auto half_exp_b = ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); - auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a}, - {type}, ir_builder_); - auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); + TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); return EmitComposeComplex( op, ir_builder_->CreateFMul( @@ -546,16 +548,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); auto type = a->getType(); - auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); auto half_exp_b = ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); - auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a}, - {type}, ir_builder_); - auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); + TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); return EmitComposeComplex( op, ir_builder_->CreateFMul( @@ -563,6 +562,58 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); } + case HloOpcode::kTanh: { + /* + tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x)) + e^(a+bi) = e^a*(cos(b)+sin(b)i) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a)) + cos(b)=cos(-b), sin(-b)=-sin(b) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a)) + =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) / + (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a)) + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) / + (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a)) + This is a complex division, so we can multiply by denom_conj/denom_conj + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) * + (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + + i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + */ + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a)); + TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); + TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); + auto exp_neg_a = ir_builder_->CreateFDiv( + llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(exp_a, exp_a), + ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); + auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); + auto real_num = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); + auto exp_a_plus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); + auto exp_a_minus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = ir_builder_->CreateFMul( + cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, + exp_a_minus_exp_neg_a_sq)); + auto denom = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), + ir_builder_->CreateFDiv(imag_num, denom)); + } case HloOpcode::kAbs: { auto sum_sq = ir_builder_->CreateFAdd( ir_builder_->CreateFMul(EmitExtractReal(operand_value), @@ -625,7 +676,6 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { switch (op->opcode()) { - // case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support case HloOpcode::kComplex: return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: @@ -669,10 +719,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp( case HloOpcode::kMinimum: return EmitFloatMin(lhs_value, rhs_value); case HloOpcode::kPower: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, - {lhs_value, rhs_value}, - {lhs_value->getType()}, ir_builder_); - + return EmitPow(op->shape().element_type(), lhs_value, rhs_value); + case HloOpcode::kAtan2: + return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value); default: return Unimplemented("binary floating point op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -768,9 +817,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp( EmitExtractImag(lhs_value), EmitExtractImag(rhs_value), ir_builder_)); - // TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic - // case HloOpcode::kPower: - // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(c/2+di/2) + case HloOpcode::kPower: { + // (a+bi)^(c+di) = + // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), + // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) + PrimitiveType component_type = + primitive_util::ComplexComponentType(op->shape().element_type()); + auto a = EmitExtractReal(lhs_value); + auto b = EmitExtractImag(lhs_value); + auto c = EmitExtractReal(rhs_value); + auto d = EmitExtractImag(rhs_value); + auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); + auto half_c = ir_builder_->CreateFMul(one_half, c); + + TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, + EmitPow(component_type, aa_p_bb, half_c)); + auto neg_d = ir_builder_->CreateFNeg(d); + TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); + auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); + TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, + EmitExp(component_type, neg_d_arg_lhs)); + auto coeff = + ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); + auto half_d = ir_builder_->CreateFMul(one_half, d); + auto q = + ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), + ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); + TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); + TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); + return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), + ir_builder_->CreateFMul(coeff, sin_q)); + } default: return Unimplemented("binary complex op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -873,6 +953,43 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv( return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value)); } +StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value}, + {value->getType()}, ir_builder_); +} + +StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, + {value->getType()}, ir_builder_); +} + +StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, + {value->getType()}, ir_builder_); +} + +StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, + {value->getType()}, ir_builder_); +} + +StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, + {lhs->getType()}, ir_builder_); +} + +StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const { + return Unimplemented("atan2"); +} + StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision( const HloInstruction* hlo, llvm::Value* x) const { if (hlo->operand(0)->shape().element_type() != F32) { diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index cccb498f82..1a48eb5fcb 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -39,7 +39,7 @@ class ElementalIrEmitter { module_(module), hlo_module_config_(hlo_module_config) {} - virtual ~ElementalIrEmitter() {} + virtual ~ElementalIrEmitter() = default; virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op, llvm::Value* operand_value) const; @@ -92,6 +92,26 @@ class ElementalIrEmitter { virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const; + + virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const; + virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x) const; diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 6bf00cfb8a..4b511cb4bb 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -135,10 +135,6 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp( PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { - case HloOpcode::kAtan2: - return EmitLibdeviceMathCall("__nv_atan2", {lhs_value, rhs_value}, - {lhs_input_type, rhs_input_type}, - output_type); case HloOpcode::kRemainder: { return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value}, {lhs_input_type, rhs_input_type}, @@ -199,29 +195,50 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitErfcInv( return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type); } +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type); +} + +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type); +} + +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type); +} + +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type); +} + +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const { + return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type}, + prim_type); +} + +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2( + PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { + return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type}, + prim_type); +} + StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { PrimitiveType input_type = op->operand(0)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { - case HloOpcode::kExp: - return EmitLibdeviceMathCall("__nv_exp", {operand_value}, {input_type}, - output_type); case HloOpcode::kFloor: return EmitLibdeviceMathCall("__nv_floor", {operand_value}, {input_type}, output_type); case HloOpcode::kCeil: return EmitLibdeviceMathCall("__nv_ceil", {operand_value}, {input_type}, output_type); - case HloOpcode::kLog: - return EmitLibdeviceMathCall("__nv_log", {operand_value}, {input_type}, - output_type); - case HloOpcode::kCos: - return EmitLibdeviceMathCall("__nv_cos", {operand_value}, {input_type}, - output_type); - case HloOpcode::kSin: - return EmitLibdeviceMathCall("__nv_sin", {operand_value}, {input_type}, - output_type); case HloOpcode::kTanh: return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type}, output_type); @@ -230,224 +247,6 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp( } } -StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { - PrimitiveType input_type = op->operand(0)->shape().element_type(); - TF_RET_CHECK(primitive_util::IsComplexType(input_type)); - PrimitiveType component_type = - primitive_util::ComplexComponentType(input_type); - switch (op->opcode()) { - case HloOpcode::kPower: { - // (a+bi)^(c+di) = - // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), - // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) - auto a = EmitExtractReal(lhs_value); - auto b = EmitExtractImag(lhs_value); - auto c = EmitExtractReal(rhs_value); - auto d = EmitExtractImag(rhs_value); - auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), - ir_builder_->CreateFMul(b, b)); - auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto half_c = ir_builder_->CreateFMul(one_half, c); - - TF_ASSIGN_OR_RETURN( - auto aa_p_bb_to_half_c, - EmitLibdeviceMathCall("__nv_pow", {aa_p_bb, half_c}, - {component_type, component_type}, - component_type)); - auto neg_d = ir_builder_->CreateFNeg(d); - TF_ASSIGN_OR_RETURN( - auto arg_lhs, EmitLibdeviceMathCall("__nv_atan2", {b, a}, - {component_type, component_type}, - component_type)); - auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); - TF_ASSIGN_OR_RETURN( - auto e_to_neg_d_arg_lhs, - EmitLibdeviceMathCall("__nv_exp", {neg_d_arg_lhs}, {component_type}, - component_type)); - auto coeff = - ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); - TF_ASSIGN_OR_RETURN( - auto ln_aa_p_bb, - EmitLibdeviceMathCall("__nv_log", {aa_p_bb}, {component_type}, - component_type)); - auto half_d = ir_builder_->CreateFMul(one_half, d); - auto q = - ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), - ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); - TF_ASSIGN_OR_RETURN( - auto cos_q, EmitLibdeviceMathCall("__nv_cos", {q}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_q, EmitLibdeviceMathCall("__nv_sin", {q}, {component_type}, - component_type)); - return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), - ir_builder_->CreateFMul(coeff, sin_q)); - } - default: - return ElementalIrEmitter::EmitComplexBinaryOp(op, lhs_value, rhs_value); - } -} - -StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { - PrimitiveType input_type = op->operand(0)->shape().element_type(); - PrimitiveType component_type = - primitive_util::IsComplexType(input_type) - ? primitive_util::ComplexComponentType(input_type) - : input_type; - - switch (op->opcode()) { - case HloOpcode::kLog: { - // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) - auto a = EmitExtractReal(operand_value); - auto b = EmitExtractImag(operand_value); - llvm::Type* llvm_ty = a->getType(); - auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), - ir_builder_->CreateFMul(b, b)); - TF_ASSIGN_OR_RETURN( - auto log_sum_sq, - EmitLibdeviceMathCall("__nv_log", {sum_sq}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto angle, EmitLibdeviceMathCall("__nv_atan2", {b, a}, - {component_type, component_type}, - component_type)); - auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex( - op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); - } - case HloOpcode::kExp: { - // e^(a+bi) = e^a*(cos(b)+sin(b)i) - auto b = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN( - auto exp_a, - EmitLibdeviceMathCall("__nv_exp", {EmitExtractReal(operand_value)}, - {component_type}, component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, - component_type)); - return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); - } - case HloOpcode::kCos: { - // cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) - auto a = EmitExtractReal(operand_value); - auto llvm_ty = a->getType(); - TF_ASSIGN_OR_RETURN( - auto exp_b, - EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, - {component_type}, component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type}, - component_type)); - auto half_exp_b = - ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - auto half_exp_neg_b = - ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return EmitComposeComplex( - op, - ir_builder_->CreateFMul( - cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), - ir_builder_->CreateFMul( - sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b))); - } - - case HloOpcode::kSin: { - // sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) - auto a = EmitExtractReal(operand_value); - auto llvm_ty = a->getType(); - TF_ASSIGN_OR_RETURN( - auto exp_b, - EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, - {component_type}, component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type}, - component_type)); - auto half_exp_b = - ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - auto half_exp_neg_b = - ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return EmitComposeComplex( - op, - ir_builder_->CreateFMul( - sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), - ir_builder_->CreateFMul( - cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); - } - case HloOpcode::kTanh: { - /* - tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x)) - e^(a+bi) = e^a*(cos(b)+sin(b)i) - so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) / - (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a)) - cos(b)=cos(-b), sin(-b)=-sin(b) - so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) / - (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a)) - =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) / - (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a)) - =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) / - (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a)) - This is a complex division, so we can multiply by denom_conj/denom_conj - =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) * - (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) / - ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) - =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + - i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / - ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) - */ - auto a = EmitExtractReal(operand_value); - auto b = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN( - auto exp_a, EmitLibdeviceMathCall("__nv_exp", {a}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, - component_type)); - auto exp_neg_a = ir_builder_->CreateFDiv( - llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); - auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(exp_a, exp_a), - ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); - auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); - auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); - auto real_num = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), - ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); - auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); - auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); - auto exp_a_plus_exp_neg_a_sq = - ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); - auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); - auto exp_a_minus_exp_neg_a_sq = - ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); - auto imag_num = ir_builder_->CreateFMul( - cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, - exp_a_minus_exp_neg_a_sq)); - auto denom = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), - ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); - return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), - ir_builder_->CreateFDiv(imag_num, denom)); - } - default: - return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value); - } -} - llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( const string& callee_name, tensorflow::gtl::ArraySlice<llvm::Value*> operands, diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 6a537d0152..77d4569b1e 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -54,20 +54,31 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr<llvm::Value*> EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const override; - StatusOr<llvm::Value*> EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const override; - StatusOr<llvm::Value*> EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const override; - StatusOr<llvm::Value*> EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const override; - StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, llvm::Value* lhs, + llvm::Value* rhs) const override; + + StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, + llvm::Value* rhs) const override; + llvm::Value* EmitThreadId() const override; private: |