aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients/array_grad.cc
diff options
context:
space:
mode:
authorGravatar KB Sriram <kbsriram@gmail.com>2018-03-07 08:11:03 -0800
committerGravatar KB Sriram <kbsriram@gmail.com>2018-03-08 07:04:28 -0800
commitcee41f9d10b81ce3b49f566ddd448a7f3f2872c3 (patch)
treee6e1b096a366e43ebef088e5339883049cba002f /tensorflow/cc/gradients/array_grad.cc
parentf73d7c90ed05bcf9f36f6a3be0c29efa5fef0f6e (diff)
C++ gradient for StridedSlice
Diffstat (limited to 'tensorflow/cc/gradients/array_grad.cc')
-rw-r--r--tensorflow/cc/gradients/array_grad.cc36
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc
index 6545e4ee3e..ff348fadb2 100644
--- a/tensorflow/cc/gradients/array_grad.cc
+++ b/tensorflow/cc/gradients/array_grad.cc
@@ -385,6 +385,42 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
+Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ Input x = Shape(scope, op.input(0));
+ Input begin = op.input(1);
+ Input end = op.input(2);
+ Input strides = op.input(3);
+ int64 begin_mask;
+ int64 end_mask;
+ int64 ellipsis_mask;
+ int64 new_axis_mask;
+ int64 shrink_axis_mask;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(op.node()->attrs(), "begin_mask", &begin_mask));
+ TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask", &end_mask));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(op.node()->attrs(), "ellipsis_mask", &ellipsis_mask));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(op.node()->attrs(), "new_axis_mask", &new_axis_mask));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(op.node()->attrs(), "shrink_axis_mask", &shrink_axis_mask));
+ grad_outputs->push_back(
+ StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0],
+ StridedSliceGrad::BeginMask(begin_mask)
+ .EndMask(end_mask)
+ .EllipsisMask(ellipsis_mask)
+ .NewAxisMask(new_axis_mask)
+ .ShrinkAxisMask(shrink_axis_mask)));
+ // No gradients returned for begin, end and strides
+ grad_outputs->push_back(NoGradient());
+ grad_outputs->push_back(NoGradient());
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
+
} // anonymous namespace
} // namespace ops
} // namespace tensorflow