aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2016-07-19 11:35:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-19 15:05:45 -0700
commitf2f7dba849baeecd473423eca560ef0f28300327 (patch)
treec052e59d3222adb76e305029dd22841752582d7a
parentfb91a77268c69020aa304dfaeb6cc701af94242e (diff)
Implement c++ version of StridedSlice gradient.
Change: 127861326
-rw-r--r--tensorflow/core/ops/array_grad.cc53
-rw-r--r--tensorflow/core/ops/array_grad_test.cc141
2 files changed, 186 insertions, 8 deletions
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc
index 2589e76b90..d3ffb907bc 100644
--- a/tensorflow/core/ops/array_grad.cc
+++ b/tensorflow/core/ops/array_grad.cc
@@ -337,18 +337,18 @@ Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) {
}
*g = FDH::Define(
// Arg defs
- {"x: T", "b: int32", "s: int32", "dy: T"},
+ {"x: T", "begin: int32", "size: int32", "dy: T"},
// Ret val defs
- {"dx: T", "db: int32", "ds: int32"},
+ {"dx: T", "begin_grad: int32", "size_grad: int32"},
// Attr defs
{"T: type"},
// Nodes
- {// paddings = concat(1, [b, shape(x) - b - s])
+ {// paddings = concat(1, [begin, shape(x) - begin - size])
FDH::Const("one", 1),
- {{"b1"}, "ExpandDims", {"b", "one"}, {{"T", DT_INT32}}},
+ {{"b1"}, "ExpandDims", {"begin", "one"}, {{"T", DT_INT32}}},
{{"xs"}, "Shape", {"x"}, {{"T", "$T"}}},
- {{"xs_b"}, "Sub", {"xs", "b"}, {{"T", DT_INT32}}},
- {{"xs_b_s"}, "Sub", {"xs_b", "s"}, {{"T", DT_INT32}}},
+ {{"xs_b"}, "Sub", {"xs", "begin"}, {{"T", DT_INT32}}},
+ {{"xs_b_s"}, "Sub", {"xs_b", "size"}, {{"T", DT_INT32}}},
{{"a1"}, "ExpandDims", {"xs_b_s", "one"}, {{"T", DT_INT32}}},
{{"b_and_a"},
"_ListToArray",
@@ -362,11 +362,48 @@ Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) {
{{"N", 2}, {"T", DT_INT32}}},
// dx = Pad(dy, paddings)
{{"dx"}, "Pad", {"dy", "paddings"}, {{"T", "$T"}}},
- {{"db"}, "ZerosLike", {"b"}, {{"T", DT_INT32}}},
- {{"ds"}, "ZerosLike", {"s"}, {{"T", DT_INT32}}}});
+ {{"begin_grad"}, "ZerosLike", {"begin"}, {{"T", DT_INT32}}},
+ {{"size_grad"}, "ZerosLike", {"size"}, {{"T", DT_INT32}}}});
VLOG(1) << "SliceGrad " << DebugString(*g);
return Status::OK();
}
REGISTER_OP_GRADIENT("Slice", SliceGrad);
+Status StridedSliceGrad(const AttrSlice& attrs, FunctionDef* g) {
+ DataType itype;
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype));
+ if (itype != DT_INT32) {
+ return errors::Unimplemented(
+ "SliceGrad for int64 index are not supported.");
+ }
+
+ *g = FDH::Define(
+ // Arg defs
+ {"x: T", "begin: int32", "end: int32", "stride: int32", "dy: T"},
+ // Ret val defs
+ {"dx: T", "begin_grad: int32", "end_grad: int32", "stride_grad: int32"},
+ // Attr defs
+ {"T: type", "Index: {int32, int64}", "begin_mask: int", "end_mask: int",
+ "ellipsis_mask: int", "new_axis_mask: int", "shrink_axis_mask: int"},
+ {// Nodes
+ {{{"xs"}, "Shape", {"x"}, {{"T", "$T"}}},
+ {{"dx"},
+ "StridedSliceGrad",
+ {"xs", "begin", "end", "stride", "dy"},
+ {{"T", "$T"},
+ {"Index", "$Index"},
+ {"begin_mask", "$begin_mask"},
+ {"end_mask", "$end_mask"},
+ {"ellipsis_mask", "$ellipsis_mask"},
+ {"new_axis_mask", "$new_axis_mask"},
+ {"shrink_axis_mask", "$shrink_axis_mask"}}},
+ {{"begin_grad"}, "ZerosLike", {"begin"}, {{"T", DT_INT32}}},
+ {{"end_grad"}, "ZerosLike", {"end"}, {{"T", DT_INT32}}},
+ {{"stride_grad"}, "ZerosLike", {"stride"}, {{"T", DT_INT32}}}}});
+
+ VLOG(1) << "StridedSliceGrad " << DebugString(*g);
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("StridedSlice", StridedSliceGrad);
+
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/array_grad_test.cc b/tensorflow/core/ops/array_grad_test.cc
index 9beaa8abe1..a051e456e5 100644
--- a/tensorflow/core/ops/array_grad_test.cc
+++ b/tensorflow/core/ops/array_grad_test.cc
@@ -435,4 +435,145 @@ TEST_F(ArrayGradTest, SliceGrad) {
test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0}));
}
+std::vector<Tensor> StridedSliceGrad(const Tensor& x, const Tensor& begin,
+ const Tensor& end, const Tensor& strides,
+ const Tensor& dy, int32 begin_mask,
+ int32 end_mask, int32 ellipsis_mask,
+ int32 new_axis_mask,
+ int32 shrink_axis_mask) {
+ auto T = DT_FLOAT;
+ auto gdef = test::function::GDef(
+ {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
+ f::NDef("begin", "Placeholder", {}, {{"dtype", DT_INT32}}),
+ f::NDef("end", "Placeholder", {}, {{"dtype", DT_INT32}}),
+ f::NDef("strides", "Placeholder", {}, {{"dtype", DT_INT32}}),
+ f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
+ f::NDef(
+ "dx", "SymbolicGradient", {"x", "begin", "end", "strides", "dy"},
+ {{"f", FDH::FunctionRef("StridedSlice",
+ {
+ {"T", T},
+ {"Index", DT_INT32},
+ {"begin_mask", begin_mask},
+ {"end_mask", end_mask},
+ {"new_axis_mask", new_axis_mask},
+ {"shrink_axis_mask", shrink_axis_mask},
+ {"ellipsis_mask", ellipsis_mask},
+ })},
+ {"Tin", DataTypeSlice{T, DT_INT32, DT_INT32, DT_INT32, T}},
+ {"Tout", DataTypeSlice{T, DT_INT32, DT_INT32, DT_INT32}}})});
+ VLOG(1) << DebugStringWhole(gdef);
+ auto sess = NewSession();
+ TF_CHECK_OK(sess->Create(gdef));
+ std::vector<Tensor> out;
+ TF_CHECK_OK(sess->Run({{"x:0", x},
+ {"begin:0", begin},
+ {"end:0", end},
+ {"strides:0", strides},
+ {"dy:0", dy}},
+ {"dx:0", "dx:1", "dx:2", "dx:3"}, {}, &out));
+ CHECK_EQ(out.size(), 4);
+ TF_CHECK_OK(sess->Close());
+ delete sess;
+ return out;
+}
+
+TEST_F(ArrayGradTest, StridedSliceGrad) {
+ Tensor x(DT_FLOAT, {2, 3, 4});
+ x.flat<float>().setZero();
+
+ {
+ auto start = test::AsTensor<int32>({1, 1, 1});
+ auto stop = test::AsTensor<int32>({2, 3, 3});
+ auto strides = test::AsTensor<int32>({1, 1, 1});
+ Tensor dy(DT_FLOAT, {1, 2, 2});
+ test::FillIota<float>(&dy, 1);
+ int begin_mask = 0, end_mask = 0, new_axis_mask = 0, shrink_axis_mask = 0,
+ ellipsis_mask = 0;
+ auto dx =
+ StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask,
+ ellipsis_mask, new_axis_mask, shrink_axis_mask);
+ test::ExpectClose(dx[0],
+ test::AsTensor<float>(
+ {
+ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0.,
+ },
+ {2, 3, 4}));
+ test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}));
+ test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0}));
+ }
+
+ // test equivalent of python tf.gradients(foo[1:2, 1:3, 1:3])
+ {
+ auto start = test::AsTensor<int32>({1, 1, 1});
+ auto stop = test::AsTensor<int32>({2, 3, 3});
+ auto strides = test::AsTensor<int32>({1, 1, 1});
+ Tensor dy(DT_FLOAT, {1, 2, 2});
+ test::FillIota<float>(&dy, 1);
+ int begin_mask = 0, end_mask = 0, new_axis_mask = 0, shrink_axis_mask = 0,
+ ellipsis_mask = 0;
+ auto dx =
+ StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask,
+ ellipsis_mask, new_axis_mask, shrink_axis_mask);
+ test::ExpectClose(dx[0],
+ test::AsTensor<float>(
+ {
+ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0.,
+ },
+ {2, 3, 4}));
+ test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}));
+ test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0}));
+ }
+
+ // test equivalent of python tf.gradients(foo[1, 1:, :-2, None])
+ {
+ int dontcare = 66;
+ auto start = test::AsTensor<int32>({1, 1, dontcare, dontcare});
+ auto stop = test::AsTensor<int32>({2, dontcare, -2, dontcare});
+ auto strides = test::AsTensor<int32>({1, 1, 1, dontcare});
+ Tensor dy(DT_FLOAT, {2, 2, 1});
+ test::FillIota<float>(&dy, 1);
+ int begin_mask = 4, end_mask = 2, new_axis_mask = 8, shrink_axis_mask = 1,
+ ellipsis_mask = 0;
+ auto dx =
+ StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask,
+ ellipsis_mask, new_axis_mask, shrink_axis_mask);
+ test::ExpectClose(dx[0],
+ test::AsTensor<float>(
+ {
+ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0., 0.,
+ },
+ {2, 3, 4}));
+ test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0, 0}));
+ test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0, 0}));
+ }
+
+ // test equivalent of tf.gradients(foo[1, ...]) i.e. foo[1, 0:3, 0:4]
+ {
+ int dontcare = 66;
+ auto start = test::AsTensor<int32>({1, dontcare});
+ auto stop = test::AsTensor<int32>({2, dontcare});
+ auto strides = test::AsTensor<int32>({1, 1});
+ Tensor dy(DT_FLOAT, {3, 4});
+ test::FillIota<float>(&dy, 1);
+ int begin_mask = 0, end_mask = 0, new_axis_mask = 0, shrink_axis_mask = 1,
+ ellipsis_mask = 2;
+ auto dx =
+ StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask,
+ ellipsis_mask, new_axis_mask, shrink_axis_mask);
+ test::ExpectClose(dx[0],
+ test::AsTensor<float>(
+ {
+ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
+ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.,
+ },
+ {2, 3, 4}));
+ test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0}));
+ test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0}));
+ }
+}
+
} // namespace tensorflow