aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-09 20:11:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 20:15:28 -0700
commitde082e57fb7609c261af72f583e0d5c236023376 (patch)
tree0ac48de4042bf52caec8307c8bebf191ef692c06 /tensorflow
parent652be30e7140ca11b756a3ae0f9bd67f913af399 (diff)
Add c++ gradient for SplitV op.
PiperOrigin-RevId: 208153311
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/ops/array_grad.cc21
-rw-r--r--tensorflow/core/ops/array_grad_test.cc66
2 files changed, 78 insertions, 9 deletions
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc
index 38bd851da8..1f2e57e9a9 100644
--- a/tensorflow/core/ops/array_grad.cc
+++ b/tensorflow/core/ops/array_grad.cc
@@ -244,6 +244,27 @@ Status SplitGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Split", SplitGrad);
+Status SplitVGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FDH::Define(
+ // Arg defs
+ {"x: T", "size_splits: Tlen", "dim: int32", "dy: num_split*T"},
+ // Ret val defs
+ {"dx: T", "d_size_splits: Tlen", "d_dim: int32"},
+ // Attr defs
+ {"T: type", "Tlen: type", "num_split: int"},
+ // Nodes
+ {
+ {{"dx"}, "Concat", {"dim", "dy"}, {{"T", "$T"}, {"N", "$num_split"}}},
+ {{"d_size_splits"}, "ZerosLike", {"size_splits"}, {{"T", "$Tlen"}}},
+ {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
+ });
+ // clang-format on
+ VLOG(1) << "SplitVGrad " << DebugString(*g);
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("SplitV", SplitVGrad);
+
Status ArrayToListGrad(const AttrSlice& attrs, FunctionDef* g) {
int N;
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N));
diff --git a/tensorflow/core/ops/array_grad_test.cc b/tensorflow/core/ops/array_grad_test.cc
index e665d17938..79d28a83cc 100644
--- a/tensorflow/core/ops/array_grad_test.cc
+++ b/tensorflow/core/ops/array_grad_test.cc
@@ -238,6 +238,39 @@ std::vector<Tensor> SplitGrad(int dim, const Tensor& x, const Tensor& dy0,
return out;
}
+std::vector<Tensor> SplitVGrad(const Tensor& x, const Tensor& size_splits,
+ int dim, const Tensor& dy0, const Tensor& dy1) {
+ auto T = DT_FLOAT;
+ auto Tlen = DT_INT64;
+ auto gdef = test::function::GDef(
+ {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
+ f::NDef("size_splits", "Placeholder", {}, {{"dtype", Tlen}}),
+ f::NDef("dim", "Placeholder", {}, {{"dtype", DT_INT32}}),
+ f::NDef("dy0", "Placeholder", {}, {{"dtype", T}}),
+ f::NDef("dy1", "Placeholder", {}, {{"dtype", T}}),
+ f::NDef("dx", "SymbolicGradient",
+ {"x", "size_splits", "dim", "dy0", "dy1"},
+ {{"f", FDH::FunctionRef("SplitV", {{"split_dim", dim},
+ {"num_split", 2},
+ {"T", T},
+ {"Tlen", Tlen}})},
+ {"Tin", DataTypeSlice{T, Tlen, DT_INT32, T, T}},
+ {"Tout", DataTypeSlice{T, Tlen, DT_INT32}}})});
+ VLOG(1) << DebugStringWhole(gdef);
+ auto sess = NewSession();
+ TF_CHECK_OK(sess->Create(gdef));
+ std::vector<Tensor> out;
+ TF_CHECK_OK(sess->Run({{"x:0", x},
+ {"size_splits:0", size_splits},
+ {"dim", test::AsScalar(dim)},
+ {"dy0:0", dy0},
+ {"dy1:0", dy1}},
+ {"dx:0", "dx:1", "dx:2"}, {}, &out));
+ CHECK_EQ(out.size(), 3);
+ TF_CHECK_OK(sess->Close());
+ return out;
+}
+
TEST(ArrayGradTest, SplitGrad) {
Tensor x(DT_FLOAT, {2, 4, 5});
x.flat<float>().setZero();
@@ -245,15 +278,30 @@ TEST(ArrayGradTest, SplitGrad) {
Tensor dy1(DT_FLOAT, {2, 2, 5});
test::FillIota<float>(&dy0, 0);
test::FillIota<float>(&dy1, 100);
- auto dx = SplitGrad(1, x, dy0, dy1);
- test::ExpectTensorEqual<int32>(dx[0], test::AsScalar(0));
- test::ExpectClose(
- dx[1], test::AsTensor<float>(
- {0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
- 100., 101., 102., 103., 104., 105., 106., 107., 108., 109.,
- 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,
- 110., 111., 112., 113., 114., 115., 116., 117., 118., 119.},
- {2, 4, 5}));
+ auto expected_dx = test::AsTensor<float>(
+ {0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
+ 100., 101., 102., 103., 104., 105., 106., 107., 108., 109.,
+ 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,
+ 110., 111., 112., 113., 114., 115., 116., 117., 118., 119.},
+ {2, 4, 5});
+ auto expected_d_dim = test::AsScalar(0);
+
+ // SplitGrad
+ {
+ auto dx = SplitGrad(1, x, dy0, dy1);
+ test::ExpectTensorEqual<int32>(dx[0], expected_d_dim);
+ test::ExpectClose(dx[1], expected_dx);
+ }
+ // SplitVGrad
+ {
+ Tensor size_splits(DT_INT64, {2});
+ size_splits.flat<int64>().setConstant(2);
+ auto expected_d_size_splits = test::AsTensor<int64>({0, 0}, {2});
+ auto dx = SplitVGrad(x, size_splits, 1, dy0, dy1);
+ test::ExpectClose(dx[0], expected_dx);
+ test::ExpectTensorEqual<int64>(dx[1], expected_d_size_splits);
+ test::ExpectTensorEqual<int32>(dx[2], expected_d_dim);
+ }
}
std::vector<Tensor> ReshapeGrad(const Tensor& x, const Tensor& s,