diff options
author | Petros Mol <pmol@google.com> | 2017-06-14 09:06:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-14 09:10:14 -0700 |
commit | df511d09b051914cbc4fc559807a3f0d07dfee71 (patch) | |
tree | 5c70cb1e00beddce3a6dfd60d35c0b1b3d93bf01 /tensorflow | |
parent | ade1560e651e67d6cf33dc05a3ab26abf364446b (diff) |
[XLA] Add a Cos unary operation that computes the elementwise cosine
PiperOrigin-RevId: 158984883
Diffstat (limited to 'tensorflow')
15 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index abe4949f5d..07ca596150 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -44,6 +44,7 @@ namespace { // Return x if x>0, otherwise -x. XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x)); +XLAJIT_MAKE_UNARY(Cos, b->Cos(x)); XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 6e95e623ba..cefa4af23c 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -967,6 +967,11 @@ ComputationDataHandle ComputationBuilder::Sign( return UnaryOp(UNOP_SIGN, operand); } +ComputationDataHandle ComputationBuilder::Cos( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_COS, operand); +} + ComputationDataHandle ComputationBuilder::Tanh( const ComputationDataHandle& operand) { return UnaryOp(UNOP_TANH, operand); diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 129e66c24a..13b44a71a5 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -508,6 +508,9 @@ class ComputationBuilder { // Enqueues a sign instruction onto the computation. ComputationDataHandle Sign(const ComputationDataHandle& operand); + // Enqueues a cosine instruction onto the computation. + ComputationDataHandle Cos(const ComputationDataHandle& operand); + // Enqueues a tanh instruction onto the computation. ComputationDataHandle Tanh(const ComputationDataHandle& operand); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index bfa6f241a3..40ff037e73 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -148,6 +148,9 @@ class DfsHloVisitor { virtual Status HandleLog(HloInstruction* log, HloInstruction* operand) { return HandleElementwiseUnary(log, HloOpcode::kLog, operand); } + virtual Status HandleCos(HloInstruction* cos, HloInstruction* operand) { + return HandleElementwiseUnary(cos, HloOpcode::kCos, operand); + } virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) { return HandleElementwiseUnary(tanh, HloOpcode::kTanh, operand); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index dbc65e80eb..c99ebceb45 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -172,6 +172,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {operand_value}, {operand_value->getType()}, ir_builder_); + case HloOpcode::kCos: + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {operand_value}, + {operand_value->getType()}, + ir_builder_); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()}, @@ -664,6 +668,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kCeil: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: case HloOpcode::kIsFinite: diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 6abc733646..48c33d62c5 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -214,6 +214,7 @@ string InstructionSequenceGraph( case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kConvert: + case HloOpcode::kCos: case HloOpcode::kDivide: case HloOpcode::kEq: case HloOpcode::kExp: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 9e020f9391..f926cb4bc7 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -122,6 +122,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: case HloOpcode::kIsFinite: @@ -744,6 +745,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kIsFinite: case HloOpcode::kFloor: @@ -1113,6 +1115,7 @@ bool HloInstruction::Identical( case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: case HloOpcode::kDot: @@ -1834,6 +1837,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleLog(this, operands_[0]); case HloOpcode::kTanh: return visitor->HandleTanh(this, operands_[0]); + case HloOpcode::kCos: + return visitor->HandleCos(this, operands_[0]); case HloOpcode::kIsFinite: return visitor->HandleIsFinite(this, operands_[0]); case HloOpcode::kLogicalNot: @@ -2080,6 +2085,7 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kCeil: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: case HloOpcode::kIsFinite: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 5bda6b6dab..342c43dc5a 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -42,6 +42,8 @@ string HloOpcodeString(HloOpcode opcode) { return "convert"; case HloOpcode::kConvolution: return "convolution"; + case HloOpcode::kCos: + return "cosine"; case HloOpcode::kCrossReplicaSum: return "cross-replica-sum"; case HloOpcode::kCustomCall: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 65aef63dcd..8e0fa7b4f1 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -41,6 +41,7 @@ enum class HloOpcode { kConvert, kConvolution, kCopy, + kCos, kCrossReplicaSum, kCustomCall, kDivide, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 06fa8bc619..9bace7edaa 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -43,6 +43,7 @@ namespace xla { case HloOpcode::kConstant: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 670e1ca84a..2508f4c13d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -184,6 +184,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, switch (operation) { case UNOP_FLOOR: case UNOP_CEIL: + case UNOP_COS: case UNOP_EXP: case UNOP_LOG: case UNOP_TANH: diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 417ed584aa..3e7942075c 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -49,6 +49,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kAbs; case UNOP_CEIL: return HloOpcode::kCeil; + case UNOP_COS: + return HloOpcode::kCos; case UNOP_EXP: return HloOpcode::kExp; case UNOP_FLOOR: diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index ecb2a6c767..7c3a1d2580 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -1297,6 +1297,15 @@ TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { {param0_data.get()}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({3.14159f, 0.0f, 1.570796f, -0.78539f}); + auto result = builder.Cos(a); + + ComputeAndCompareR1<float>(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, + error_spec_); +} + TEST_F(ArrayElementwiseOpTest, TanhF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f}); diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 23ef79d0d7..633d16c4c3 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -603,6 +603,9 @@ enum UnaryOperation { // Elementwise, tests if values are finite (not NaN or inf) UNOP_IS_FINITE = 11; + + // Elementwise, computes the cosine of x. + UNOP_COS = 12; } message UnaryOpRequest { diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index ed055691ce..120260448d 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -583,6 +583,8 @@ ComputationBuilder supports these element-wise unary functions: <b>`Ceil(operand)`</b> Element-wise ceil `x -> ⌈x⌉`. +<b>`Cos(operand)`</b> Element-wise cosine `x -> cos(x)`. + <b>`Exp(operand)`</b> Element-wise natural exponential `x -> e^x`. <b>`Floor(operand)`</b> Element-wise floor `x -> ⌊x⌋`. |