aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-07-25 11:08:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 11:11:42 -0700
commita5285d999af961367437d72285f67c4a5a2878d4 (patch)
tree9a5710a0f171b4aa486fcabc2e65a9ecc4238077
parent0bc512505957e3685305b6a850f222c6eed88c7d (diff)
[XLA:GPU] Use a fast approximation for tanh
Just reuse the CPU implementation, which in turn is derived from Eigen. It claims to be accurate within +-1% which is good enough for fast math. Refactor the CPU implementation into a common file and remove the VectorSupportLibrary dependency (it's not needed). PiperOrigin-RevId: 206022260
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc39
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc11
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD10
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/math_ops.cc59
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/math_ops.h32
-rw-r--r--tensorflow/compiler/xla/tests/half_test.cc3
8 files changed, 120 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index ace9f96cfb..71f7f985d0 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -444,6 +444,7 @@ cc_library(
deps = [
":vector_support_library",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/compiler/xla/service/llvm_ir:math_ops",
"//tensorflow/core:lib",
"@llvm//:core",
"@llvm//:transform_utils",
diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
index ec0498e04e..cef5e57b0b 100644
--- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/IR/Verifier.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/platform/logging.h"
@@ -54,44 +55,12 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
llvm::IRBuilder<> b(vector_tanh_body);
llvm::FastMathFlags fast_math_flags;
- fast_math_flags.setFast();
+ fast_math_flags.setFast(enable_fast_math);
b.setFastMathFlags(fast_math_flags);
- VectorSupportLibrary vsl(F32, vector_width, &b, "tanh_f32");
-
llvm::Value* input = &*vector_tanh_function->arg_begin();
- CHECK_EQ(input->getType(), vsl.vector_type());
-
- // This implements the same rational interpolant as implemented in Eigen3.
- llvm::Value* input_clamped =
- vsl.Clamp(input, /*low=*/GetIeeeF32(-9.0), /*high=*/GetIeeeF32(9.0));
-
- std::array<float, 7> numerator_coeffs{
- -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
- 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
- 4.89352455891786e-03f};
-
- std::array<float, 4> denominator_coeffs{
- 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
- 4.89352518554385e-03f};
-
- llvm::Value* input_squared = vsl.Mul(input_clamped, input_clamped);
- llvm::Value* numerator = vsl.SplatFloat(GetIeeeF32(numerator_coeffs[0]));
- for (int i = 1; i < numerator_coeffs.size(); i++) {
- numerator =
- vsl.MulAdd(input_squared, numerator, GetIeeeF32(numerator_coeffs[i]));
- }
-
- numerator = vsl.Mul(input_clamped, numerator);
-
- llvm::Value* denominator = vsl.SplatFloat(GetIeeeF32(denominator_coeffs[0]));
- for (int i = 1; i < denominator_coeffs.size(); i++) {
- denominator = vsl.MulAdd(input_squared, denominator,
- GetIeeeF32(denominator_coeffs[i]));
- }
-
- llvm::Value* result = vsl.Div(numerator, denominator);
- b.CreateRet(result);
+ CHECK_EQ(vector_width, input->getType()->getVectorNumElements());
+ b.CreateRet(llvm_ir::EmitFastTanh(&b, input));
DCHECK(!llvm::verifyFunction(*vector_tanh_function));
return vector_tanh_function;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 08429c5b4d..6f1e766d1c 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -217,6 +217,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
+ "//tensorflow/compiler/xla/service/llvm_ir:math_ops",
"//tensorflow/core:lib",
"@llvm//:core",
"@llvm//:support",
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index b97a627d9b..cc38db27e2 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -277,6 +278,16 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp(
PrimitiveType output_type = op->shape().element_type();
switch (op->opcode()) {
case HloOpcode::kTanh:
+ // If we don't care much about precision, emit a fast approximation of
+ // tanh.
+ if (hlo_module_config_.debug_options().xla_enable_fast_math()) {
+ // Upcast F16 to F32 if necessary.
+ llvm::Type* type =
+ input_type == F16 ? b_->getFloatTy() : operand_value->getType();
+ llvm::Value* input = b_->CreateFPCast(operand_value, type);
+ llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
+ return b_->CreateFPCast(fast_tanh, operand_value->getType());
+ }
return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type},
output_type);
default:
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index 0573304912..309a186e58 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -223,3 +223,13 @@ cc_library(
"@llvm//:core",
],
)
+
+cc_library(
+ name = "math_ops",
+ srcs = ["math_ops.cc"],
+ hdrs = ["math_ops.h"],
+ deps = [
+ ":llvm_util",
+ "@llvm//:core",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc
new file mode 100644
index 0000000000..0e115cdabf
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+
+namespace xla {
+namespace llvm_ir {
+
+llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input) {
+ llvm::Type* type = input->getType();
+
+ // Clamp the input to [-9, 9].
+ llvm::Value* input_clamped = llvm_ir::EmitFloatMin(
+ llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, -9.0), b),
+ llvm::ConstantFP::get(type, 9.0), b);
+
+ static constexpr std::array<float, 7> numerator_coeffs{
+ -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
+ 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
+ 4.89352455891786e-03f};
+
+ static constexpr std::array<float, 4> denominator_coeffs{
+ 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
+ 4.89352518554385e-03f};
+
+ llvm::Value* input_squared = b->CreateFMul(input_clamped, input_clamped);
+ llvm::Value* numerator = llvm::ConstantFP::get(type, numerator_coeffs[0]);
+ for (int i = 1; i < numerator_coeffs.size(); i++) {
+ numerator = b->CreateFAdd(b->CreateFMul(input_squared, numerator),
+ llvm::ConstantFP::get(type, numerator_coeffs[i]));
+ }
+
+ numerator = b->CreateFMul(input_clamped, numerator);
+
+ llvm::Value* denominator = llvm::ConstantFP::get(type, denominator_coeffs[0]);
+ for (int i = 1; i < denominator_coeffs.size(); i++) {
+ denominator =
+ b->CreateFAdd(b->CreateFMul(input_squared, denominator),
+ llvm::ConstantFP::get(type, denominator_coeffs[i]));
+ }
+
+ return b->CreateFDiv(numerator, denominator);
+}
+
+} // namespace llvm_ir
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/math_ops.h b/tensorflow/compiler/xla/service/llvm_ir/math_ops.h
new file mode 100644
index 0000000000..6c8bc3a076
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/math_ops.h
@@ -0,0 +1,32 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_MATH_OPS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_MATH_OPS_H_
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+
+namespace xla {
+namespace llvm_ir {
+
+// Emits an approximation of tanh. The implementation uses the same rational
+// interpolant as implemented in Eigen3.
+llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input);
+
+} // namespace llvm_ir
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_MATH_OPS_H_
diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc
index 73a47eda72..249a4b2493 100644
--- a/tensorflow/compiler/xla/tests/half_test.cc
+++ b/tensorflow/compiler/xla/tests/half_test.cc
@@ -48,7 +48,8 @@ class UnaryOpTest : public HalfTestBase,
public ::testing::WithParamInterface<UnaryOpTestParam> {};
XLA_TEST_P(UnaryOpTest, Ops) {
- std::vector<half> x({half(1.4), half(-2.3), half(3.2), half(-4.1)});
+ std::vector<half> x({half(1.4), half(-2.3), half(3.2), half(-4.1), half(9.0),
+ half(42.0), half(-9.0), half(-100.0)});
XlaBuilder builder(TestName());
XlaOp x_opnd;
auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",