aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py65
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py85
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc33
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc203
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.h22
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc267
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h25
8 files changed, 330 insertions, 372 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 905dd9fc7b..65706b35d6 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -94,14 +94,12 @@ class BinaryOpsTest(XLATestCase):
dtype(4),
expected=np.array([[16], [81]], dtype=dtype))
- atan2_supported = self.device == "XLA_GPU"
- if atan2_supported:
- self._testBinary(
- math_ops.atan2,
- np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype),
- np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype),
- expected=np.array(
- [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype))
+ self._testBinary(
+ math_ops.atan2,
+ np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype),
+ np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype),
+ expected=np.array(
+ [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype))
self._testBinary(
gen_math_ops._reciprocal_grad,
@@ -388,30 +386,28 @@ class BinaryOpsTest(XLATestCase):
],
dtype=dtype))
- atan2_supported = self.device == "XLA_GPU"
- if atan2_supported:
- self._testBinary(
- math_ops.pow,
- dtype(3 + 2j),
- dtype(4 - 5j),
- expected=np.power(dtype(3 + 2j), dtype(4 - 5j)))
- self._testBinary( # empty rhs
- math_ops.pow,
- np.array([1 + 2j, 2 - 3j], dtype=dtype),
- np.zeros(shape=[0, 2], dtype=dtype),
- expected=np.zeros(shape=[0, 2], dtype=dtype))
- self._testBinary( # to zero power
- math_ops.pow,
- np.array([1 + 2j, 2 - 3j], dtype=dtype),
- np.zeros(shape=[1, 2], dtype=dtype),
- expected=np.ones(shape=[1, 2], dtype=dtype))
- lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype)
- rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype)
- scalar = dtype(2 + 2j)
- self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs))
- self._testBinary(
- math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs))
- self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar))
+ self._testBinary(
+ math_ops.pow,
+ dtype(3 + 2j),
+ dtype(4 - 5j),
+ expected=np.power(dtype(3 + 2j), dtype(4 - 5j)))
+ self._testBinary( # empty rhs
+ math_ops.pow,
+ np.array([1 + 2j, 2 - 3j], dtype=dtype),
+ np.zeros(shape=[0, 2], dtype=dtype),
+ expected=np.zeros(shape=[0, 2], dtype=dtype))
+ self._testBinary( # to zero power
+ math_ops.pow,
+ np.array([1 + 2j, 2 - 3j], dtype=dtype),
+ np.zeros(shape=[1, 2], dtype=dtype),
+ expected=np.ones(shape=[1, 2], dtype=dtype))
+ lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype)
+ rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype)
+ scalar = dtype(2 + 2j)
+ self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs))
+ self._testBinary(
+ math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs))
+ self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar))
lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype)
rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype)
@@ -421,9 +417,8 @@ class BinaryOpsTest(XLATestCase):
self._testBinary(
gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs))
- if atan2_supported:
- self._testBinary(
- gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2)
+ self._testBinary(
+ gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2)
self._testBinary(
gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs))
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index ecba5a4fb0..0a6fe04d3c 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -363,26 +363,23 @@ class UnaryOpsTest(XLATestCase):
def testComplexOps(self):
for dtype in self.complex_types:
- # TODO(b/65408531): Wider support for log (needs atan2).
- atan2_supported = self.device == "XLA_GPU"
- if atan2_supported:
- self._assertOpOutputMatchesExpected(
- math_ops.acosh,
- np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
- expected=np.arccosh(
- np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
+ self._assertOpOutputMatchesExpected(
+ math_ops.acosh,
+ np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
+ expected=np.arccosh(
+ np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
- self._assertOpOutputMatchesExpected(
- math_ops.asinh,
- np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
- expected=np.arcsinh(
- np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
+ self._assertOpOutputMatchesExpected(
+ math_ops.asinh,
+ np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
+ expected=np.arcsinh(
+ np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
- self._assertOpOutputMatchesExpected(
- math_ops.atanh,
- np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
- expected=np.arctanh(
- np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
+ self._assertOpOutputMatchesExpected(
+ math_ops.atanh,
+ np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
+ expected=np.arctanh(
+ np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.cosh,
@@ -409,11 +406,10 @@ class UnaryOpsTest(XLATestCase):
np.array([[1, 2j, 2 + 3j]], dtype=dtype),
expected=1.0 / np.array([[1, 2j, 2 + 3j]], dtype=dtype))
- if atan2_supported:
- self._assertOpOutputMatchesExpected(
- math_ops.log,
- np.array([[5j, 3 - 2j]], dtype=dtype),
- expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype)))
+ self._assertOpOutputMatchesExpected(
+ math_ops.log,
+ np.array([[5j, 3 - 2j]], dtype=dtype),
+ expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.sin,
@@ -427,27 +423,26 @@ class UnaryOpsTest(XLATestCase):
# TODO(b/34703906): improve log1p implementation and make tolerance
# tighter.
- if atan2_supported: # TODO(b/34703906): log support
- self._assertOpOutputMatchesExpected(
- math_ops.log1p,
- np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype),
- expected=np.log1p(
- np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)))
+ self._assertOpOutputMatchesExpected(
+ math_ops.log1p,
+ np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype),
+ expected=np.log1p(
+ np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)))
- val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)
- self._assertOpOutputMatchesExpected(
- math_ops.rsqrt, val, expected=1 / np.sqrt(val))
+ val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)
+ self._assertOpOutputMatchesExpected(
+ math_ops.rsqrt, val, expected=1 / np.sqrt(val))
- self._assertOpOutputMatchesExpected(
- math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val)))
+ self._assertOpOutputMatchesExpected(
+ math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val)))
- self._assertOpOutputMatchesExpected(
- math_ops.sqrt, val, expected=np.sqrt(val))
+ self._assertOpOutputMatchesExpected(
+ math_ops.sqrt, val, expected=np.sqrt(val))
- self._assertOpOutputMatchesExpected(
- math_ops.tanh,
- np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
- expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
+ self._assertOpOutputMatchesExpected(
+ math_ops.tanh,
+ np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
+ expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.tan,
@@ -480,12 +475,10 @@ class UnaryOpsTest(XLATestCase):
np.array([[-4j, 3 + 2j], [2, -1j]], dtype=dtype),
expected=np.array([[1, 1], [1, 1]], dtype=dtype))
- if atan2_supported: # TODO(b/34703906): atan2 support
- self._assertOpOutputMatchesExpected(
- math_ops.angle,
- np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
- expected=np.angle(
- np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype)))
+ self._assertOpOutputMatchesExpected(
+ math_ops.angle,
+ np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
+ expected=np.angle(np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.conj,
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index ba693ec89a..ebd96c4c42 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -44,15 +44,11 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
default:
return Unimplemented("tanh");
}
- // Create function type for the function.
- llvm::FunctionType* function_type = llvm::FunctionType::get(
- llvm_ir::PrimitiveTypeToIrType(element_type, module_),
- llvm_ir::PrimitiveTypeToIrType(element_type, module_),
- /*isVarArg=*/false);
// Create function declaration for 'tanhf'.
llvm::Function* function =
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
- llvm_ir::AsStringRef(function_name), function_type));
+ llvm_ir::AsStringRef(function_name), operand_value->getType(),
+ operand_value->getType()));
function->setCallingConv(llvm::CallingConv::C);
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
@@ -64,6 +60,31 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
}
}
+StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
+ PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
+ string function_name;
+ switch (prim_type) {
+ case F32:
+ function_name = "atan2f";
+ break;
+ case F64:
+ function_name = "atan2";
+ break;
+ default:
+ return Unimplemented("atan2");
+ }
+ // Create function declaration for 'atan2'.
+ llvm::Function* function =
+ llvm::cast<llvm::Function>(module_->getOrInsertFunction(
+ llvm_ir::AsStringRef(function_name), lhs->getType(), lhs->getType(),
+ rhs->getType()));
+ function->setCallingConv(llvm::CallingConv::C);
+ function->setDoesNotThrow();
+ function->setDoesNotAccessMemory();
+ // Create instruction to call 'atan2'.
+ return ir_builder_->CreateCall(function, {lhs, rhs});
+}
+
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator) const {
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index 7e9f27befb..4446dfd282 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -41,6 +41,8 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
protected:
StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;
+ StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
+ llvm::Value* rhs) const override;
IrEmitter* ir_emitter_;
};
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 7e88bbd631..3792929432 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -404,21 +404,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
primitive_util::BitWidth(to_type));
}
case HloOpcode::kExp:
- return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {operand_value},
- {operand_value->getType()},
- ir_builder_);
+ return EmitExp(op->shape().element_type(), operand_value);
case HloOpcode::kLog:
- return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {operand_value},
- {operand_value->getType()},
- ir_builder_);
+ return EmitLog(op->shape().element_type(), operand_value);
case HloOpcode::kCos:
- return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {operand_value},
- {operand_value->getType()},
- ir_builder_);
+ return EmitCos(op->shape().element_type(), operand_value);
case HloOpcode::kSin:
- return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {operand_value},
- {operand_value->getType()},
- ir_builder_);
+ return EmitSin(op->shape().element_type(), operand_value);
case HloOpcode::kFloor:
return llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()},
@@ -469,9 +461,25 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
+ PrimitiveType input_type = op->operand(0)->shape().element_type();
+ PrimitiveType component_type =
+ primitive_util::IsComplexType(input_type)
+ ? primitive_util::ComplexComponentType(input_type)
+ : input_type;
switch (op->opcode()) {
- // TODO(b/65209142): Angle/Log require atan2.
- // case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
+ case HloOpcode::kLog: {
+ // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
+ auto a = EmitExtractReal(operand_value);
+ auto b = EmitExtractImag(operand_value);
+ llvm::Type* llvm_ty = a->getType();
+ auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
+ ir_builder_->CreateFMul(b, b));
+ TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
+ TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a));
+ auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
+ return EmitComposeComplex(
+ op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle);
+ }
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
TF_RET_CHECK(primitive_util::IsComplexType(from_type));
@@ -493,15 +501,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
}
case HloOpcode::kExp: {
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
- auto exp_a = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::exp, {EmitExtractReal(operand_value)},
- {EmitExtractReal(operand_value)->getType()}, ir_builder_);
- auto cos_b = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::cos, {EmitExtractImag(operand_value)},
- {EmitExtractImag(operand_value)->getType()}, ir_builder_);
- auto sin_b = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::sin, {EmitExtractImag(operand_value)},
- {EmitExtractImag(operand_value)->getType()}, ir_builder_);
+ TF_ASSIGN_OR_RETURN(
+ auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
+ TF_ASSIGN_OR_RETURN(
+ auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
+ TF_ASSIGN_OR_RETURN(
+ auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
ir_builder_->CreateFMul(exp_a, sin_b));
}
@@ -516,16 +521,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
- auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b},
- {type}, ir_builder_);
+ TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
auto half_exp_b =
ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
auto half_exp_neg_b =
ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
- auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a},
- {type}, ir_builder_);
- auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a},
- {type}, ir_builder_);
+ TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
+ TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
return EmitComposeComplex(
op,
ir_builder_->CreateFMul(
@@ -546,16 +548,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
- auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b},
- {type}, ir_builder_);
+ TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
auto half_exp_b =
ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
auto half_exp_neg_b =
ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
- auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a},
- {type}, ir_builder_);
- auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a},
- {type}, ir_builder_);
+ TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
+ TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
return EmitComposeComplex(
op,
ir_builder_->CreateFMul(
@@ -563,6 +562,58 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
ir_builder_->CreateFMul(
cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b)));
}
+ case HloOpcode::kTanh: {
+ /*
+ tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
+ e^(a+bi) = e^a*(cos(b)+sin(b)i)
+ so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
+ (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
+ cos(b)=cos(-b), sin(-b)=-sin(b)
+ so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
+ (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
+ =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
+ (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
+ =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
+ (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
+ This is a complex division, so we can multiply by denom_conj/denom_conj
+ =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
+ (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
+ ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
+ =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
+ i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
+ ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
+ */
+ auto a = EmitExtractReal(operand_value);
+ auto b = EmitExtractImag(operand_value);
+ TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a));
+ TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b));
+ TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b));
+ auto exp_neg_a = ir_builder_->CreateFDiv(
+ llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
+ auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub(
+ ir_builder_->CreateFMul(exp_a, exp_a),
+ ir_builder_->CreateFMul(exp_neg_a, exp_neg_a));
+ auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b);
+ auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b);
+ auto real_num = ir_builder_->CreateFAdd(
+ ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
+ ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
+ auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b);
+ auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a);
+ auto exp_a_plus_exp_neg_a_sq =
+ ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
+ auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a);
+ auto exp_a_minus_exp_neg_a_sq =
+ ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
+ auto imag_num = ir_builder_->CreateFMul(
+ cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq,
+ exp_a_minus_exp_neg_a_sq));
+ auto denom = ir_builder_->CreateFAdd(
+ ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
+ ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
+ return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom),
+ ir_builder_->CreateFDiv(imag_num, denom));
+ }
case HloOpcode::kAbs: {
auto sum_sq = ir_builder_->CreateFAdd(
ir_builder_->CreateFMul(EmitExtractReal(operand_value),
@@ -625,7 +676,6 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
switch (op->opcode()) {
- // case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support
case HloOpcode::kComplex:
return EmitComposeComplex(op, lhs_value, rhs_value);
case HloOpcode::kAdd:
@@ -669,10 +719,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
case HloOpcode::kMinimum:
return EmitFloatMin(lhs_value, rhs_value);
case HloOpcode::kPower:
- return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow,
- {lhs_value, rhs_value},
- {lhs_value->getType()}, ir_builder_);
-
+ return EmitPow(op->shape().element_type(), lhs_value, rhs_value);
+ case HloOpcode::kAtan2:
+ return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
default:
return Unimplemented("binary floating point op '%s'",
HloOpcodeString(op->opcode()).c_str());
@@ -768,9 +817,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
EmitExtractImag(lhs_value),
EmitExtractImag(rhs_value), ir_builder_));
- // TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic
- // case HloOpcode::kPower:
- // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(c/2+di/2)
+ case HloOpcode::kPower: {
+ // (a+bi)^(c+di) =
+ // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
+ // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
+ PrimitiveType component_type =
+ primitive_util::ComplexComponentType(op->shape().element_type());
+ auto a = EmitExtractReal(lhs_value);
+ auto b = EmitExtractImag(lhs_value);
+ auto c = EmitExtractReal(rhs_value);
+ auto d = EmitExtractImag(rhs_value);
+ auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
+ ir_builder_->CreateFMul(b, b));
+ auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
+ auto half_c = ir_builder_->CreateFMul(one_half, c);
+
+ TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
+ EmitPow(component_type, aa_p_bb, half_c));
+ auto neg_d = ir_builder_->CreateFNeg(d);
+ TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a));
+ auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs);
+ TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
+ EmitExp(component_type, neg_d_arg_lhs));
+ auto coeff =
+ ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
+ TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
+ auto half_d = ir_builder_->CreateFMul(one_half, d);
+ auto q =
+ ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs),
+ ir_builder_->CreateFMul(half_d, ln_aa_p_bb));
+ TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
+ TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
+ return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q),
+ ir_builder_->CreateFMul(coeff, sin_q));
+ }
default:
return Unimplemented("binary complex op '%s'",
HloOpcodeString(op->opcode()).c_str());
@@ -873,6 +953,43 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value));
}
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
+ llvm::Value* value) const {
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
+ {value->getType()}, ir_builder_);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
+ llvm::Value* value) const {
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
+ {value->getType()}, ir_builder_);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
+ llvm::Value* value) const {
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
+ {value->getType()}, ir_builder_);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
+ llvm::Value* value) const {
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
+ {value->getType()}, ir_builder_);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
+ llvm::Value* lhs,
+ llvm::Value* rhs) const {
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
+ {lhs->getType()}, ir_builder_);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
+ llvm::Value* lhs,
+ llvm::Value* rhs) const {
+ return Unimplemented("atan2");
+}
+
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
const HloInstruction* hlo, llvm::Value* x) const {
if (hlo->operand(0)->shape().element_type() != F32) {
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index cccb498f82..1a48eb5fcb 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -39,7 +39,7 @@ class ElementalIrEmitter {
module_(module),
hlo_module_config_(hlo_module_config) {}
- virtual ~ElementalIrEmitter() {}
+ virtual ~ElementalIrEmitter() = default;
virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
llvm::Value* operand_value) const;
@@ -92,6 +92,26 @@ class ElementalIrEmitter {
virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
llvm::Value* value) const;
+ virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
+ llvm::Value* lhs,
+ llvm::Value* rhs) const;
+
+ virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
+ llvm::Value* value) const;
+
+ virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
+ llvm::Value* value) const;
+
+ virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
+ llvm::Value* value) const;
+
+ virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
+ llvm::Value* value) const;
+
+ virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
+ llvm::Value* lhs,
+ llvm::Value* rhs) const;
+
virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
llvm::Value* x) const;
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 6bf00cfb8a..4b511cb4bb 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -135,10 +135,6 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
PrimitiveType output_type = op->shape().element_type();
switch (op->opcode()) {
- case HloOpcode::kAtan2:
- return EmitLibdeviceMathCall("__nv_atan2", {lhs_value, rhs_value},
- {lhs_input_type, rhs_input_type},
- output_type);
case HloOpcode::kRemainder: {
return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value},
{lhs_input_type, rhs_input_type},
@@ -199,29 +195,50 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitErfcInv(
return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type);
}
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(
+ PrimitiveType prim_type, llvm::Value* value) const {
+ return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type);
+}
+
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(
+ PrimitiveType prim_type, llvm::Value* value) const {
+ return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type);
+}
+
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(
+ PrimitiveType prim_type, llvm::Value* value) const {
+ return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type);
+}
+
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(
+ PrimitiveType prim_type, llvm::Value* value) const {
+ return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type);
+}
+
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
+ llvm::Value* lhs,
+ llvm::Value* rhs) const {
+ return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type},
+ prim_type);
+}
+
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
+ PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
+ return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type},
+ prim_type);
+}
+
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
PrimitiveType input_type = op->operand(0)->shape().element_type();
PrimitiveType output_type = op->shape().element_type();
switch (op->opcode()) {
- case HloOpcode::kExp:
- return EmitLibdeviceMathCall("__nv_exp", {operand_value}, {input_type},
- output_type);
case HloOpcode::kFloor:
return EmitLibdeviceMathCall("__nv_floor", {operand_value}, {input_type},
output_type);
case HloOpcode::kCeil:
return EmitLibdeviceMathCall("__nv_ceil", {operand_value}, {input_type},
output_type);
- case HloOpcode::kLog:
- return EmitLibdeviceMathCall("__nv_log", {operand_value}, {input_type},
- output_type);
- case HloOpcode::kCos:
- return EmitLibdeviceMathCall("__nv_cos", {operand_value}, {input_type},
- output_type);
- case HloOpcode::kSin:
- return EmitLibdeviceMathCall("__nv_sin", {operand_value}, {input_type},
- output_type);
case HloOpcode::kTanh:
return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type},
output_type);
@@ -230,224 +247,6 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp(
}
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
- PrimitiveType input_type = op->operand(0)->shape().element_type();
- TF_RET_CHECK(primitive_util::IsComplexType(input_type));
- PrimitiveType component_type =
- primitive_util::ComplexComponentType(input_type);
- switch (op->opcode()) {
- case HloOpcode::kPower: {
- // (a+bi)^(c+di) =
- // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
- // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
- auto a = EmitExtractReal(lhs_value);
- auto b = EmitExtractImag(lhs_value);
- auto c = EmitExtractReal(rhs_value);
- auto d = EmitExtractImag(rhs_value);
- auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
- ir_builder_->CreateFMul(b, b));
- auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
- auto half_c = ir_builder_->CreateFMul(one_half, c);
-
- TF_ASSIGN_OR_RETURN(
- auto aa_p_bb_to_half_c,
- EmitLibdeviceMathCall("__nv_pow", {aa_p_bb, half_c},
- {component_type, component_type},
- component_type));
- auto neg_d = ir_builder_->CreateFNeg(d);
- TF_ASSIGN_OR_RETURN(
- auto arg_lhs, EmitLibdeviceMathCall("__nv_atan2", {b, a},
- {component_type, component_type},
- component_type));
- auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs);
- TF_ASSIGN_OR_RETURN(
- auto e_to_neg_d_arg_lhs,
- EmitLibdeviceMathCall("__nv_exp", {neg_d_arg_lhs}, {component_type},
- component_type));
- auto coeff =
- ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
- TF_ASSIGN_OR_RETURN(
- auto ln_aa_p_bb,
- EmitLibdeviceMathCall("__nv_log", {aa_p_bb}, {component_type},
- component_type));
- auto half_d = ir_builder_->CreateFMul(one_half, d);
- auto q =
- ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs),
- ir_builder_->CreateFMul(half_d, ln_aa_p_bb));
- TF_ASSIGN_OR_RETURN(
- auto cos_q, EmitLibdeviceMathCall("__nv_cos", {q}, {component_type},
- component_type));
- TF_ASSIGN_OR_RETURN(
- auto sin_q, EmitLibdeviceMathCall("__nv_sin", {q}, {component_type},
- component_type));
- return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q),
- ir_builder_->CreateFMul(coeff, sin_q));
- }
- default:
- return ElementalIrEmitter::EmitComplexBinaryOp(op, lhs_value, rhs_value);
- }
-}
-
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const {
- PrimitiveType input_type = op->operand(0)->shape().element_type();
- PrimitiveType component_type =
- primitive_util::IsComplexType(input_type)
- ? primitive_util::ComplexComponentType(input_type)
- : input_type;
-
- switch (op->opcode()) {
- case HloOpcode::kLog: {
- // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
- auto a = EmitExtractReal(operand_value);
- auto b = EmitExtractImag(operand_value);
- llvm::Type* llvm_ty = a->getType();
- auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
- ir_builder_->CreateFMul(b, b));
- TF_ASSIGN_OR_RETURN(
- auto log_sum_sq,
- EmitLibdeviceMathCall("__nv_log", {sum_sq}, {component_type},
- component_type));
- TF_ASSIGN_OR_RETURN(
- auto angle, EmitLibdeviceMathCall("__nv_atan2", {b, a},
- {component_type, component_type},
- component_type));
- auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
- return EmitComposeComplex(
- op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle);
- }
- case HloOpcode::kExp: {
- // e^(a+bi) = e^a*(cos(b)+sin(b)i)
- auto b = EmitExtractImag(operand_value);
- TF_ASSIGN_OR_RETURN(
- auto exp_a,
- EmitLibdeviceMathCall("__nv_exp", {EmitExtractReal(operand_value)},
- {component_type}, component_type));
- TF_ASSIGN_OR_RETURN(
- auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type},
- component_type));
- TF_ASSIGN_OR_RETURN(
- auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type},
- component_type));
- return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
- ir_builder_->CreateFMul(exp_a, sin_b));
- }
- case HloOpcode::kCos: {
- // cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
- auto a = EmitExtractReal(operand_value);
- auto llvm_ty = a->getType();
- TF_ASSIGN_OR_RETURN(
- auto exp_b,
- EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)},
- {component_type}, component_type));
- TF_ASSIGN_OR_RETURN(
- auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type},
- component_type));
- TF_ASSIGN_OR_RETURN(
- auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type},
- component_type));
- auto half_exp_b =
- ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
- auto half_exp_neg_b =
- ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
- return EmitComposeComplex(
- op,
- ir_builder_->CreateFMul(
- cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)),
- ir_builder_->CreateFMul(
- sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b)));
- }
-
- case HloOpcode::kSin: {
- // sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
- auto a = EmitExtractReal(operand_value);
- auto llvm_ty = a->getType();
- TF_ASSIGN_OR_RETURN(
- auto exp_b,
- EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)},
- {component_type}, component_type));
- TF_ASSIGN_OR_RETURN(
- auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type},
- component_type));
- TF_ASSIGN_OR_RETURN(
- auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type},
- component_type));
- auto half_exp_b =
- ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
- auto half_exp_neg_b =
- ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
- return EmitComposeComplex(
- op,
- ir_builder_->CreateFMul(
- sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)),
- ir_builder_->CreateFMul(
- cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b)));
- }
- case HloOpcode::kTanh: {
- /*
- tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
- e^(a+bi) = e^a*(cos(b)+sin(b)i)
- so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
- (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
- cos(b)=cos(-b), sin(-b)=-sin(b)
- so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
- (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
- =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
- (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
- =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
- (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
- This is a complex division, so we can multiply by denom_conj/denom_conj
- =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
- (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
- ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
- =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
- i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
- ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
- */
- auto a = EmitExtractReal(operand_value);
- auto b = EmitExtractImag(operand_value);
- TF_ASSIGN_OR_RETURN(
- auto exp_a, EmitLibdeviceMathCall("__nv_exp", {a}, {component_type},
- component_type));
- TF_ASSIGN_OR_RETURN(
- auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type},
- component_type));
- TF_ASSIGN_OR_RETURN(
- auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type},
- component_type));
- auto exp_neg_a = ir_builder_->CreateFDiv(
- llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
- auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub(
- ir_builder_->CreateFMul(exp_a, exp_a),
- ir_builder_->CreateFMul(exp_neg_a, exp_neg_a));
- auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b);
- auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b);
- auto real_num = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
- ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
- auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b);
- auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a);
- auto exp_a_plus_exp_neg_a_sq =
- ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
- auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a);
- auto exp_a_minus_exp_neg_a_sq =
- ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
- auto imag_num = ir_builder_->CreateFMul(
- cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq,
- exp_a_minus_exp_neg_a_sq));
- auto denom = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
- ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
- return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom),
- ir_builder_->CreateFDiv(imag_num, denom));
- }
- default:
- return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value);
- }
-}
-
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
index 6a537d0152..77d4569b1e 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -54,20 +54,31 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;
- StatusOr<llvm::Value*> EmitComplexUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const override;
-
StatusOr<llvm::Value*> EmitFloatBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const override;
- StatusOr<llvm::Value*> EmitComplexBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const override;
-
StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
llvm::Value* value) const override;
+ StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
+ llvm::Value* value) const override;
+
+ StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
+ llvm::Value* value) const override;
+
+ StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
+ llvm::Value* value) const override;
+
+ StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
+ llvm::Value* value) const override;
+
+ StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, llvm::Value* lhs,
+ llvm::Value* rhs) const override;
+
+ StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
+ llvm::Value* rhs) const override;
+
llvm::Value* EmitThreadId() const override;
private: