diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2016-12-19 16:34:57 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-19 16:46:09 -0800 |
commit | 91b9c31f41833c2040cdd247378deb4b66707fc6 (patch) | |
tree | d1218e1ea3595d78f95edd539331df5f6cfd36f3 | |
parent | 4fbbb3040ba9e4b2cf7d4b40088e9f2cac28bdbf (diff) |
C++ grad: A couple Reverse* ops gradients.
Change: 142500394
-rw-r--r-- | tensorflow/cc/gradients/array_grad.cc | 26 | ||||
-rw-r--r-- | tensorflow/cc/gradients/array_grad_test.cc | 17 |
2 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 1ae60118ec..1265dd9f09 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -187,6 +187,32 @@ Status TransposeGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Transpose", TransposeGrad); +Status ReverseSequenceGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + auto seq_lengths = op.input(1); + int batch_dim; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim)); + int seq_dim; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim)); + grad_outputs->push_back( + ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim, + ReverseSequence::BatchDim(batch_dim))); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad); + +Status ReverseGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + auto reverse_dims = op.input(1); + grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims)); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("Reverse", ReverseGrad); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index 48f32c8057..2d6fabd8c8 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -195,5 +195,22 @@ TEST_F(ArrayGradTest, TransposeGrad) { RunTest(x, x_shape, y, y_shape); } +TEST_F(ArrayGradTest, ReverseSequenceGrad) { + TensorShape shape({5, 2, 5}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto seq_lengths = Const(scope_, {1, 2, 3, 4, 5}); + // batch_dim defaults to 0. + auto y = ReverseSequence(scope_, x, seq_lengths, /* seq_dim */ 2); + RunTest(x, shape, y, shape); +} + +TEST_F(ArrayGradTest, ReverseGrad) { + TensorShape shape({5, 2, 5}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto reverse_dims = Const(scope_, {true, false, true}); + auto y = Reverse(scope_, x, reverse_dims); + RunTest(x, shape, y, shape); +} + } // namespace } // namespace tensorflow |