aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients/array_grad.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/gradients/array_grad.cc')
-rw-r--r--tensorflow/cc/gradients/array_grad.cc52
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc
index ff348fadb2..b353accddc 100644
--- a/tensorflow/cc/gradients/array_grad.cc
+++ b/tensorflow/cc/gradients/array_grad.cc
@@ -421,6 +421,58 @@ Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
+Status SliceGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // Propagate the incoming gradient along all the selected values,
+ // and zero everywhere else. Use the Pad operator for this.
+ //
+ // First create an Nx2 padding where N is the number of input
+ // dimensions. The first column is the number of prepended zeros
+ // for each dimension, and the second column is the number of
+ // appended zeros.
+ //
+ // The first column is just the begin vector.
+ // The second column is the shape of the input element-wise
+ // subtracted by begin+size
+
+ // Running example:
+ // input.shape = [3, 5, 3]
+ // begin = [1, 2, 1], size = [1, 3, 2]
+ Input input = op.input(0);
+ Input begin = op.input(1);
+ // input_rank = 3
+ auto input_rank = Rank(scope, input);
+ // slice_size = [1, 3, 2]
+ auto slice_size = Shape(scope, op.output(0));
+ // padding_shape = [3, 1]
+ auto padding_shape = Stack(scope, {input_rank, 1});
+ // before_padding = [[1]
+ // [2]
+ // [1]]
+ Input before_padding = Reshape(scope, begin, padding_shape);
+ // after_padding_sizes = shape(input) - slice_size - begin
+ // = [3, 5, 3] - [1, 3, 2] - [1, 2, 1]
+ // = [1, 0, 0]
+ auto after_padding_sizes =
+ Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin);
+ // after_padding = [[1]
+ // [0]
+ // [0]]
+ Input after_padding = Reshape(scope, after_padding_sizes, padding_shape);
+ // paddings = [[1 1]
+ // [2 0]
+ // [1 0]]
+ auto paddings =
+ Concat(scope, {before_padding, after_padding}, Const(scope, 1));
+ grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings));
+ // Nothing propagated for "begin" and "size" inputs
+ grad_outputs->push_back(NoGradient());
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Slice", SliceGrad);
+
} // anonymous namespace
} // namespace ops
} // namespace tensorflow