diff options
author | Cao Zongyan <zongyan.cao@alibaba-inc.com> | 2018-08-29 17:05:43 +0800 |
---|---|---|
committer | Cao Zongyan <zongyan.cao@alibaba-inc.com> | 2018-08-29 17:05:43 +0800 |
commit | 4e72dd865a3fc83baa69f6b7c08720a1b546a464 (patch) | |
tree | 9ac73a11393bf248e4f85c64feafda9316081781 /tensorflow/cc | |
parent | cb5c61a3e11a37fb39a246aaf8ed6d02dd9ae9ab (diff) |
Refine LeakyRelu codes.
1. Add C++ gradient of gradient definition of LeakyReLu and revalant UT.
2. Using forward compatibility layer for python code changes.
Diffstat (limited to 'tensorflow/cc')
-rw-r--r-- | tensorflow/cc/gradients/nn_grad.cc | 18 | ||||
-rw-r--r-- | tensorflow/cc/gradients/nn_grad_test.cc | 16 |
2 files changed, 32 insertions, 2 deletions
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 0fc23d0bf7..2a32a2ed6f 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -149,13 +149,27 @@ Status LeakyReluGradHelper(const Scope& scope, const Operation& op, float alpha; TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); internal::LeakyReluGrad::Attrs attrs; - attrs.Alpha(alpha); - auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), attrs); + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), + attrs.Alpha(alpha)); grad_outputs->push_back(dx); return scope.status(); } REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper); +Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1), + attrs.Alpha(alpha)); + grad_outputs->push_back(dx); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper); + Status EluGradHelper(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index 5ebece7b6e..bf0db1f59d 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/gradient_checker.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -173,6 +174,21 @@ TEST_F(NNGradTest, LeakyReluGrad) { RunTest(x, x_init_value, y, shape); } +TEST_F(NNGradTest, LeakyReluGradGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor<float>( + {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, + {5, 2}); + Tensor features = test::AsTensor<float>( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + auto y = ops::internal::LeakyReluGrad(scope_, x, features); + RunTest(x, x_init_value, y, shape); +} + TEST_F(NNGradTest, EluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); |