aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2018-05-01 14:28:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-01 14:33:20 -0700
commit325d0ef21a48bea1cc618a2bd24a9776de417ce5 (patch)
treed41cf6304071e95bebd5747ca87dfca571e98634 /tensorflow/cc
parent46bf1e8934b3bc8edeff3f218a50b0ee5806e96b (diff)
Merge changes from github.
PiperOrigin-RevId: 194997009
Diffstat (limited to 'tensorflow/cc')
-rw-r--r--tensorflow/cc/gradients/array_grad.cc36
-rw-r--r--tensorflow/cc/gradients/array_grad_test.cc24
2 files changed, 60 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc
index 6545e4ee3e..ff348fadb2 100644
--- a/tensorflow/cc/gradients/array_grad.cc
+++ b/tensorflow/cc/gradients/array_grad.cc
@@ -385,6 +385,42 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
+Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ Input x = Shape(scope, op.input(0));
+ Input begin = op.input(1);
+ Input end = op.input(2);
+ Input strides = op.input(3);
+ int64 begin_mask;
+ int64 end_mask;
+ int64 ellipsis_mask;
+ int64 new_axis_mask;
+ int64 shrink_axis_mask;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(op.node()->attrs(), "begin_mask", &begin_mask));
+ TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask", &end_mask));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(op.node()->attrs(), "ellipsis_mask", &ellipsis_mask));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(op.node()->attrs(), "new_axis_mask", &new_axis_mask));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(op.node()->attrs(), "shrink_axis_mask", &shrink_axis_mask));
+ grad_outputs->push_back(
+ StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0],
+ StridedSliceGrad::BeginMask(begin_mask)
+ .EndMask(end_mask)
+ .EllipsisMask(ellipsis_mask)
+ .NewAxisMask(new_axis_mask)
+ .ShrinkAxisMask(shrink_axis_mask)));
+ // No gradients returned for begin, end and strides
+ grad_outputs->push_back(NoGradient());
+ grad_outputs->push_back(NoGradient());
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
+
} // anonymous namespace
} // namespace ops
} // namespace tensorflow
diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc
index 4a215fcc92..de3bd0fc9e 100644
--- a/tensorflow/cc/gradients/array_grad_test.cc
+++ b/tensorflow/cc/gradients/array_grad_test.cc
@@ -354,5 +354,29 @@ TEST_F(ArrayGradTest, MirrorPadGradGrad_Symmetric) {
RunTest(x, x_shape, y, y_shape);
}
+TEST_F(ArrayGradTest, StridedSliceGrad) {
+ TensorShape x_shape({6, 4, 4});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
+
+ // y = x[2:6:2, 1:3, 1:3]
+ auto y = StridedSlice(scope_, x, {2, 1, 1}, {6, 3, 3}, {2, 1, 1});
+ // y.shape = [2, 2, 2];
+ RunTest(x, x_shape, y, {2, 2, 2});
+
+ // y = x[2:6:2, 1:3, 1:3]
+ // begin_mask = 1<<1 (ignore begin_index = 1)
+ // end_mask = 1<<2 (ignore end_index = 2)
+ y = StridedSlice(scope_, x, {2, 1, 1}, {6, 3, 3}, {2, 1, 1},
+ StridedSlice::BeginMask(1 << 1).EndMask(1 << 2));
+ // y.shape = [2, 3, 3];
+ RunTest(x, x_shape, y, {2, 3, 3});
+
+ // y = [tf.newaxis, 2:6:2, 1:3, 1:3]
+ y = StridedSlice(scope_, x, {0, 2, 1, 1}, {0, 6, 3, 3}, {1, 2, 1, 1},
+ StridedSlice::NewAxisMask(1 << 0));
+ // y.shape = [1, 2, 2, 2];
+ RunTest(x, x_shape, y, {1, 2, 2, 2});
+}
+
} // namespace
} // namespace tensorflow