diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-12-22 12:42:59 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-22 12:46:28 -0800 |
commit | e4532d20973c4c00854492362665317551661c18 (patch) | |
tree | 398527e29bd30d39237adb4785be5069fdb646fa /tensorflow/cc/gradients | |
parent | 673641c2d6a27fa97ee05453d671853731a4c602 (diff) |
Merge changes from github.
PiperOrigin-RevId: 179953488
Diffstat (limited to 'tensorflow/cc/gradients')
-rw-r--r-- | tensorflow/cc/gradients/math_grad.cc | 18 | ||||
-rw-r--r-- | tensorflow/cc/gradients/math_grad_test.cc | 8 |
2 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index d7446b9560..ebc0c77828 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -728,6 +728,24 @@ Status LgammaGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Lgamma", LgammaGrad); +Status SelectGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + auto comparator = op.input(0); + auto x = op.input(1); + auto zeros = ZerosLike(scope, x); + auto grad = grad_inputs[0]; + + auto gx_1 = Where3(scope, comparator, grad, zeros); + auto gx_2 = Where3(scope, comparator, zeros, grad); + + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(gx_1); + grad_outputs->push_back(gx_2); + return scope.status(); +} +REGISTER_GRADIENT_OP("Select", SelectGrad); + Status MinOrMaxGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 6313f41da5..29def3c3ea 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -865,5 +865,13 @@ TEST_F(NaryGradTest, Minimum) { RunTest(x, x_init_value, y, shape); } +TEST_F(NaryGradTest, Select) { + TensorShape shape({3, 4}); + auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Where3(scope_, Greater(scope_, x1, x2), x1, x2); + RunTest({x1, x2}, {shape, shape}, {y}, {shape}); +} + } // namespace } // namespace tensorflow |