diff options
author | 2016-07-19 11:35:27 -0800 | |
---|---|---|
committer | 2016-07-19 15:05:45 -0700 | |
commit | f2f7dba849baeecd473423eca560ef0f28300327 (patch) | |
tree | c052e59d3222adb76e305029dd22841752582d7a | |
parent | fb91a77268c69020aa304dfaeb6cc701af94242e (diff) |
Implement c++ version of StridedSlice gradient.
Change: 127861326
-rw-r--r-- | tensorflow/core/ops/array_grad.cc | 53 | ||||
-rw-r--r-- | tensorflow/core/ops/array_grad_test.cc | 141 |
2 files changed, 186 insertions, 8 deletions
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index 2589e76b90..d3ffb907bc 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -337,18 +337,18 @@ Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) { } *g = FDH::Define( // Arg defs - {"x: T", "b: int32", "s: int32", "dy: T"}, + {"x: T", "begin: int32", "size: int32", "dy: T"}, // Ret val defs - {"dx: T", "db: int32", "ds: int32"}, + {"dx: T", "begin_grad: int32", "size_grad: int32"}, // Attr defs {"T: type"}, // Nodes - {// paddings = concat(1, [b, shape(x) - b - s]) + {// paddings = concat(1, [begin, shape(x) - begin - size]) FDH::Const("one", 1), - {{"b1"}, "ExpandDims", {"b", "one"}, {{"T", DT_INT32}}}, + {{"b1"}, "ExpandDims", {"begin", "one"}, {{"T", DT_INT32}}}, {{"xs"}, "Shape", {"x"}, {{"T", "$T"}}}, - {{"xs_b"}, "Sub", {"xs", "b"}, {{"T", DT_INT32}}}, - {{"xs_b_s"}, "Sub", {"xs_b", "s"}, {{"T", DT_INT32}}}, + {{"xs_b"}, "Sub", {"xs", "begin"}, {{"T", DT_INT32}}}, + {{"xs_b_s"}, "Sub", {"xs_b", "size"}, {{"T", DT_INT32}}}, {{"a1"}, "ExpandDims", {"xs_b_s", "one"}, {{"T", DT_INT32}}}, {{"b_and_a"}, "_ListToArray", @@ -362,11 +362,48 @@ Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) { {{"N", 2}, {"T", DT_INT32}}}, // dx = Pad(dy, paddings) {{"dx"}, "Pad", {"dy", "paddings"}, {{"T", "$T"}}}, - {{"db"}, "ZerosLike", {"b"}, {{"T", DT_INT32}}}, - {{"ds"}, "ZerosLike", {"s"}, {{"T", DT_INT32}}}}); + {{"begin_grad"}, "ZerosLike", {"begin"}, {{"T", DT_INT32}}}, + {{"size_grad"}, "ZerosLike", {"size"}, {{"T", DT_INT32}}}}); VLOG(1) << "SliceGrad " << DebugString(*g); return Status::OK(); } REGISTER_OP_GRADIENT("Slice", SliceGrad); +Status StridedSliceGrad(const AttrSlice& attrs, FunctionDef* g) { + DataType itype; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype)); + if (itype != DT_INT32) { + return errors::Unimplemented( + "SliceGrad for int64 index are not supported."); + } + + *g = FDH::Define( + // Arg defs + {"x: T", "begin: int32", "end: int32", "stride: int32", "dy: T"}, + // Ret val defs + {"dx: T", "begin_grad: int32", "end_grad: int32", "stride_grad: int32"}, + // Attr defs + {"T: type", "Index: {int32, int64}", "begin_mask: int", "end_mask: int", + "ellipsis_mask: int", "new_axis_mask: int", "shrink_axis_mask: int"}, + {// Nodes + {{{"xs"}, "Shape", {"x"}, {{"T", "$T"}}}, + {{"dx"}, + "StridedSliceGrad", + {"xs", "begin", "end", "stride", "dy"}, + {{"T", "$T"}, + {"Index", "$Index"}, + {"begin_mask", "$begin_mask"}, + {"end_mask", "$end_mask"}, + {"ellipsis_mask", "$ellipsis_mask"}, + {"new_axis_mask", "$new_axis_mask"}, + {"shrink_axis_mask", "$shrink_axis_mask"}}}, + {{"begin_grad"}, "ZerosLike", {"begin"}, {{"T", DT_INT32}}}, + {{"end_grad"}, "ZerosLike", {"end"}, {{"T", DT_INT32}}}, + {{"stride_grad"}, "ZerosLike", {"stride"}, {{"T", DT_INT32}}}}}); + + VLOG(1) << "StridedSliceGrad " << DebugString(*g); + return Status::OK(); +} +REGISTER_OP_GRADIENT("StridedSlice", StridedSliceGrad); + } // end namespace tensorflow diff --git a/tensorflow/core/ops/array_grad_test.cc b/tensorflow/core/ops/array_grad_test.cc index 9beaa8abe1..a051e456e5 100644 --- a/tensorflow/core/ops/array_grad_test.cc +++ b/tensorflow/core/ops/array_grad_test.cc @@ -435,4 +435,145 @@ TEST_F(ArrayGradTest, SliceGrad) { test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0})); } +std::vector<Tensor> StridedSliceGrad(const Tensor& x, const Tensor& begin, + const Tensor& end, const Tensor& strides, + const Tensor& dy, int32 begin_mask, + int32 end_mask, int32 ellipsis_mask, + int32 new_axis_mask, + int32 shrink_axis_mask) { + auto T = DT_FLOAT; + auto gdef = test::function::GDef( + {f::NDef("x", "Placeholder", {}, {{"dtype", T}}), + f::NDef("begin", "Placeholder", {}, {{"dtype", DT_INT32}}), + f::NDef("end", "Placeholder", {}, {{"dtype", DT_INT32}}), + f::NDef("strides", "Placeholder", {}, {{"dtype", DT_INT32}}), + f::NDef("dy", "Placeholder", {}, {{"dtype", T}}), + f::NDef( + "dx", "SymbolicGradient", {"x", "begin", "end", "strides", "dy"}, + {{"f", FDH::FunctionRef("StridedSlice", + { + {"T", T}, + {"Index", DT_INT32}, + {"begin_mask", begin_mask}, + {"end_mask", end_mask}, + {"new_axis_mask", new_axis_mask}, + {"shrink_axis_mask", shrink_axis_mask}, + {"ellipsis_mask", ellipsis_mask}, + })}, + {"Tin", DataTypeSlice{T, DT_INT32, DT_INT32, DT_INT32, T}}, + {"Tout", DataTypeSlice{T, DT_INT32, DT_INT32, DT_INT32}}})}); + VLOG(1) << DebugStringWhole(gdef); + auto sess = NewSession(); + TF_CHECK_OK(sess->Create(gdef)); + std::vector<Tensor> out; + TF_CHECK_OK(sess->Run({{"x:0", x}, + {"begin:0", begin}, + {"end:0", end}, + {"strides:0", strides}, + {"dy:0", dy}}, + {"dx:0", "dx:1", "dx:2", "dx:3"}, {}, &out)); + CHECK_EQ(out.size(), 4); + TF_CHECK_OK(sess->Close()); + delete sess; + return out; +} + +TEST_F(ArrayGradTest, StridedSliceGrad) { + Tensor x(DT_FLOAT, {2, 3, 4}); + x.flat<float>().setZero(); + + { + auto start = test::AsTensor<int32>({1, 1, 1}); + auto stop = test::AsTensor<int32>({2, 3, 3}); + auto strides = test::AsTensor<int32>({1, 1, 1}); + Tensor dy(DT_FLOAT, {1, 2, 2}); + test::FillIota<float>(&dy, 1); + int begin_mask = 0, end_mask = 0, new_axis_mask = 0, shrink_axis_mask = 0, + ellipsis_mask = 0; + auto dx = + StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask, + ellipsis_mask, new_axis_mask, shrink_axis_mask); + test::ExpectClose(dx[0], + test::AsTensor<float>( + { + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0., + }, + {2, 3, 4})); + test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0})); + test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0})); + } + + // test equivalent of python tf.gradients(foo[1:2, 1:3, 1:3]) + { + auto start = test::AsTensor<int32>({1, 1, 1}); + auto stop = test::AsTensor<int32>({2, 3, 3}); + auto strides = test::AsTensor<int32>({1, 1, 1}); + Tensor dy(DT_FLOAT, {1, 2, 2}); + test::FillIota<float>(&dy, 1); + int begin_mask = 0, end_mask = 0, new_axis_mask = 0, shrink_axis_mask = 0, + ellipsis_mask = 0; + auto dx = + StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask, + ellipsis_mask, new_axis_mask, shrink_axis_mask); + test::ExpectClose(dx[0], + test::AsTensor<float>( + { + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0., + }, + {2, 3, 4})); + test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0})); + test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0})); + } + + // test equivalent of python tf.gradients(foo[1, 1:, :-2, None]) + { + int dontcare = 66; + auto start = test::AsTensor<int32>({1, 1, dontcare, dontcare}); + auto stop = test::AsTensor<int32>({2, dontcare, -2, dontcare}); + auto strides = test::AsTensor<int32>({1, 1, 1, dontcare}); + Tensor dy(DT_FLOAT, {2, 2, 1}); + test::FillIota<float>(&dy, 1); + int begin_mask = 4, end_mask = 2, new_axis_mask = 8, shrink_axis_mask = 1, + ellipsis_mask = 0; + auto dx = + StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask, + ellipsis_mask, new_axis_mask, shrink_axis_mask); + test::ExpectClose(dx[0], + test::AsTensor<float>( + { + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0., 0., + }, + {2, 3, 4})); + test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0, 0})); + test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0, 0})); + } + + // test equivalent of tf.gradients(foo[1, ...]) i.e. foo[1, 0:3, 0:4] + { + int dontcare = 66; + auto start = test::AsTensor<int32>({1, dontcare}); + auto stop = test::AsTensor<int32>({2, dontcare}); + auto strides = test::AsTensor<int32>({1, 1}); + Tensor dy(DT_FLOAT, {3, 4}); + test::FillIota<float>(&dy, 1); + int begin_mask = 0, end_mask = 0, new_axis_mask = 0, shrink_axis_mask = 1, + ellipsis_mask = 2; + auto dx = + StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask, + ellipsis_mask, new_axis_mask, shrink_axis_mask); + test::ExpectClose(dx[0], + test::AsTensor<float>( + { + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + }, + {2, 3, 4})); + test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0})); + test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0})); + } +} + } // namespace tensorflow |