aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Lakshay Garg <lakshayg@outlook.in>2017-06-09 08:46:08 +0530
committerGravatar gunan <gunan@google.com>2017-06-08 20:16:08 -0700
commite9de087fa7f59c39bbe12ac2c83c5547c83f746c (patch)
tree0993c3053b6af29c5720f06a38074b3995645bb5
parent9bb581c7b79a25e69a821724e7c9b2bcd7c8fced (diff)
Implemented sinh and cosh (#10427)
* Implemented sinh and cosh #7531 * Removed Eigen::half from cosh and sinh definitions Refer to Issue #7531 and Pull Request #7628 * Fixed the gradient for sinh in math_grad_test.cc
-rw-r--r--tensorflow/cc/gradients/math_grad.cc26
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc52
-rw-r--r--tensorflow/core/kernels/cwise_op_cosh.cc37
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_cosh.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_sinh.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_sinh.cc37
-rw-r--r--tensorflow/core/kernels/cwise_ops.h6
-rw-r--r--tensorflow/core/ops/math_grad.cc20
-rw-r--r--tensorflow/core/ops/math_grad_test.cc20
-rw-r--r--tensorflow/core/ops/math_ops.cc8
-rw-r--r--tensorflow/core/ops/ops.pbtxt48
-rw-r--r--tensorflow/docs_src/api_guides/python/math_ops.md2
-rw-r--r--tensorflow/go/op/wrappers.go30
-rw-r--r--tensorflow/python/kernel_tests/basic_gpu_test.py2
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py10
-rw-r--r--tensorflow/python/ops/math_grad.py18
-rw-r--r--tensorflow/python/ops/math_ops.py2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt8
18 files changed, 378 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index 8c1a01f518..71d9a8ed7b 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -162,6 +162,32 @@ Status Log1pGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
+Status SinhGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = sinh(x)
+ // dy/dx = cosh(x)
+ auto dydx = Cosh(scope, op.input(0));
+ // grad(x) = grad(y) * conj(dy/dx)
+ grad_outputs->push_back(
+ Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Sinh", SinhGrad);
+
+Status CoshGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = cosh(x)
+ // dy/dx = sinh(x)
+ auto dydx = Sinh(scope, op.input(0));
+ // grad(x) = grad(y) * conj(dy/dx)
+ grad_outputs->push_back(
+ Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Cosh", CoshGrad);
+
Status TanhGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index de6baa1769..1653b04378 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -45,6 +45,8 @@ class CWiseUnaryGradTest : public ::testing::Test {
EXPM1,
LOG,
LOG1P,
+ SINH,
+ COSH,
TANH,
SIGMOID,
SIGN,
@@ -111,6 +113,12 @@ class CWiseUnaryGradTest : public ::testing::Test {
case LOG1P:
y = Log1p(scope_, x);
break;
+ case SINH:
+ y = Sinh(scope_, x);
+ break;
+ case COSH:
+ y = Cosh(scope_, x);
+ break;
case TANH:
y = Tanh(scope_, x);
break;
@@ -337,6 +345,50 @@ TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
TestCWiseGrad<complex64>(LOG1P, x_fn, dy_fn, dx_fn);
}
+TEST_F(CWiseUnaryGradTest, Sinh) {
+ auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ return dy * std::cosh(x);
+ };
+ TestCWiseGrad<float>(SINH, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Sinh_Complex) {
+ auto x_fn = [this](const int i) {
+ return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
+ };
+ auto dy_fn = [this](const complex64& x) {
+ return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
+ };
+ auto dx_fn = [this](const complex64& x, const complex64& dy) {
+ return dy * conjugate(std::cosh(x));
+ };
+ TestCWiseGrad<complex64>(SINH, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Cosh) {
+ auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ return dy * std::sinh(x);
+ };
+ TestCWiseGrad<float>(COSH, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Cosh_Complex) {
+ auto x_fn = [this](const int i) {
+ return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
+ };
+ auto dy_fn = [this](const complex64& x) {
+ return x + CRV({{-2, 2}, {-3, 3}, {1, -4}});
+ };
+ auto dx_fn = [this](const complex64& x, const complex64& dy) {
+ return dy * conjugate(std::sinh(x));
+ };
+ TestCWiseGrad<complex64>(COSH, x_fn, dy_fn, dx_fn);
+}
+
TEST_F(CWiseUnaryGradTest, Tanh) {
auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
diff --git a/tensorflow/core/kernels/cwise_op_cosh.cc b/tensorflow/core/kernels/cwise_op_cosh.cc
new file mode 100644
index 0000000000..bca99a4f89
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_cosh.cc
@@ -0,0 +1,37 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER4(UnaryOp, CPU, "Cosh", functor::cosh, float, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Cosh") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<TYPE>("T"), \
+ UnaryOp<SYCLDevice, functor::cosh<TYPE>>);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Cosh", functor::cosh, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_cosh.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_cosh.cu.cc
new file mode 100644
index 0000000000..267a381d1a
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_cosh.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(cosh, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_sinh.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sinh.cu.cc
new file mode 100644
index 0000000000..f8329e50d6
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_sinh.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(sinh, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_sinh.cc b/tensorflow/core/kernels/cwise_op_sinh.cc
new file mode 100644
index 0000000000..055f0b12e1
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_sinh.cc
@@ -0,0 +1,37 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER4(UnaryOp, CPU, "Sinh", functor::sinh, float, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Sinh") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<TYPE>("T"), \
+ UnaryOp<SYCLDevice, functor::sinh<TYPE>>);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYC
+
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Sinh", functor::sinh, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 423307fd4c..6d80e4bfc1 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -484,6 +484,12 @@ template <typename T>
struct sign : base<T, Eigen::internal::scalar_sign_op<T> > {};
template <typename T>
+struct sinh : base<T, Eigen::internal::scalar_sinh_op<T> > {};
+
+template <typename T>
+struct cosh : base<T, Eigen::internal::scalar_cosh_op<T> > {};
+
+template <typename T>
struct tanh : base<T, Eigen::internal::scalar_tanh_op<T> > {};
template <typename T>
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index a530d286f7..9a58a31757 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -155,6 +155,26 @@ Status Log1pGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Log1p", Log1pGrad);
+Status SinhGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"cosh"}, "Cosh", {"x"}, {}, {"dy"}},
+ {{"dx"}, "Mul", {"dy", "cosh"}}, // dy * cosh(x)
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Sinh", SinhGrad);
+
+Status CoshGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"sinh"}, "Sinh", {"x"}, {}, {"dy"}},
+ {{"dx"}, "Mul", {"dy", "sinh"}}, // dy * sinh(x)
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Cosh", CoshGrad);
+
Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForUnaryCwise(g, {
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index 8670ca307c..8c258f8d81 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -500,6 +500,26 @@ TEST_F(MathGradTest, Log1p) {
test::ExpectClose(ans, dx);
}
+TEST_F(MathGradTest, Sinh) {
+ auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
+ TensorShape({2, 3}));
+ auto g = [](float x) { return std::cosh(x); };
+ auto dx = test::AsTensor<float>(
+ {g(-3.f), g(-2.f), g(-1.f), g(1.f), g(2.f), g(3.f)}, TensorShape({2, 3}));
+ auto ans = SymGrad("Sinh", x);
+ test::ExpectClose(ans, dx);
+}
+
+TEST_F(MathGradTest, Cosh) {
+ auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
+ TensorShape({2, 3}));
+ auto g = [](float x) { return std::sinh(x); };
+ auto dx = test::AsTensor<float>(
+ {g(-3.f), g(-2.f), g(-1.f), g(1.f), g(2.f), g(3.f)}, TensorShape({2, 3}));
+ auto ans = SymGrad("Cosh", x);
+ test::ExpectClose(ans, dx);
+}
+
TEST_F(MathGradTest, Tanh) {
auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
TensorShape({2, 3}));
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 41ca570732..8b9dd39536 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -293,6 +293,14 @@ Computes natural logarithm of (1 + x) element-wise.
I.e., \\(y = \log_e (1 + x)\\).
)doc");
+REGISTER_OP("Sinh").UNARY_COMPLEX().Doc(R"doc(
+Computes hyperbolic sine of x element-wise.
+)doc");
+
+REGISTER_OP("Cosh").UNARY_COMPLEX().Doc(R"doc(
+Computes hyperbolic cosine of x element-wise.
+)doc");
+
REGISTER_OP("Tanh").UNARY_COMPLEX().Doc(R"doc(
Computes hyperbolic tangent of `x` element-wise.
)doc");
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index c02bd4b8a0..ff29dfcb04 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -4971,6 +4971,30 @@ op {
summary: "Computes cos of x element-wise."
}
op {
+ name: "Cosh"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ summary: "Computes hyperbolic cosine of x element-wise."
+}
+op {
name: "CountUpTo"
input_arg {
name: "ref"
@@ -21741,6 +21765,30 @@ op {
summary: "Computes sin of x element-wise."
}
op {
+ name: "Sinh"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ summary: "Computes hyperbolic sine of x element-wise."
+}
+op {
name: "Size"
input_arg {
name: "input"
diff --git a/tensorflow/docs_src/api_guides/python/math_ops.md b/tensorflow/docs_src/api_guides/python/math_ops.md
index b3c12a61dd..3d9f203297 100644
--- a/tensorflow/docs_src/api_guides/python/math_ops.md
+++ b/tensorflow/docs_src/api_guides/python/math_ops.md
@@ -59,6 +59,8 @@ mathematical functions to your graph.
* @{tf.acos}
* @{tf.asin}
* @{tf.atan}
+* @{tf.cosh}
+* @{tf.sinh}
* @{tf.lgamma}
* @{tf.digamma}
* @{tf.erf}
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 9f048d3ea0..b3a951f3a2 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -13828,6 +13828,36 @@ func DecodeBase64(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
+// Computes hyperbolic sine of x element-wise.
+func Sinh(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Sinh",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes hyperbolic cosine of x element-wise.
+func Cosh(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Cosh",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes hyperbolic tangent of `x` element-wise.
func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
diff --git a/tensorflow/python/kernel_tests/basic_gpu_test.py b/tensorflow/python/kernel_tests/basic_gpu_test.py
index 013aa1ba8a..b5b17ff80a 100644
--- a/tensorflow/python/kernel_tests/basic_gpu_test.py
+++ b/tensorflow/python/kernel_tests/basic_gpu_test.py
@@ -113,6 +113,7 @@ class MathBuiltinUnaryTest(test.TestCase):
self._compare(data, np.arctan, math_ops.atan, use_gpu)
self._compare(data, np.ceil, math_ops.ceil, use_gpu)
self._compare(data, np.cos, math_ops.cos, use_gpu)
+ self._compare(data, np.cosh, math_ops.cosh, use_gpu)
self._compare(data, np.exp, math_ops.exp, use_gpu)
self._compare(data, np.floor, math_ops.floor, use_gpu)
self._compare(data, np.log, math_ops.log, use_gpu)
@@ -120,6 +121,7 @@ class MathBuiltinUnaryTest(test.TestCase):
self._compare(data, np.negative, math_ops.negative, use_gpu)
self._compare(data, self._rsqrt, math_ops.rsqrt, use_gpu)
self._compare(data, np.sin, math_ops.sin, use_gpu)
+ self._compare(data, np.sinh, math_ops.sinh, use_gpu)
self._compare(data, np.sqrt, math_ops.sqrt, use_gpu)
self._compare(data, np.square, math_ops.square, use_gpu)
self._compare(data, np.tan, math_ops.tan, use_gpu)
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 0846470abc..54810cdc34 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -200,6 +200,8 @@ class UnaryOpTest(test.TestCase):
self._compareBoth(x, np.expm1, math_ops.expm1)
self._compareBoth(z, np.log, math_ops.log)
self._compareBoth(z, np.log1p, math_ops.log1p)
+ self._compareBoth(x, np.sinh, math_ops.sinh)
+ self._compareBoth(x, np.cosh, math_ops.cosh)
self._compareBoth(x, np.tanh, math_ops.tanh)
self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
self._compareBoth(x, self._log_sigmoid, math_ops.log_sigmoid)
@@ -245,6 +247,8 @@ class UnaryOpTest(test.TestCase):
self._compareBoth(x, np.expm1, math_ops.expm1)
self._compareBoth(x, np.log, math_ops.log)
self._compareBoth(x, np.log1p, math_ops.log1p)
+ self._compareBoth(x, np.sinh, math_ops.sinh)
+ self._compareBoth(x, np.cosh, math_ops.cosh)
self._compareBoth(x, np.tanh, math_ops.tanh)
self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
self._compareBoth(x, np.sign, math_ops.sign)
@@ -285,6 +289,8 @@ class UnaryOpTest(test.TestCase):
self._compareBoth(x, np.expm1, math_ops.expm1)
self._compareBoth(z, np.log, math_ops.log)
self._compareBoth(z, np.log1p, math_ops.log1p)
+ self._compareBoth(x, np.sinh, math_ops.sinh)
+ self._compareBoth(x, np.cosh, math_ops.cosh)
self._compareBoth(x, np.tanh, math_ops.tanh)
self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
self._compareBoth(y, np.sign, math_ops.sign)
@@ -389,6 +395,8 @@ class UnaryOpTest(test.TestCase):
self._compareCpu(x, np.expm1, math_ops.expm1)
self._compareCpu(y, np.log, math_ops.log)
self._compareCpu(y, np.log1p, math_ops.log1p)
+ self._compareCpu(x, np.sinh, math_ops.sinh)
+ self._compareCpu(x, np.cosh, math_ops.cosh)
self._compareCpu(x, np.tanh, math_ops.tanh)
self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
self._compareCpu(x, np.sin, math_ops.sin)
@@ -423,6 +431,8 @@ class UnaryOpTest(test.TestCase):
self._compareCpu(x, np.expm1, math_ops.expm1)
self._compareCpu(y, np.log, math_ops.log)
self._compareCpu(y, np.log1p, math_ops.log1p)
+ self._compareCpu(x, np.sinh, math_ops.sinh)
+ self._compareCpu(x, np.cosh, math_ops.cosh)
self._compareCpu(x, np.tanh, math_ops.tanh)
self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
self._compareCpu(x, np.sin, math_ops.sin)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 024158e709..a0f505e47b 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -369,6 +369,24 @@ def _Log1pGrad(op, grad):
return grad * math_ops.reciprocal(1 + x)
+@ops.RegisterGradient("Sinh")
+def _SinhGrad(op, grad):
+ """Returns grad * cosh(x)."""
+ x = op.inputs[0]
+ with ops.control_dependencies([grad.op]):
+ x = math_ops.conj(x)
+ return grad * math_ops.cosh(x)
+
+
+@ops.RegisterGradient("Cosh")
+def _CoshGrad(op, grad):
+ """Returns grad * sinh(x)."""
+ x = op.inputs[0]
+ with ops.control_dependencies([grad.op]):
+ x = math_ops.conj(x)
+ return grad * math_ops.sinh(x)
+
+
@ops.RegisterGradient("Tanh")
def _TanhGrad(op, grad):
"""Returns grad * (1 - tanh(x) * tanh(x))."""
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 89b7746e71..1ab1dd1fe9 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -45,6 +45,8 @@ See the @{$python/math_ops} guide.
@@expm1
@@log
@@log1p
+@@sinh
+@@cosh
@@ceil
@@floor
@@maximum
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 39c1feaed9..91abff6e13 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -769,6 +769,10 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "cosh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "count_nonzero"
argspec: "args=[\'input_tensor\', \'axis\', \'keep_dims\', \'dtype\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \"<dtype: \'int64\'>\", \'None\', \'None\'], "
}
@@ -1641,6 +1645,10 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "sinh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "size"
argspec: "args=[\'input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'int32\'>\"], "
}