aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-13 11:34:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-13 11:40:57 -0800
commit58f7858601b72aa3c5854571f2152b91d1795e29 (patch)
tree214e1ff498ecc21573dbe444fc5a0142915152af /tensorflow/compiler
parent659d8cbc3aaffc0249afee1ec437639beda8d243 (diff)
[TF:XLA] Adding test coverage for more C64 operations, and ensuring they pass.
Included here: - reduction ops (reduce_sum, reduce_prod) - unaries: tanh, sigmoid (currently GPU only) - binaries: pow (currently GPU only) PiperOrigin-RevId: 175562417
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py50
-rw-r--r--tensorflow/compiler/tests/reduce_ops_test.py30
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py31
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc15
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc25
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc208
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.h12
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc164
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc33
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc13
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc20
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.h3
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc153
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h10
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc32
16 files changed, 623 insertions, 180 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index d412c572ae..654dc15e86 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -366,16 +366,52 @@ class BinaryOpsTest(XLATestCase):
self._testBinary(
gen_math_ops._real_div,
- np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype),
- np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype),
+ np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j], dtype=dtype),
+ np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j], dtype=dtype),
+ expected=np.array(
+ [1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2],
+ dtype=dtype))
+
+ # Test inf/nan scenarios.
+ self._testBinary(
+ gen_math_ops._real_div,
+ np.array([4 + 3j, 4, 3j, -4, -4j, 2 - 3j], dtype=dtype),
+ np.array([0, 0, 0, 0, 0, 0], dtype=dtype),
expected=np.array(
[
- 1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2,
- float("inf")
+ dtype(1 + 1j) / 0,
+ dtype(1) / 0,
+ dtype(1j) / 0,
+ dtype(-1) / 0,
+ dtype(-1j) / 0,
+ dtype(1 - 1j) / 0
],
dtype=dtype))
- # TODO(b/65408531): support+test pow for cplx
+ atan2_supported = self.device == "XLA_GPU"
+ if atan2_supported:
+ self._testBinary(
+ math_ops.pow,
+ dtype(3 + 2j),
+ dtype(4 - 5j),
+ expected=np.power(dtype(3 + 2j), dtype(4 - 5j)))
+ self._testBinary( # empty rhs
+ math_ops.pow,
+ np.array([1 + 2j, 2 - 3j], dtype=dtype),
+ np.zeros(shape=[0, 2], dtype=dtype),
+ expected=np.zeros(shape=[0, 2], dtype=dtype))
+ self._testBinary( # to zero power
+ math_ops.pow,
+ np.array([1 + 2j, 2 - 3j], dtype=dtype),
+ np.zeros(shape=[1, 2], dtype=dtype),
+ expected=np.ones(shape=[1, 2], dtype=dtype))
+ lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype)
+ rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype)
+ scalar = dtype(2 + 2j)
+ self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs))
+ self._testBinary(
+ math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs))
+ self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar))
lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype)
rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype)
@@ -385,7 +421,9 @@ class BinaryOpsTest(XLATestCase):
self._testBinary(
gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs))
- # TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow)
+ if atan2_supported:
+ self._testBinary(
+ gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2)
self._testBinary(
gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs))
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index efda2cc207..965fdf684b 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -67,25 +67,37 @@ class ReduceOpsTest(XLATestCase):
np.arange(-10, -4).reshape(2, 3),
np.arange(-4, 2).reshape(2, 3),
]
- NONEMPTY_FLOAT_DATA = [
- np.arange(1, 7).reshape(2, 3),
- np.arange(-10, -4).reshape(2, 3),
- np.arange(-4, 2).reshape(2, 3),
+ COMPLEX_DATA = [
+ np.zeros(shape=(2, 0)).astype(np.complex64),
+ np.zeros(shape=(0, 30)).astype(np.complex64),
+ np.arange(1, 13, dtype=np.float32).view(np.complex64).reshape(2, 3),
+ np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3),
+ np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3),
]
+ NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0]
+ NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0]
BOOL_DATA = [
np.array([], dtype=np.bool).reshape(2, 0),
np.array([], dtype=np.bool).reshape(0, 3),
np.array([[False, True, False], [True, True, False]]),
]
- def testReduceSum(self):
+ def testReduceSumF32(self):
self._testReduction(math_ops.reduce_sum, np.sum, np.float32,
self.FLOAT_DATA)
- def testReduceProd(self):
+ def testReduceSumC64(self):
+ self._testReduction(math_ops.reduce_sum, np.sum, np.complex64,
+ self.COMPLEX_DATA)
+
+ def testReduceProdF32(self):
self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
self.FLOAT_DATA)
+ def testReduceProdC64(self):
+ self._testReduction(math_ops.reduce_prod, np.prod, np.complex64,
+ self.COMPLEX_DATA)
+
def testReduceMin(self):
def reference_min(inp, axis):
@@ -108,12 +120,16 @@ class ReduceOpsTest(XLATestCase):
self._testReduction(math_ops.reduce_max, reference_max, np.float32,
self.FLOAT_DATA)
- def testReduceMean(self):
+ def testReduceMeanF32(self):
# TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
# reducing across zero inputs.
self._testReduction(math_ops.reduce_mean, np.mean, np.float32,
self.NONEMPTY_FLOAT_DATA)
+ def testReduceMeanC64(self):
+ self._testReduction(math_ops.reduce_mean, np.mean, np.complex64,
+ self.NONEMPTY_COMPLEX_DATA)
+
def testReduceAll(self):
self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA)
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 76644380bd..a9a3f4f97f 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -330,13 +330,23 @@ class UnaryOpsTest(XLATestCase):
def testComplexOps(self):
for dtype in self.complex_types:
- # TODO(b/65408531): math_ops.acosh (needs pow)
- # TODO(b/65408531): math_ops.asinh (needs pow)
# TODO(b/65408531): Wider support for log (needs atan2).
atan2_supported = self.device == "XLA_GPU"
if atan2_supported:
self._assertOpOutputMatchesExpected(
+ math_ops.acosh,
+ np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
+ expected=np.arccosh(
+ np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
+
+ self._assertOpOutputMatchesExpected(
+ math_ops.asinh,
+ np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
+ expected=np.arcsinh(
+ np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
+
+ self._assertOpOutputMatchesExpected(
math_ops.atanh,
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
expected=np.arctanh(
@@ -392,19 +402,26 @@ class UnaryOpsTest(XLATestCase):
expected=np.log1p(
np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)))
- # TODO(b/34703906): math_ops.rsqrt (needs pow)
+ val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)
+ self._assertOpOutputMatchesExpected(
+ math_ops.rsqrt, val, expected=1 / np.sqrt(val))
- # TODO(b/34703906): math_ops.sigmoid (needs tanh)
+ self._assertOpOutputMatchesExpected(
+ math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val)))
- # TODO(b/34703906): math_ops.sqrt (needs pow)
+ self._assertOpOutputMatchesExpected(
+ math_ops.sqrt, val, expected=np.sqrt(val))
+
+ self._assertOpOutputMatchesExpected(
+ math_ops.tanh,
+ np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
+ expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.tan,
np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
- # TODO(b/34703906): math_ops.tanh (as itself)
-
ctypes = {np.complex64: np.float32}
self._assertOpOutputMatchesExpected(
math_ops.abs,
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 35fe0d1a51..5c9b29f6e2 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -135,7 +135,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleConvert(HloInstruction* convert) override;
+ Status HandleComplex(HloInstruction* complex) override;
+
Status HandleReal(HloInstruction* real) override;
+
Status HandleImag(HloInstruction* imag) override;
Status HandleConvolution(HloInstruction* convolution) override;
@@ -947,6 +950,18 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
return Status::OK();
}
+// Complex(Real(c), Imag(c)) -> c
+Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
+ auto real = complex->mutable_operand(0);
+ auto imag = complex->mutable_operand(1);
+ if (real->opcode() == HloOpcode::kReal &&
+ imag->opcode() == HloOpcode::kImag &&
+ real->operand(0) == imag->operand(0)) {
+ return ReplaceInstruction(complex, real->mutable_operand(0));
+ }
+ return Status::OK();
+}
+
// Real(Complex(r, i)) -> r
Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
auto operand = real->mutable_operand(0);
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index c06e330bc1..620f0a54fa 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -371,6 +371,31 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
EXPECT_EQ(root, param0);
}
+// Test that complex(real(c), imag(c)) is simplified to c.
+TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) {
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
+ Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r2c64, "param0"));
+ HloInstruction* real = builder.AddInstruction(
+ HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0));
+ HloInstruction* imag = builder.AddInstruction(
+ HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0));
+ HloInstruction* cplx = builder.AddInstruction(
+ HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, cplx);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, param0);
+}
+
// Test that real(complex(r,i)) is simplified to r.
TEST_F(AlgebraicSimplifierTest, RealOfComplex) {
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index a945657712..606868034a 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -93,14 +93,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
primitive_util::ComplexComponentType(to_type), module_);
if (primitive_util::IsSignedIntegralType(from_type)) {
- return ComposeComplex(
+ return EmitComposeComplex(
op,
ir_builder_->CreateSIToFP(operand_value, to_ir_component_type),
nullptr);
}
if (primitive_util::IsUnsignedIntegralType(from_type) ||
from_type == PRED) {
- return ComposeComplex(
+ return EmitComposeComplex(
op,
ir_builder_->CreateUIToFP(operand_value, to_ir_component_type),
nullptr);
@@ -178,9 +178,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
PrimitiveType to_component_type =
primitive_util::ComplexComponentType(to_type);
if (from_type == to_component_type) {
- return ComposeComplex(op, operand_value, nullptr);
+ return EmitComposeComplex(op, operand_value, nullptr);
}
- return ComposeComplex(
+ return EmitComposeComplex(
op,
ir_builder_->CreateFPCast(
operand_value,
@@ -269,15 +269,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
- auto real = [&](llvm::Value* x) {
- return ir_builder_->CreateExtractValue(x, {0});
- };
- auto imag = [&](llvm::Value* x) {
- return ir_builder_->CreateExtractValue(x, {1});
- };
switch (op->opcode()) {
// TODO(b/65209142): Angle/Log require atan2.
- // case HloOpcode::kAngle:
// case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -291,24 +284,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
primitive_util::ComplexComponentType(to_type);
auto to_ir_component_type =
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
- return ComposeComplex(
+ return EmitComposeComplex(
op,
- ir_builder_->CreateFPCast(real(operand_value), to_ir_component_type),
- ir_builder_->CreateFPCast(imag(operand_value), to_ir_component_type));
+ ir_builder_->CreateFPCast(EmitExtractReal(operand_value),
+ to_ir_component_type),
+ ir_builder_->CreateFPCast(EmitExtractImag(operand_value),
+ to_ir_component_type));
}
case HloOpcode::kExp: {
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
auto exp_a = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::exp, {real(operand_value)},
- {real(operand_value)->getType()}, ir_builder_);
+ llvm::Intrinsic::exp, {EmitExtractReal(operand_value)},
+ {EmitExtractReal(operand_value)->getType()}, ir_builder_);
auto cos_b = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::cos, {imag(operand_value)},
- {imag(operand_value)->getType()}, ir_builder_);
+ llvm::Intrinsic::cos, {EmitExtractImag(operand_value)},
+ {EmitExtractImag(operand_value)->getType()}, ir_builder_);
auto sin_b = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::sin, {imag(operand_value)},
- {imag(operand_value)->getType()}, ir_builder_);
- return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
- ir_builder_->CreateFMul(exp_a, sin_b));
+ llvm::Intrinsic::sin, {EmitExtractImag(operand_value)},
+ {EmitExtractImag(operand_value)->getType()}, ir_builder_);
+ return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
+ ir_builder_->CreateFMul(exp_a, sin_b));
}
case HloOpcode::kCos: {
// cos(z) = .5(e^(iz) + e^(-iz))
@@ -318,8 +313,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
// cos(-x) = cos(x) and sin(-x) = -sin(x), so
// cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
// = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
- auto a = real(operand_value);
- auto b = imag(operand_value);
+ auto a = EmitExtractReal(operand_value);
+ auto b = EmitExtractImag(operand_value);
auto type = a->getType();
auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b},
{type}, ir_builder_);
@@ -331,7 +326,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
{type}, ir_builder_);
auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a},
{type}, ir_builder_);
- return ComposeComplex(
+ return EmitComposeComplex(
op,
ir_builder_->CreateFMul(
cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)),
@@ -348,8 +343,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
// cos(-x) = cos(x) and sin(-x) = -sin(x), so
// = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
// = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
- auto a = real(operand_value);
- auto b = imag(operand_value);
+ auto a = EmitExtractReal(operand_value);
+ auto b = EmitExtractImag(operand_value);
auto type = a->getType();
auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b},
{type}, ir_builder_);
@@ -361,7 +356,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
{type}, ir_builder_);
auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a},
{type}, ir_builder_);
- return ComposeComplex(
+ return EmitComposeComplex(
op,
ir_builder_->CreateFMul(
sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)),
@@ -370,33 +365,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
}
case HloOpcode::kAbs: {
auto sum_sq = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(real(operand_value), real(operand_value)),
- ir_builder_->CreateFMul(imag(operand_value), imag(operand_value)));
+ ir_builder_->CreateFMul(EmitExtractReal(operand_value),
+ EmitExtractReal(operand_value)),
+ ir_builder_->CreateFMul(EmitExtractImag(operand_value),
+ EmitExtractImag(operand_value)));
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq},
{sum_sq->getType()}, ir_builder_);
}
case HloOpcode::kSign: { // Sign(c) = c / |c|
auto sum_sq = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(real(operand_value), real(operand_value)),
- ir_builder_->CreateFMul(imag(operand_value), imag(operand_value)));
+ ir_builder_->CreateFMul(EmitExtractReal(operand_value),
+ EmitExtractReal(operand_value)),
+ ir_builder_->CreateFMul(EmitExtractImag(operand_value),
+ EmitExtractImag(operand_value)));
auto cplx_abs = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_);
auto type = cplx_abs->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero);
return ir_builder_->CreateSelect(
- oeq, ComposeComplex(op, zero, zero),
- ComposeComplex(
- op, ir_builder_->CreateFDiv(real(operand_value), cplx_abs),
- ir_builder_->CreateFDiv(imag(operand_value), cplx_abs)));
+ oeq, EmitComposeComplex(op, zero, zero),
+ EmitComposeComplex(
+ op,
+ ir_builder_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs),
+ ir_builder_->CreateFDiv(EmitExtractImag(operand_value),
+ cplx_abs)));
}
case HloOpcode::kNegate:
- return ComposeComplex(op, ir_builder_->CreateFNeg(real(operand_value)),
- ir_builder_->CreateFNeg(imag(operand_value)));
+ return EmitComposeComplex(
+ op, ir_builder_->CreateFNeg(EmitExtractReal(operand_value)),
+ ir_builder_->CreateFNeg(EmitExtractImag(operand_value)));
case HloOpcode::kReal:
- return real(operand_value);
+ return EmitExtractReal(operand_value);
case HloOpcode::kImag:
- return imag(operand_value);
+ return EmitExtractImag(operand_value);
default:
return Unimplemented("unary complex op '%s'",
HloOpcodeString(op->opcode()).c_str());
@@ -424,7 +426,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
switch (op->opcode()) {
// case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support
case HloOpcode::kComplex:
- return ComposeComplex(op, lhs_value, rhs_value);
+ return EmitComposeComplex(op, lhs_value, rhs_value);
case HloOpcode::kAdd:
return ir_builder_->CreateFAdd(lhs_value, rhs_value);
case HloOpcode::kSubtract:
@@ -479,54 +481,66 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
- auto real = [&](llvm::Value* x) {
- return ir_builder_->CreateExtractValue(x, {0});
- };
- auto imag = [&](llvm::Value* x) {
- return ir_builder_->CreateExtractValue(x, {1});
- };
switch (op->opcode()) {
case HloOpcode::kAdd:
- return ComposeComplex(
- op, ir_builder_->CreateFAdd(real(lhs_value), real(rhs_value)),
- ir_builder_->CreateFAdd(imag(lhs_value), imag(rhs_value)));
+ return EmitComposeComplex(
+ op,
+ ir_builder_->CreateFAdd(EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value)),
+ ir_builder_->CreateFAdd(EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value)));
case HloOpcode::kSubtract:
- return ComposeComplex(
- op, ir_builder_->CreateFSub(real(lhs_value), real(rhs_value)),
- ir_builder_->CreateFSub(imag(lhs_value), imag(rhs_value)));
+ return EmitComposeComplex(
+ op,
+ ir_builder_->CreateFSub(EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value)),
+ ir_builder_->CreateFSub(EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value)));
case HloOpcode::kMultiply:
- return ComposeComplex(
+ return EmitComposeComplex(
op,
ir_builder_->CreateFSub(
- ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)),
- ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))),
+ ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value)),
+ ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value))),
ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)),
- ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value))));
+ ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
+ EmitExtractImag(rhs_value)),
+ ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
+ EmitExtractReal(rhs_value))));
case HloOpcode::kDivide: {
// (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di))
// = ((ac + bd) + (bc - ad)i) / (c^2 + d^2)
auto rhs_sum_sq = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(real(rhs_value), real(rhs_value)),
- ir_builder_->CreateFMul(imag(rhs_value), imag(rhs_value)));
+ ir_builder_->CreateFMul(EmitExtractReal(rhs_value),
+ EmitExtractReal(rhs_value)),
+ ir_builder_->CreateFMul(EmitExtractImag(rhs_value),
+ EmitExtractImag(rhs_value)));
auto type = rhs_sum_sq->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero);
+ auto real_inf_or_nan =
+ ir_builder_->CreateFDiv(EmitExtractReal(lhs_value), zero);
+ auto imag_inf_or_nan =
+ ir_builder_->CreateFDiv(EmitExtractImag(lhs_value), zero);
return ir_builder_->CreateSelect(
- oeq, ComposeComplex(op, llvm::ConstantFP::getInfinity(type), zero),
- ComposeComplex(
+ oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan),
+ EmitComposeComplex(
op,
ir_builder_->CreateFDiv(
ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)),
- ir_builder_->CreateFMul(imag(lhs_value),
- imag(rhs_value))),
+ ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value)),
+ ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value))),
rhs_sum_sq),
ir_builder_->CreateFDiv(
ir_builder_->CreateFSub(
- ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)),
- ir_builder_->CreateFMul(real(lhs_value),
- imag(rhs_value))),
+ ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
+ EmitExtractReal(rhs_value)),
+ ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
+ EmitExtractImag(rhs_value))),
rhs_sum_sq)));
}
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
@@ -538,16 +552,20 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
// matches C++'s semantics.
case HloOpcode::kEq:
return ir_builder_->CreateAnd(
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, real(lhs_value),
- real(rhs_value), ir_builder_),
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, imag(lhs_value),
- imag(rhs_value), ir_builder_));
+ llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
+ EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value), ir_builder_),
+ llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
+ EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value), ir_builder_));
case HloOpcode::kNe:
return ir_builder_->CreateOr(
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, real(lhs_value),
- real(rhs_value), ir_builder_),
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, imag(lhs_value),
- imag(rhs_value), ir_builder_));
+ llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
+ EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value), ir_builder_),
+ llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
+ EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value), ir_builder_));
// TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic
// case HloOpcode::kPower:
@@ -1565,25 +1583,25 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
llvm::Value* next_accumulator;
if (primitive_util::IsComplexType(primitive_type)) {
- auto real = [&](llvm::Value* x) {
- return ir_builder_->CreateExtractValue(x, {0});
- };
- auto imag = [&](llvm::Value* x) {
- return ir_builder_->CreateExtractValue(x, {1});
- };
llvm::Value* product_real = ir_builder_->CreateFSub(
- ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)),
- ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value)));
+ ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value)),
+ ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value)));
llvm::Value* product_imag = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)),
- ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)));
+ ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
+ EmitExtractImag(rhs_value)),
+ ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
+ EmitExtractReal(rhs_value)));
next_accumulator = ir_builder_->CreateInsertValue(
current_accumulator,
- ir_builder_->CreateFAdd(real(current_accumulator), product_real),
+ ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator),
+ product_real),
{0});
next_accumulator = ir_builder_->CreateInsertValue(
next_accumulator,
- ir_builder_->CreateFAdd(imag(current_accumulator), product_imag),
+ ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator),
+ product_imag),
{1});
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
next_accumulator = ir_builder_->CreateFAdd(
@@ -1607,9 +1625,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
}
}
-llvm::Value* ElementalIrEmitter::ComposeComplex(const HloInstruction* op,
- llvm::Value* real,
- llvm::Value* imag) const {
+llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const {
+ return ir_builder_->CreateExtractValue(value, {0});
+}
+
+llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const {
+ return ir_builder_->CreateExtractValue(value, {1});
+}
+
+llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
+ llvm::Value* real,
+ llvm::Value* imag) const {
auto cplx_type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
auto complex = ir_builder_->CreateInsertValue(
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 9d32436e38..cccb498f82 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -95,6 +95,13 @@ class ElementalIrEmitter {
virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
llvm::Value* x) const;
+ virtual llvm::Value* EmitExtractReal(llvm::Value* value) const;
+ virtual llvm::Value* EmitExtractImag(llvm::Value* value) const;
+
+ // Composes a complex struct. imag may be nullptr for simple cast operations.
+ llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
+ llvm::Value* imag) const;
+
// A helper method for MakeElementGenerator. Given an elementwise op `hlo` and
// the target array index, computes the source array index of its
// `operand_no`-th operand.
@@ -117,11 +124,6 @@ class ElementalIrEmitter {
// compiled executable outside of the HLO code itself.
const HloModuleConfig& hlo_module_config_;
- protected:
- // Composes a complex struct. imag may be nullptr for simple cast operations.
- llvm::Value* ComposeComplex(const HloInstruction* op, llvm::Value* real,
- llvm::Value* imag) const;
-
private:
// Returns a ElementGenerator for a RNG HloInstruction.
llvm_ir::ElementGenerator MakeRngElementGenerator(
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 1b94499bc6..6bf00cfb8a 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -230,6 +230,66 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp(
}
}
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexBinaryOp(
+ const HloInstruction* op, llvm::Value* lhs_value,
+ llvm::Value* rhs_value) const {
+ PrimitiveType input_type = op->operand(0)->shape().element_type();
+ TF_RET_CHECK(primitive_util::IsComplexType(input_type));
+ PrimitiveType component_type =
+ primitive_util::ComplexComponentType(input_type);
+ switch (op->opcode()) {
+ case HloOpcode::kPower: {
+ // (a+bi)^(c+di) =
+ // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
+ // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
+ auto a = EmitExtractReal(lhs_value);
+ auto b = EmitExtractImag(lhs_value);
+ auto c = EmitExtractReal(rhs_value);
+ auto d = EmitExtractImag(rhs_value);
+ auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
+ ir_builder_->CreateFMul(b, b));
+ auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
+ auto half_c = ir_builder_->CreateFMul(one_half, c);
+
+ TF_ASSIGN_OR_RETURN(
+ auto aa_p_bb_to_half_c,
+ EmitLibdeviceMathCall("__nv_pow", {aa_p_bb, half_c},
+ {component_type, component_type},
+ component_type));
+ auto neg_d = ir_builder_->CreateFNeg(d);
+ TF_ASSIGN_OR_RETURN(
+ auto arg_lhs, EmitLibdeviceMathCall("__nv_atan2", {b, a},
+ {component_type, component_type},
+ component_type));
+ auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs);
+ TF_ASSIGN_OR_RETURN(
+ auto e_to_neg_d_arg_lhs,
+ EmitLibdeviceMathCall("__nv_exp", {neg_d_arg_lhs}, {component_type},
+ component_type));
+ auto coeff =
+ ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
+ TF_ASSIGN_OR_RETURN(
+ auto ln_aa_p_bb,
+ EmitLibdeviceMathCall("__nv_log", {aa_p_bb}, {component_type},
+ component_type));
+ auto half_d = ir_builder_->CreateFMul(one_half, d);
+ auto q =
+ ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs),
+ ir_builder_->CreateFMul(half_d, ln_aa_p_bb));
+ TF_ASSIGN_OR_RETURN(
+ auto cos_q, EmitLibdeviceMathCall("__nv_cos", {q}, {component_type},
+ component_type));
+ TF_ASSIGN_OR_RETURN(
+ auto sin_q, EmitLibdeviceMathCall("__nv_sin", {q}, {component_type},
+ component_type));
+ return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q),
+ ir_builder_->CreateFMul(coeff, sin_q));
+ }
+ default:
+ return ElementalIrEmitter::EmitComplexBinaryOp(op, lhs_value, rhs_value);
+ }
+}
+
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
PrimitiveType input_type = op->operand(0)->shape().element_type();
@@ -237,18 +297,12 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp(
primitive_util::IsComplexType(input_type)
? primitive_util::ComplexComponentType(input_type)
: input_type;
- auto real = [&](llvm::Value* x) {
- return ir_builder_->CreateExtractValue(x, {0});
- };
- auto imag = [&](llvm::Value* x) {
- return ir_builder_->CreateExtractValue(x, {1});
- };
switch (op->opcode()) {
case HloOpcode::kLog: {
// log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
- auto a = real(operand_value);
- auto b = imag(operand_value);
+ auto a = EmitExtractReal(operand_value);
+ auto b = EmitExtractImag(operand_value);
llvm::Type* llvm_ty = a->getType();
auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
ir_builder_->CreateFMul(b, b));
@@ -261,34 +315,33 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp(
{component_type, component_type},
component_type));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
- return ComposeComplex(op, ir_builder_->CreateFMul(one_half, log_sum_sq),
- angle);
+ return EmitComposeComplex(
+ op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle);
}
- // TODO(b/65408531): Implement kPower on GPU, where atan2 is available.
- // case HloOpcode::kPower:
- // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(0.5(c+di))
case HloOpcode::kExp: {
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
- auto b = imag(operand_value);
+ auto b = EmitExtractImag(operand_value);
TF_ASSIGN_OR_RETURN(
- auto exp_a, EmitLibdeviceMathCall("__nv_exp", {real(operand_value)},
- {component_type}, component_type));
+ auto exp_a,
+ EmitLibdeviceMathCall("__nv_exp", {EmitExtractReal(operand_value)},
+ {component_type}, component_type));
TF_ASSIGN_OR_RETURN(
auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type},
component_type));
TF_ASSIGN_OR_RETURN(
auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type},
component_type));
- return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
- ir_builder_->CreateFMul(exp_a, sin_b));
+ return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
+ ir_builder_->CreateFMul(exp_a, sin_b));
}
case HloOpcode::kCos: {
// cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
- auto a = real(operand_value);
+ auto a = EmitExtractReal(operand_value);
auto llvm_ty = a->getType();
TF_ASSIGN_OR_RETURN(
- auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)},
- {component_type}, component_type));
+ auto exp_b,
+ EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)},
+ {component_type}, component_type));
TF_ASSIGN_OR_RETURN(
auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type},
component_type));
@@ -299,7 +352,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp(
ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
auto half_exp_neg_b =
ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
- return ComposeComplex(
+ return EmitComposeComplex(
op,
ir_builder_->CreateFMul(
cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)),
@@ -309,11 +362,12 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp(
case HloOpcode::kSin: {
// sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
- auto a = real(operand_value);
+ auto a = EmitExtractReal(operand_value);
auto llvm_ty = a->getType();
TF_ASSIGN_OR_RETURN(
- auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)},
- {component_type}, component_type));
+ auto exp_b,
+ EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)},
+ {component_type}, component_type));
TF_ASSIGN_OR_RETURN(
auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type},
component_type));
@@ -324,13 +378,71 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp(
ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
auto half_exp_neg_b =
ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
- return ComposeComplex(
+ return EmitComposeComplex(
op,
ir_builder_->CreateFMul(
sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)),
ir_builder_->CreateFMul(
cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b)));
}
+ case HloOpcode::kTanh: {
+ /*
+ tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
+ e^(a+bi) = e^a*(cos(b)+sin(b)i)
+ so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
+ (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
+ cos(b)=cos(-b), sin(-b)=-sin(b)
+ so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
+ (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
+ =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
+ (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
+ =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
+ (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
+ This is a complex division, so we can multiply by denom_conj/denom_conj
+ =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
+ (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
+ ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
+ =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
+ i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
+ ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
+ */
+ auto a = EmitExtractReal(operand_value);
+ auto b = EmitExtractImag(operand_value);
+ TF_ASSIGN_OR_RETURN(
+ auto exp_a, EmitLibdeviceMathCall("__nv_exp", {a}, {component_type},
+ component_type));
+ TF_ASSIGN_OR_RETURN(
+ auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type},
+ component_type));
+ TF_ASSIGN_OR_RETURN(
+ auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type},
+ component_type));
+ auto exp_neg_a = ir_builder_->CreateFDiv(
+ llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
+ auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub(
+ ir_builder_->CreateFMul(exp_a, exp_a),
+ ir_builder_->CreateFMul(exp_neg_a, exp_neg_a));
+ auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b);
+ auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b);
+ auto real_num = ir_builder_->CreateFAdd(
+ ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
+ ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
+ auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b);
+ auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a);
+ auto exp_a_plus_exp_neg_a_sq =
+ ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
+ auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a);
+ auto exp_a_minus_exp_neg_a_sq =
+ ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
+ auto imag_num = ir_builder_->CreateFMul(
+ cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq,
+ exp_a_minus_exp_neg_a_sq));
+ auto denom = ir_builder_->CreateFAdd(
+ ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
+ ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
+ return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom),
+ ir_builder_->CreateFDiv(imag_num, denom));
+ }
default:
return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value);
}
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
index 3defa1b696..6a537d0152 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -61,6 +61,10 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const override;
+ StatusOr<llvm::Value*> EmitComplexBinaryOp(
+ const HloInstruction* op, llvm::Value* lhs_value,
+ llvm::Value* rhs_value) const override;
+
StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
llvm::Value* value) const override;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 9d55c7859d..af2a92e11e 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -293,29 +293,30 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation(
computation, {old_output_location, source_address}, new_output_location));
// (old_output, success) = atomicCAS(output_address, old_output, new_output);
- llvm::Type* element_int_ir_type =
- ir_builder_.getIntNTy(element_ir_type->getScalarSizeInBits());
- // cmpxchg accetps integer only, so we bitcast the operands (old_output and
- // new_output) to integers of the same bit width, and bitcast the result
- // back to the original element type.
- llvm::Value* old_output =
- ir_builder_.CreateLoad(old_output_location, "old_output");
- llvm::Value* new_output =
- ir_builder_.CreateLoad(new_output_location, "new_output");
+ int num_bits = llvm_ir::GetSizeInBits(element_ir_type);
+ llvm::Type* element_int_ir_type = ir_builder_.getIntNTy(num_bits);
+ // cmpxchg accepts integer only, and bitcast refuses to operate on aggregate
+ // types, so we bitcast load and store addresses to intN* of the same bit
+ // width.
+ llvm::Value* old_output = ir_builder_.CreateLoad(
+ ir_builder_.CreateBitCast(old_output_location,
+ element_int_ir_type->getPointerTo()),
+ "old_output");
+ llvm::Value* new_output = ir_builder_.CreateLoad(
+ ir_builder_.CreateBitCast(new_output_location,
+ element_int_ir_type->getPointerTo()),
+ "new_output");
llvm::Value* ret_value = ir_builder_.CreateAtomicCmpXchg(
ir_builder_.CreateBitCast(output_address,
element_int_ir_type->getPointerTo()),
- ir_builder_.CreateBitCast(old_output, element_int_ir_type),
- ir_builder_.CreateBitCast(new_output, element_int_ir_type),
- llvm::AtomicOrdering::SequentiallyConsistent,
+ old_output, new_output, llvm::AtomicOrdering::SequentiallyConsistent,
llvm::AtomicOrdering::SequentiallyConsistent);
// cmpxchg returns a pair. The first element is the original value at
// output_address and the second element is whether the swap is successful.
ir_builder_.CreateStore(
- ir_builder_.CreateBitCast(
- ir_builder_.CreateExtractValue(ret_value, 0, "old_output"),
- element_ir_type),
- old_output_location);
+ ir_builder_.CreateExtractValue(ret_value, 0, "old_output"),
+ ir_builder_.CreateBitCast(old_output_location,
+ element_int_ir_type->getPointerTo()));
ir_builder_.CreateCondBr(
ir_builder_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb,
loop_body_bb);
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 7b4662fc80..db78f4b84d 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -1081,16 +1081,25 @@ Status IrEmitterUnnested::EmitRowReduction(
// from the warp.
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
&ir_builder_);
+ int bit_width = llvm_ir::GetSizeInBits(element_ir_type);
+ // bitcast cannot be applied to aggregate types (even packed ones), so we
+ // instead bitcast addresses of load/store to intN* of the same bit-width.
+ llvm::Type* shuffle_ir_type = element_ir_type->isStructTy()
+ ? ir_builder_.getIntNTy(bit_width)
+ : element_ir_type;
for (int shuffle_distance = 16; shuffle_distance >= 1;
shuffle_distance /= 2) {
llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
- partial_reduction_result_address, "partial_reduction_result");
+ ir_builder_.CreateBitCast(partial_reduction_result_address,
+ shuffle_ir_type->getPointerTo()),
+ "partial_reduction_result");
llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca(
element_ir_type, nullptr, "result_from_other_lane");
ir_builder_.CreateStore(
EmitShuffleDown(partial_reduction_result,
ir_builder_.getInt32(shuffle_distance), &ir_builder_),
- result_from_other_lane);
+ ir_builder_.CreateBitCast(result_from_other_lane,
+ shuffle_ir_type->getPointerTo()));
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducer, {partial_reduction_result_address, result_from_other_lane},
partial_reduction_result_address));
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index d95409e399..086c8dae9e 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -163,8 +163,9 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
// z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
// imaginary part of z.
return llvm::StructType::create(
- "complex64", llvm::Type::getFloatTy(module->getContext()),
- llvm::Type::getFloatTy(module->getContext()));
+ {llvm::Type::getFloatTy(module->getContext()),
+ llvm::Type::getFloatTy(module->getContext())},
+ "complex64", /*isPacked=*/true);
}
return cplx_t;
}
@@ -178,6 +179,21 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
}
}
+int GetSizeInBits(llvm::Type* type) {
+ const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type);
+ if (struct_ty) {
+ CHECK(struct_ty->isPacked());
+ int bits = 0;
+ for (auto element_type : struct_ty->elements()) {
+ bits += GetSizeInBits(element_type);
+ }
+ return bits;
+ }
+ int bits = type->getPrimitiveSizeInBits();
+ CHECK_GT(bits, 0) << "type is not sized";
+ return bits;
+}
+
llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
if (ShapeUtil::IsTuple(shape)) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
index f70d9f88b3..063ead2b64 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -129,6 +129,9 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
llvm::Module* module);
+// Returns the type size in bits. If "type" is a struct, it must be packed.
+int GetSizeInBits(llvm::Type* type);
+
// Returns the LLVM type which represents the given XLA shape. For example,
// if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]].
llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module);
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 0b700fbb6f..c6e8b24d12 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -82,6 +82,25 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
{});
}
+XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<complex64>({});
+ auto result = builder.Neg(a);
+
+ ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<complex64>(
+ {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}});
+ auto result = builder.Neg(a);
+
+ ComputeAndCompareR1<complex64>(
+ &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}},
+ {}, error_spec_);
+}
+
XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<float>({});
@@ -145,6 +164,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) {
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
+XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<complex64>(
+ {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}});
+ auto b = builder.ConstantR1<complex64>(
+ {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}});
+ auto add = builder.Add(a, b);
+
+ ComputeAndCompareR1<complex64>(
+ &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<complex64>({});
+ auto b = builder.ConstantR1<complex64>({});
+ auto add = builder.Add(a, b);
+
+ ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
+}
+
TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
const int count = GetParam();
ComputationBuilder builder(client_, TestName());
@@ -222,6 +263,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
ComputeAndCompareR1<int32>(&builder, {}, {});
}
+XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<complex64>(
+ {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}});
+ auto b = builder.ConstantR1<complex64>(
+ {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}});
+ auto add = builder.Sub(a, b);
+
+ ComputeAndCompareR1<complex64>(
+ &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<complex64>({});
+ auto b = builder.ConstantR1<complex64>({});
+ auto add = builder.Sub(a, b);
+
+ ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
+}
+
XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
@@ -385,6 +448,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) {
}
}
+XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<complex64>(
+ {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}});
+ auto b = builder.ConstantR1<complex64>(
+ {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}});
+ auto div = builder.Div(a, b);
+
+ ComputeAndCompareR1<complex64>(
+ &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<complex64>({});
+ auto b = builder.ConstantR1<complex64>({});
+ auto div = builder.Div(a, b);
+
+ ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
+}
+
XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<float>(
@@ -496,6 +580,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
ComputeAndCompareR1<uint32>(&builder, expected, {});
}
+XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<complex64>(
+ {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}});
+ auto b = builder.ConstantR1<complex64>(
+ {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}});
+ auto add = builder.Mul(a, b);
+
+ ComputeAndCompareR1<complex64>(
+ &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<complex64>({});
+ auto b = builder.ConstantR1<complex64>({});
+ auto add = builder.Mul(a, b);
+
+ ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
+}
+
XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<bool>({false, false, true, true});
@@ -886,6 +992,53 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
ComputeAndCompareR1<bool>(&builder, {}, {});
}
+XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) {
+ SetFastMathDisabled(true);
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<complex64>({{-2.5f, 10.0f},
+ {1.0f, 25.5f},
+ {2.25f, -3.0f},
+ {NAN, 0.0f},
+ {1.0f, 6.0f}});
+ auto rhs = builder.ConstantR1<complex64>({{0.0f, 10.0f},
+ {1.0f, 5.0f},
+ {2.25f, -3.0f},
+ {10.0f, 0.0f},
+ {1.0f, NAN}});
+ auto compare = builder.Eq(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<complex64>({});
+ auto rhs = builder.ConstantR1<complex64>({});
+ auto compare = builder.Eq(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) {
+ // Disable fast-math because we're operating on NaNs.
+ SetFastMathDisabled(true);
+
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<complex64>({{-2.5f, 10.0f},
+ {1.0f, 25.5f},
+ {2.25f, -3.0f},
+ {NAN, 0.0f},
+ {1.0f, 6.0f}});
+ auto rhs = builder.ConstantR1<complex64>({{0.0f, 10.0f},
+ {1.0f, 5.0f},
+ {2.25f, -3.0f},
+ {10.0f, 0.0f},
+ {1.0f, NAN}});
+ auto compare = builder.Ne(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {true, true, false, true, true}, {});
+}
+
XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
// Disable fast-math because we're operating on NaNs.
SetFastMathDisabled(true);
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index b578667735..1dc274c591 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -332,8 +332,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
ComputationBuilder* builder, NativeT expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
- std::is_same<NativeT, double>::value,
- "Floating point type required when specifying an ErrorSpec");
+ std::is_same<NativeT, double>::value ||
+ std::is_same<NativeT, complex64>::value,
+ "Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
Literal::CreateR0<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
@@ -355,8 +356,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
- std::is_same<NativeT, double>::value,
- "Floating point type required when specifying an ErrorSpec");
+ std::is_same<NativeT, double>::value ||
+ std::is_same<NativeT, complex64>::value,
+ "Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
Literal::CreateR1<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index b72dd2707c..bfb04fd9f9 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -386,35 +386,39 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major,
}
XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) {
- constexpr bool kLhsRowMajor = false;
- constexpr bool kRhsRowMajor = false;
- TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+ TestNonsquareMatrixDot<float>(false, false);
}
XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) {
- constexpr bool kLhsRowMajor = false;
- constexpr bool kRhsRowMajor = true;
- TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+ TestNonsquareMatrixDot<float>(false, true);
}
XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) {
- constexpr bool kLhsRowMajor = true;
- constexpr bool kRhsRowMajor = false;
- TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+ TestNonsquareMatrixDot<float>(true, false);
}
XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) {
- constexpr bool kLhsRowMajor = true;
- constexpr bool kRhsRowMajor = true;
- TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+ TestNonsquareMatrixDot<float>(true, true);
}
XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) {
TestNonsquareMatrixDot<double>();
}
-XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) {
- TestNonsquareMatrixDot<complex64>();
+XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFF) {
+ TestNonsquareMatrixDot<complex64>(false, false);
+}
+
+XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFT) {
+ TestNonsquareMatrixDot<complex64>(false, true);
+}
+
+XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTF) {
+ TestNonsquareMatrixDot<complex64>(true, false);
+}
+
+XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTT) {
+ TestNonsquareMatrixDot<complex64>(true, true);
}
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {