aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/cc/gradients/array_grad.cc26
-rw-r--r--tensorflow/cc/gradients/array_grad_test.cc17
2 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc
index 1ae60118ec..1265dd9f09 100644
--- a/tensorflow/cc/gradients/array_grad.cc
+++ b/tensorflow/cc/gradients/array_grad.cc
@@ -187,6 +187,32 @@ Status TransposeGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Transpose", TransposeGrad);
+Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ auto seq_lengths = op.input(1);
+ int batch_dim;
+ TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim));
+ int seq_dim;
+ TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim));
+ grad_outputs->push_back(
+ ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
+ ReverseSequence::BatchDim(batch_dim)));
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad);
+
+Status ReverseGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ auto reverse_dims = op.input(1);
+ grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims));
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Reverse", ReverseGrad);
+
} // anonymous namespace
} // namespace ops
} // namespace tensorflow
diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc
index 48f32c8057..2d6fabd8c8 100644
--- a/tensorflow/cc/gradients/array_grad_test.cc
+++ b/tensorflow/cc/gradients/array_grad_test.cc
@@ -195,5 +195,22 @@ TEST_F(ArrayGradTest, TransposeGrad) {
RunTest(x, x_shape, y, y_shape);
}
+TEST_F(ArrayGradTest, ReverseSequenceGrad) {
+ TensorShape shape({5, 2, 5});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
+ auto seq_lengths = Const(scope_, {1, 2, 3, 4, 5});
+ // batch_dim defaults to 0.
+ auto y = ReverseSequence(scope_, x, seq_lengths, /* seq_dim */ 2);
+ RunTest(x, shape, y, shape);
+}
+
+TEST_F(ArrayGradTest, ReverseGrad) {
+ TensorShape shape({5, 2, 5});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
+ auto reverse_dims = Const(scope_, {true, false, true});
+ auto y = Reverse(scope_, x, reverse_dims);
+ RunTest(x, shape, y, shape);
+}
+
} // namespace
} // namespace tensorflow