From b462815ccc279c79a5592a0a2d4492718da02c5d Mon Sep 17 00:00:00 2001 From: KB Sriram Date: Wed, 18 Jul 2018 08:31:46 -0700 Subject: Add gradient for tensorflow::ops::Fill See https://github.com/tensorflow/tensorflow/issues/20926 --- tensorflow/cc/gradients/array_grad.cc | 18 ++++++++++++++++++ tensorflow/cc/gradients/array_grad_test.cc | 8 ++++++++ 2 files changed, 26 insertions(+) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index b353accddc..e9173227aa 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -120,6 +120,24 @@ Status SplitGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Split", SplitGrad); +Status FillGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = fill(fill_shape, x) + // No gradient returned for the fill_shape argument. + grad_outputs->push_back(NoGradient()); + // The gradient for x (which must be a scalar) is just the sum of + // all the gradients from the shape it fills. + // We use ReduceSum to implement this, which needs an argument providing + // the indices of all the dimensions of the incoming gradient. + // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))]) + auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]), + Const(scope, 1)); + grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims)); + return scope.status(); +} +REGISTER_GRADIENT_OP("Fill", FillGrad); + Status DiagGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index d09275b648..f41de3dc20 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -108,6 +108,14 @@ TEST_F(ArrayGradTest, SplitGrad) { RunTest({x}, {x_shape}, y.output, {y_shape, y_shape}); } +TEST_F(ArrayGradTest, FillGrad) { + TensorShape x_shape({}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + TensorShape y_shape({2, 5, 3}); + auto y = Fill(scope_, {2, 5, 3}, x); + RunTest(x, x_shape, y, y_shape); +} + TEST_F(ArrayGradTest, DiagGrad) { TensorShape x_shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); -- cgit v1.2.3