diff options
Diffstat (limited to 'tensorflow/cc/gradients/array_grad.cc')
-rw-r--r-- | tensorflow/cc/gradients/array_grad.cc | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index ff348fadb2..b353accddc 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -421,6 +421,58 @@ Status StridedSliceGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper); +Status SliceGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // Propagate the incoming gradient along all the selected values, + // and zero everywhere else. Use the Pad operator for this. + // + // First create an Nx2 padding where N is the number of input + // dimensions. The first column is the number of prepended zeros + // for each dimension, and the second column is the number of + // appended zeros. + // + // The first column is just the begin vector. + // The second column is the shape of the input element-wise + // subtracted by begin+size + + // Running example: + // input.shape = [3, 5, 3] + // begin = [1, 2, 1], size = [1, 3, 2] + Input input = op.input(0); + Input begin = op.input(1); + // input_rank = 3 + auto input_rank = Rank(scope, input); + // slice_size = [1, 3, 2] + auto slice_size = Shape(scope, op.output(0)); + // padding_shape = [3, 1] + auto padding_shape = Stack(scope, {input_rank, 1}); + // before_padding = [[1] + // [2] + // [1]] + Input before_padding = Reshape(scope, begin, padding_shape); + // after_padding_sizes = shape(input) - slice_size - begin + // = [3, 5, 3] - [1, 3, 2] - [1, 2, 1] + // = [1, 0, 0] + auto after_padding_sizes = + Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin); + // after_padding = [[1] + // [0] + // [0]] + Input after_padding = Reshape(scope, after_padding_sizes, padding_shape); + // paddings = [[1 1] + // [2 0] + // [1 0]] + auto paddings = + Concat(scope, {before_padding, after_padding}, Const(scope, 1)); + grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings)); + // Nothing propagated for "begin" and "size" inputs + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("Slice", SliceGrad); + } // anonymous namespace } // namespace ops } // namespace tensorflow |