diff options
author | 2017-12-19 10:11:58 -0800 | |
---|---|---|
committer | 2017-12-19 10:16:51 -0800 | |
commit | 2368d4114465b3ebd6bd597cd5919b295cd4348b (patch) | |
tree | 9dc275dd8eec7a80ee19ba09f0334f0f27da7b7f /tensorflow/compiler/xla/service/elemental_ir_emitter.h | |
parent | 2071f7ff4fe13d0b5a7b8d9dceaeb3c211e37199 (diff) |
[XLA] Add support for atan2 on CPU
This leans on the libm's atan2 for the actual routine but allows us to share
the implementation of other complex operations between CPU and GPU.
PiperOrigin-RevId: 179569666
Diffstat (limited to 'tensorflow/compiler/xla/service/elemental_ir_emitter.h')
-rw-r--r-- | tensorflow/compiler/xla/service/elemental_ir_emitter.h | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index cccb498f82..1a48eb5fcb 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -39,7 +39,7 @@ class ElementalIrEmitter { module_(module), hlo_module_config_(hlo_module_config) {} - virtual ~ElementalIrEmitter() {} + virtual ~ElementalIrEmitter() = default; virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op, llvm::Value* operand_value) const; @@ -92,6 +92,26 @@ class ElementalIrEmitter { virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const; + + virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const; + virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x) const; |