aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2016-06-30 13:17:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-30 14:34:28 -0700
commit5a8faefbad514c88ca5161ebc901bb9bd74d932d (patch)
treefc19c06558238d796d68dd3729ce1a251c056d5b
parent1445e8054c65b776f63d43833da30dc3debdbc31 (diff)
Improved the gradients for tanh and sigmoid. This improves the speed of the ptb word model from 6800 to 7800 words per second.
Change: 126342788
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_sigmoid.cc9
-rw-r--r--tensorflow/core/kernels/cwise_op_tanh.cc8
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h30
-rw-r--r--tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h71
-rw-r--r--tensorflow/core/kernels/cwise_ops_gradients.h107
-rw-r--r--tensorflow/core/ops/math_ops.cc21
-rw-r--r--tensorflow/python/BUILD2
-rw-r--r--tensorflow/python/ops/math_grad.py5
-rw-r--r--tensorflow/python/ops/math_ops.py2
12 files changed, 258 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 5cf48bfab5..142f63c6b4 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1753,6 +1753,7 @@ filegroup(
"cwise_ops.h",
"cwise_ops_common.cc",
"cwise_ops_common.h",
+ "cwise_ops_gradients.h",
"dense_update_ops.cc",
"dense_update_ops.h",
"example_parsing_ops.cc",
diff --git a/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc
index a7ac9baca0..b59d22310e 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc
@@ -16,10 +16,12 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h"
namespace tensorflow {
namespace functor {
DEFINE_UNARY3(sigmoid, Eigen::half, float, double);
+DEFINE_SIMPLE_BINARY3(sigmoid_grad, Eigen::half, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc
index 1678086c35..66ee3c193e 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc
@@ -16,10 +16,12 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h"
namespace tensorflow {
namespace functor {
DEFINE_UNARY3(tanh, Eigen::half, float, double);
+DEFINE_SIMPLE_BINARY3(tanh_grad, Eigen::half, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sigmoid.cc b/tensorflow/core/kernels/cwise_op_sigmoid.cc
index 9d8a849bd3..cc1f9b8f03 100644
--- a/tensorflow/core/kernels/cwise_op_sigmoid.cc
+++ b/tensorflow/core/kernels/cwise_op_sigmoid.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
namespace tensorflow {
REGISTER5(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double,
@@ -22,4 +23,12 @@ REGISTER5(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double,
REGISTER3(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, Eigen::half,
double);
#endif
+
+REGISTER5(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, float,
+ Eigen::half, double, complex64, complex128);
+#if GOOGLE_CUDA
+REGISTER3(SimpleBinaryOp, GPU, "SigmoidGrad", functor::sigmoid_grad, float,
+ Eigen::half, double);
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_tanh.cc b/tensorflow/core/kernels/cwise_op_tanh.cc
index 6604d71d14..a4c4aad053 100644
--- a/tensorflow/core/kernels/cwise_op_tanh.cc
+++ b/tensorflow/core/kernels/cwise_op_tanh.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
namespace tensorflow {
REGISTER5(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double,
@@ -21,4 +22,11 @@ REGISTER5(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double,
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Tanh", functor::tanh, float, Eigen::half, double);
#endif
+
+REGISTER5(SimpleBinaryOp, CPU, "TanhGrad", functor::tanh_grad, float,
+ Eigen::half, double, complex64, complex128);
+#if GOOGLE_CUDA
+REGISTER3(SimpleBinaryOp, GPU, "TanhGrad", functor::tanh_grad, float,
+ Eigen::half, double);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index 02a82c00bf..6ccbe46c7f 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -21,6 +21,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/cwise_ops.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -130,6 +131,35 @@ class BinaryOp : public BinaryOpShared {
}
};
+// Basic coefficient-wise binary operations that are known to not require
+// any broadcasting. This is the case for example of the gradients of
+// unary operations.
+// Device: E.g., CPUDevice, GPUDevice.
+// Functor: defined above. E.g., functor::tanh_grad.
+template <typename Device, typename Functor>
+class SimpleBinaryOp : public OpKernel {
+ public:
+ typedef typename Functor::in_type Tin; // Input scalar data type.
+ typedef typename Functor::out_type Tout; // Output scalar data type.
+
+ explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& in0 = ctx->input(0);
+ const Tensor& in1 = ctx->input(1);
+
+ Tensor* out;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
+ auto out_flat = out->flat<Tout>();
+ auto in0_flat = in0.flat<Tin>();
+ auto in1_flat = in1.flat<Tin>();
+ const Device& eigen_device = ctx->eigen_device<Device>();
+
+ functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
+ in0_flat, in1_flat);
+ }
+};
+
// Coefficient-wise unary operations:
// Device: E.g., CPUDevice, GPUDevice.
// Functor: defined in cwise_functors.h. E.g., functor::sqrt.
diff --git a/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
new file mode 100644
index 0000000000..4394770708
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
@@ -0,0 +1,71 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if !GOOGLE_CUDA
+#error This file must only be included when building with Cuda support
+#endif
+
+#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+#define TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+
+#define EIGEN_USE_GPU
+
+#include <complex>
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/cwise_ops.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+namespace functor {
+
+typedef Eigen::GpuDevice GPUDevice;
+typedef std::complex<float> complex64;
+typedef std::complex<double> complex128;
+
+// Partial specialization of SimpleBinaryFunctor<Device=GPUDevice, Functor>.
+template <typename Functor>
+struct SimpleBinaryFunctor<GPUDevice, Functor> {
+ void operator()(const GPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in1,
+ typename Functor::tin_type in2) {
+ To32Bit(out).device(d) =
+ To32Bit(in1).binaryExpr(in2, typename Functor::func());
+ }
+};
+
+// Macros to explicitly instantiate kernels on GPU for multiple types
+// (T0, T1, etc.) for SimpleBiaryFunctor (e.g., functor::tanh_grad).
+#define DEFINE_SIMPLE_BINARY1(F, T) \
+ template struct SimpleBinaryFunctor<GPUDevice, F<T> >
+#define DEFINE_SIMPLE_BINARY2(F, T0, T1) \
+ DEFINE_SIMPLE_BINARY1(F, T0); \
+ DEFINE_SIMPLE_BINARY1(F, T1)
+#define DEFINE_SIMPLE_BINARY3(F, T0, T1, T2) \
+ DEFINE_SIMPLE_BINARY2(F, T0, T1); \
+ DEFINE_SIMPLE_BINARY1(F, T2)
+#define DEFINE_SIMPLE_BINARY4(F, T0, T1, T2, T3) \
+ DEFINE_SIMPLE_BINARY2(F, T0, T1); \
+ DEFINE_SIMPLE_BINARY2(F, T2, T3)
+#define DEFINE_SIMPLE_BINARY5(F, T0, T1, T2, T3, T4) \
+ DEFINE_SIMPLE_BINARY2(F, T0, T1); \
+ DEFINE_SIMPLE_BINARY3(F, T2, T3, T4)
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h
new file mode 100644
index 0000000000..a59f157281
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_ops_gradients.h
@@ -0,0 +1,107 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
+#define TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
+
+#define EIGEN_USE_THREADS
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace Eigen {
+namespace internal {
+
+// Gradient for the tanh function
+template <typename T>
+struct scalar_tanh_gradient_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_gradient_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
+ operator()(const T& output, const T& output_gradient) const {
+ return output_gradient * (T(1) - output * output);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
+ packetOp(const Packet& output, const Packet& output_gradient) const {
+ return pmul(output_gradient,
+ psub(pset1<Packet>(T(1)), pmul(output, output)));
+ }
+};
+template <typename T>
+struct functor_traits<scalar_tanh_gradient_op<T>> {
+ enum {
+ Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
+ PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
+ };
+};
+
+// Gradient for the sigmoid function
+template <typename T>
+struct scalar_sigmoid_gradient_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_sigmoid_gradient_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
+ operator()(const T& output, const T& output_gradient) const {
+ return output_gradient * output * (T(1) - output);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
+ packetOp(const Packet& output, const Packet& output_gradient) const {
+ return pmul(output_gradient,
+ pmul(output, psub(pset1<Packet>(T(1)), output)));
+ }
+};
+template <typename T>
+struct functor_traits<scalar_sigmoid_gradient_op<T>> {
+ enum {
+ Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
+ PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
+ };
+};
+
+} // end namespace internal
+} // end namespace Eigen
+
+namespace tensorflow {
+
+namespace functor {
+
+template <typename Device, typename Functor>
+struct SimpleBinaryFunctor {
+ void operator()(const Device& d, typename Functor::tout_type out,
+ typename Functor::tin_type in0,
+ typename Functor::tin_type in1);
+};
+
+// Partial specialization of BinaryFunctor for CPU devices
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Functor>
+struct SimpleBinaryFunctor<CPUDevice, Functor> {
+ void operator()(const CPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in0,
+ typename Functor::tin_type in1) {
+ out.device(d) = in0.binaryExpr(in1, typename Functor::func());
+ }
+};
+
+template <typename T>
+struct tanh_grad : base<T, Eigen::internal::scalar_tanh_gradient_op<T>> {};
+
+template <typename T>
+struct sigmoid_grad : base<T, Eigen::internal::scalar_sigmoid_gradient_op<T>> {
+};
+
+} // end namespace functor
+
+} // end namespace tensorflow
+#endif // TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 0f9ee4942a..b220a2d2d6 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -238,6 +238,13 @@ tf.complex_abs(x) ==> [5.25594902, 6.60492229]
.Attr("T: {half, float, double, complex64, complex128}") \
.SetShapeFn(OpShapeInferenceFn(shape_inference::UnchangedShape))
+#define UNARY_GRADIENT_COMPLEX() \
+ Input("x: T") \
+ .Input("y: T") \
+ .Output("z: T") \
+ .Attr("T: {half, float, double, complex64, complex128}") \
+ .SetShapeFn(OpShapeInferenceFn(shape_inference::UnchangedShape))
+
REGISTER_OP("Neg")
.UNARY()
.Doc(R"doc(
@@ -292,6 +299,13 @@ REGISTER_OP("Tanh")
Computes hyperbolic tangent of `x` element-wise.
)doc");
+REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX().Doc(R"doc(
+Computes the gradient for the tanh of `x` wrt its input.
+
+Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy`
+is the corresponding input gradient.
+)doc");
+
REGISTER_OP("Lgamma")
.UNARY_REAL()
.Doc(R"doc(
@@ -325,6 +339,13 @@ Computes sigmoid of `x` element-wise.
Specifically, `y = 1 / (1 + exp(-x))`.
)doc");
+REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX().Doc(R"doc(
+Computes the gradient of the sigmoid of `x` wrt its input.
+
+Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
+`dy` is the corresponding input gradient.
+)doc");
+
REGISTER_OP("Sin")
.UNARY_COMPLEX()
.Doc(R"doc(
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index c2e5b0cc1c..c5418cf076 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -670,6 +670,8 @@ tf_gen_op_wrapper_py(
"MatMul",
"Sigmoid",
"Tanh",
+ "SigmoidGrad",
+ "TanhGrad",
],
require_shape_functions = True,
)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 8bfd9ce8bf..348ab9fd12 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
@@ -272,7 +273,7 @@ def _TanhGrad(op, grad):
with ops.control_dependencies([grad.op]):
if y.dtype.is_complex:
y = math_ops.conj(y)
- return grad * (1 - math_ops.square(y))
+ return gen_math_ops._tanh_grad(y, grad)
@ops.RegisterGradient("Erf")
@@ -374,7 +375,7 @@ def _SigmoidGrad(op, grad):
with ops.control_dependencies([grad.op]):
if y.dtype.is_complex:
y = math_ops.conj(y)
- return grad * (y * (1 - y))
+ return gen_math_ops._sigmoid_grad(y, grad)
@ops.RegisterGradient("Sign")
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 0a76450c5b..0bcf45db76 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1609,6 +1609,8 @@ ops.RegisterShape("BatchFFT2D")(common_shapes.unchanged_shape)
ops.RegisterShape("BatchIFFT2D")(common_shapes.unchanged_shape)
ops.RegisterShape("BatchFFT3D")(common_shapes.unchanged_shape)
ops.RegisterShape("BatchIFFT3D")(common_shapes.unchanged_shape)
+ops.RegisterShape("TanhGrad")(common_shapes.unchanged_shape)
+ops.RegisterShape("SigmoidGrad")(common_shapes.unchanged_shape)
@ops.RegisterShape("Add")