From aa25cc078c9b55e5ca3e0f59df43e169bfee8f3c Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Thu, 16 Aug 2018 19:04:37 +0800 Subject: Add LeakyRelu C++ Op and its gradient implementation. LeakyRelu, defined as 'y = { x (x>=0) or alpha*x (x<0) }', was computed by combined Ops 'max(x, alpha*x)' in current codes. Hence its gradient calculation for back propagation would contain a serial of element-wise Ops. This looks really unnecessary for such a simple op and it could be done within just one Op with less memory accesses. --- tensorflow/cc/gradients/nn_grad.cc | 13 +++++++++++++ tensorflow/cc/gradients/nn_grad_test.cc | 13 +++++++++++++ 2 files changed, 26 insertions(+) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 588e96cb19..0fc23d0bf7 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -143,6 +143,19 @@ Status Relu6GradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper); +Status LeakyReluGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + attrs.Alpha(alpha); + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), attrs); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper); + Status EluGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index aa72cf7ba2..5ebece7b6e 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -41,6 +41,7 @@ using ops::MaxPoolV2; using ops::Placeholder; using ops::Relu; using ops::Relu6; +using ops::LeakyRelu; using ops::Selu; using ops::Softmax; using ops::Softplus; @@ -160,6 +161,18 @@ TEST_F(NNGradTest, Relu6Grad) { RunTest(x, x_init_value, y, shape); } +TEST_F(NNGradTest, LeakyReluGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = LeakyRelu(scope_, x); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + RunTest(x, x_init_value, y, shape); +} + TEST_F(NNGradTest, EluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); -- cgit v1.2.3 From 4e72dd865a3fc83baa69f6b7c08720a1b546a464 Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Wed, 29 Aug 2018 17:05:43 +0800 Subject: Refine LeakyRelu codes. 1. Add C++ gradient of gradient definition of LeakyReLu and revalant UT. 2. Using forward compatibility layer for python code changes. --- tensorflow/cc/gradients/nn_grad.cc | 18 ++++++- tensorflow/cc/gradients/nn_grad_test.cc | 16 ++++++ tensorflow/python/kernel_tests/relu_op_test.py | 70 ++++++++++++++------------ tensorflow/python/ops/nn_ops.py | 5 +- 4 files changed, 73 insertions(+), 36 deletions(-) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 0fc23d0bf7..2a32a2ed6f 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -149,13 +149,27 @@ Status LeakyReluGradHelper(const Scope& scope, const Operation& op, float alpha; TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); internal::LeakyReluGrad::Attrs attrs; - attrs.Alpha(alpha); - auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), attrs); + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), + attrs.Alpha(alpha)); grad_outputs->push_back(dx); return scope.status(); } REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper); +Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1), + attrs.Alpha(alpha)); + grad_outputs->push_back(dx); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper); + Status EluGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index 5ebece7b6e..bf0db1f59d 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/gradient_checker.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -173,6 +174,21 @@ TEST_F(NNGradTest, LeakyReluGrad) { RunTest(x, x_init_value, y, shape); } +TEST_F(NNGradTest, LeakyReluGradGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor( + {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, + {5, 2}); + Tensor features = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + auto y = ops::internal::LeakyReluGrad(scope_, x, features); + RunTest(x, x_init_value, y, shape); +} + TEST_F(NNGradTest, EluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index ccb3a231bb..7066f28883 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -283,8 +284,9 @@ class LeakyReluTest(test.TestCase): np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), alpha=0.1, use_gpu=True) - # The gradient test for ReLU is a bit tricky as the derivative is not well - # defined at around zero and we want to avoid that in terms of input values. + # The gradient test for Leaky ReLU is a bit tricky as the derivative is not + # well defined at around zero and we want to avoid that in terms of input + # values. def testGradientFloat32(self): with self.test_session(): x = constant_op.constant( @@ -319,39 +321,41 @@ class LeakyReluTest(test.TestCase): self.assertLess(err, 1e-10) def testGradGradFloat32(self): - with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - name="x") - y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float32, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) - print("leaky_relu (float32) gradient of gradient err = ", err) - self.assertLess(err, 1e-4) + with compat.forward_compatibility_horizon(2018, 10, 2): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float32) gradient of gradient err = ", err) + self.assertLess(err, 1e-4) def testGradGradFloat64(self): - with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - dtype=dtypes.float64, - name="x") - y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float64, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) - print("leaky_relu (float64) gradient of gradient err = ", err) - self.assertLess(err, 1e-10) + with compat.forward_compatibility_horizon(2018, 10, 2): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float64) gradient of gradient err = ", err) + self.assertLess(err, 1e-10) def testGradientScalar(self): with self.test_session() as sess: diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 31b8f3945d..52ea202636 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1601,7 +1601,10 @@ def leaky_relu(features, alpha=0.2, name=None): features = ops.convert_to_tensor(features, name="features") if features.dtype.is_integer: features = math_ops.to_float(features) - return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name) + if compat.forward_compatible(2018, 10, 1): + return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name) + alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha") + return math_ops.maximum(alpha * features, features, name=name) def _flatten_outer_dims(logits): -- cgit v1.2.3 From 2586eb3bfeeef3af357e438ae5aff92d2bac12a5 Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Mon, 3 Sep 2018 11:48:35 +0800 Subject: Code fix against ci_build error results. --- tensorflow/cc/gradients/nn_grad_test.cc | 3 +- tensorflow/core/kernels/relu_op.cc | 8 ++-- tensorflow/core/kernels/relu_op.h | 8 ++-- tensorflow/core/kernels/relu_op_functor.h | 1 - tensorflow/python/kernel_tests/relu_op_test.py | 50 ++++++++++++------------- tensorflow/tools/api/golden/v1/tensorflow.pbtxt | 4 ++ 6 files changed, 39 insertions(+), 35 deletions(-) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index bf0db1f59d..d8c2a1a0fc 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -180,8 +180,7 @@ TEST_F(NNGradTest, LeakyReluGradGrad) { // Avoid input values where Leaky ReLU gradient is not well defined (around // zero). Tensor x_init_value = test::AsTensor( - {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, - {5, 2}); + {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, {5, 2}); Tensor features = test::AsTensor( {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, {5, 2}); diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index c4f2ef5632..cafa49cbb6 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -106,15 +106,15 @@ namespace functor { \ template <> \ void LeakyRelu::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor features, \ - T alpha, typename TTypes::Tensor activations); \ + const GPUDevice& d, typename TTypes::ConstTensor features, T alpha, \ + typename TTypes::Tensor activations); \ extern template struct LeakyRelu; \ \ template <> \ void LeakyReluGrad::operator()( \ const GPUDevice& d, typename TTypes::ConstTensor gradients, \ - typename TTypes::ConstTensor features, \ - T alpha, typename TTypes::Tensor backprops); \ + typename TTypes::ConstTensor features, T alpha, \ + typename TTypes::Tensor backprops); \ extern template struct LeakyReluGrad; \ \ template <> \ diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h index c55190065c..fa79ab03ae 100644 --- a/tensorflow/core/kernels/relu_op.h +++ b/tensorflow/core/kernels/relu_op.h @@ -143,8 +143,8 @@ class LeakyReluOp : public UnaryElementWiseOp> { void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { functor::LeakyRelu functor; - functor(context->eigen_device(), input.flat(), - alpha_, output->flat()); + functor(context->eigen_device(), input.flat(), alpha_, + output->flat()); } private: @@ -183,7 +183,9 @@ class LeakyReluGradOp template void LeakyReluGradOp::OperateNoTemplate(OpKernelContext* context, - const Tensor& g, const Tensor& a, T alpha, Tensor* output) { + const Tensor& g, + const Tensor& a, T alpha, + Tensor* output) { if (!ReluHelpers::ValidateSameSize(context, g, a)) return; functor::LeakyReluGrad functor; functor(context->eigen_device(), g.flat(), a.flat(), alpha, diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h index 7f0951451d..548d5a277d 100644 --- a/tensorflow/core/kernels/relu_op_functor.h +++ b/tensorflow/core/kernels/relu_op_functor.h @@ -91,7 +91,6 @@ struct Relu6Grad { } }; - // Functor used by LeakyReluOp to do the computations. template struct LeakyRelu { diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index 7066f28883..3e24b8a2c4 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -323,37 +323,37 @@ class LeakyReluTest(test.TestCase): def testGradGradFloat32(self): with compat.forward_compatibility_horizon(2018, 10, 2): with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - name="x") - y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float32, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) print("leaky_relu (float32) gradient of gradient err = ", err) self.assertLess(err, 1e-4) def testGradGradFloat64(self): with compat.forward_compatibility_horizon(2018, 10, 2): with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - dtype=dtypes.float64, - name="x") - y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float64, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) print("leaky_relu (float64) gradient of gradient err = ", err) self.assertLess(err, 1e-10) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 4de662fe33..9e8d320f06 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1324,6 +1324,10 @@ tf_module { name: "lbeta" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "leaky_relu" + argspec: "args=[\'features\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.2\', \'None\'], " + } member_method { name: "less" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " -- cgit v1.2.3 From a95281ce1b449d8f92a3799ff9c1dbf661b70bc4 Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Wed, 5 Sep 2018 09:02:40 +0800 Subject: Avoid golden API file changing. --- tensorflow/cc/gradients/nn_grad_test.cc | 3 +-- tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt | 1 + tensorflow/tools/api/golden/v1/tensorflow.pbtxt | 4 ---- 3 files changed, 2 insertions(+), 6 deletions(-) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index d8c2a1a0fc..f5a09e09dc 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -42,7 +42,6 @@ using ops::MaxPoolV2; using ops::Placeholder; using ops::Relu; using ops::Relu6; -using ops::LeakyRelu; using ops::Selu; using ops::Softmax; using ops::Softplus; @@ -165,7 +164,7 @@ TEST_F(NNGradTest, Relu6Grad) { TEST_F(NNGradTest, LeakyReluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); - auto y = LeakyRelu(scope_, x); + auto y = ops::internal::LeakyRelu(scope_, x); // Avoid input values where Leaky ReLU gradient is not well defined (around // zero). Tensor x_init_value = test::AsTensor( diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt index 4a61889f54..280148e032 100644 --- a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt @@ -1,4 +1,5 @@ op { graph_op_name: "LeakyRelu" + visibility: HIDDEN summary: "Computes rectified linear: `max(features, features * alpha)`." } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 9e8d320f06..4de662fe33 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1324,10 +1324,6 @@ tf_module { name: "lbeta" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "leaky_relu" - argspec: "args=[\'features\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.2\', \'None\'], " - } member_method { name: "less" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " -- cgit v1.2.3 From f0886f7269de900d226455d4831722f6fc94a71b Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Tue, 25 Sep 2018 09:59:17 +0800 Subject: Fix build dependencies in tensorflow/cc/BUILD. --- tensorflow/cc/BUILD | 1 + tensorflow/python/kernel_tests/relu_op_test.py | 4 ++-- tensorflow/python/ops/nn_ops.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index f56521dac0..e99d15f85d 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -410,6 +410,7 @@ tf_cc_test( srcs = ["gradients/nn_grad_test.cc"], deps = [ ":cc_ops", + ":cc_ops_internal", ":grad_op_registry", ":grad_testutil", ":gradient_checker", diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index 86d9c90e83..d97a1613b9 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -351,7 +351,7 @@ class LeakyReluTest(test.TestCase): self.assertLess(err, 1e-10) def testGradGradFloat32(self): - with compat.forward_compatibility_horizon(2018, 10, 2): + with compat.forward_compatibility_horizon(2018, 11, 2): with self.test_session(): x = constant_op.constant( [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], @@ -369,7 +369,7 @@ class LeakyReluTest(test.TestCase): self.assertLess(err, 1e-4) def testGradGradFloat64(self): - with compat.forward_compatibility_horizon(2018, 10, 2): + with compat.forward_compatibility_horizon(2018, 11, 2): with self.test_session(): x = constant_op.constant( [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index d646245ce3..2861f40586 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1601,7 +1601,7 @@ def leaky_relu(features, alpha=0.2, name=None): features = ops.convert_to_tensor(features, name="features") if features.dtype.is_integer: features = math_ops.to_float(features) - if compat.forward_compatible(2018, 10, 1): + if compat.forward_compatible(2018, 11, 1): return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name) alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha") return math_ops.maximum(alpha * features, features, name=name) -- cgit v1.2.3