diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-09 20:11:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-09 20:15:28 -0700 |
commit | de082e57fb7609c261af72f583e0d5c236023376 (patch) | |
tree | 0ac48de4042bf52caec8307c8bebf191ef692c06 /tensorflow | |
parent | 652be30e7140ca11b756a3ae0f9bd67f913af399 (diff) |
Add c++ gradient for SplitV op.
PiperOrigin-RevId: 208153311
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/ops/array_grad.cc | 21 | ||||
-rw-r--r-- | tensorflow/core/ops/array_grad_test.cc | 66 |
2 files changed, 78 insertions, 9 deletions
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index 38bd851da8..1f2e57e9a9 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -244,6 +244,27 @@ Status SplitGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Split", SplitGrad); +Status SplitVGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"x: T", "size_splits: Tlen", "dim: int32", "dy: num_split*T"}, + // Ret val defs + {"dx: T", "d_size_splits: Tlen", "d_dim: int32"}, + // Attr defs + {"T: type", "Tlen: type", "num_split: int"}, + // Nodes + { + {{"dx"}, "Concat", {"dim", "dy"}, {{"T", "$T"}, {"N", "$num_split"}}}, + {{"d_size_splits"}, "ZerosLike", {"size_splits"}, {{"T", "$Tlen"}}}, + {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}}, + }); + // clang-format on + VLOG(1) << "SplitVGrad " << DebugString(*g); + return Status::OK(); +} +REGISTER_OP_GRADIENT("SplitV", SplitVGrad); + Status ArrayToListGrad(const AttrSlice& attrs, FunctionDef* g) { int N; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N)); diff --git a/tensorflow/core/ops/array_grad_test.cc b/tensorflow/core/ops/array_grad_test.cc index e665d17938..79d28a83cc 100644 --- a/tensorflow/core/ops/array_grad_test.cc +++ b/tensorflow/core/ops/array_grad_test.cc @@ -238,6 +238,39 @@ std::vector<Tensor> SplitGrad(int dim, const Tensor& x, const Tensor& dy0, return out; } +std::vector<Tensor> SplitVGrad(const Tensor& x, const Tensor& size_splits, + int dim, const Tensor& dy0, const Tensor& dy1) { + auto T = DT_FLOAT; + auto Tlen = DT_INT64; + auto gdef = test::function::GDef( + {f::NDef("x", "Placeholder", {}, {{"dtype", T}}), + f::NDef("size_splits", "Placeholder", {}, {{"dtype", Tlen}}), + f::NDef("dim", "Placeholder", {}, {{"dtype", DT_INT32}}), + f::NDef("dy0", "Placeholder", {}, {{"dtype", T}}), + f::NDef("dy1", "Placeholder", {}, {{"dtype", T}}), + f::NDef("dx", "SymbolicGradient", + {"x", "size_splits", "dim", "dy0", "dy1"}, + {{"f", FDH::FunctionRef("SplitV", {{"split_dim", dim}, + {"num_split", 2}, + {"T", T}, + {"Tlen", Tlen}})}, + {"Tin", DataTypeSlice{T, Tlen, DT_INT32, T, T}}, + {"Tout", DataTypeSlice{T, Tlen, DT_INT32}}})}); + VLOG(1) << DebugStringWhole(gdef); + auto sess = NewSession(); + TF_CHECK_OK(sess->Create(gdef)); + std::vector<Tensor> out; + TF_CHECK_OK(sess->Run({{"x:0", x}, + {"size_splits:0", size_splits}, + {"dim", test::AsScalar(dim)}, + {"dy0:0", dy0}, + {"dy1:0", dy1}}, + {"dx:0", "dx:1", "dx:2"}, {}, &out)); + CHECK_EQ(out.size(), 3); + TF_CHECK_OK(sess->Close()); + return out; +} + TEST(ArrayGradTest, SplitGrad) { Tensor x(DT_FLOAT, {2, 4, 5}); x.flat<float>().setZero(); @@ -245,15 +278,30 @@ TEST(ArrayGradTest, SplitGrad) { Tensor dy1(DT_FLOAT, {2, 2, 5}); test::FillIota<float>(&dy0, 0); test::FillIota<float>(&dy1, 100); - auto dx = SplitGrad(1, x, dy0, dy1); - test::ExpectTensorEqual<int32>(dx[0], test::AsScalar(0)); - test::ExpectClose( - dx[1], test::AsTensor<float>( - {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., - 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., - 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., - 110., 111., 112., 113., 114., 115., 116., 117., 118., 119.}, - {2, 4, 5})); + auto expected_dx = test::AsTensor<float>( + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., + 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., + 110., 111., 112., 113., 114., 115., 116., 117., 118., 119.}, + {2, 4, 5}); + auto expected_d_dim = test::AsScalar(0); + + // SplitGrad + { + auto dx = SplitGrad(1, x, dy0, dy1); + test::ExpectTensorEqual<int32>(dx[0], expected_d_dim); + test::ExpectClose(dx[1], expected_dx); + } + // SplitVGrad + { + Tensor size_splits(DT_INT64, {2}); + size_splits.flat<int64>().setConstant(2); + auto expected_d_size_splits = test::AsTensor<int64>({0, 0}, {2}); + auto dx = SplitVGrad(x, size_splits, 1, dy0, dy1); + test::ExpectClose(dx[0], expected_dx); + test::ExpectTensorEqual<int64>(dx[1], expected_d_size_splits); + test::ExpectTensorEqual<int32>(dx[2], expected_d_dim); + } } std::vector<Tensor> ReshapeGrad(const Tensor& x, const Tensor& s, |