aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-02-06 21:38:25 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-02-06 21:54:43 -0800
commit23eccfb17635bce1c19b668986dceae1281ccee8 (patch)
treef9471e8c27057feae3bd1b5ade62295fe4bad583
parentaa976ef6a20865b41ab2521756d2af4d8ebb6d5a (diff)
Adds c++ grad for ExpandDims and Transpose.
Change: 114043957
-rw-r--r--tensorflow/core/ops/array_grad.cc20
-rw-r--r--tensorflow/core/ops/array_grad_test.cc119
2 files changed, 139 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc
index 45c1152a33..4bd50c2aef 100644
--- a/tensorflow/core/ops/array_grad.cc
+++ b/tensorflow/core/ops/array_grad.cc
@@ -47,6 +47,7 @@ Status ReshapeGrad(const AttrSlice& attrs, FunctionDef* g) {
return Status::OK();
}
REGISTER_OP_GRADIENT("Reshape", ReshapeGrad);
+REGISTER_OP_GRADIENT("ExpandDims", ReshapeGrad);
Status IdentityGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
@@ -260,4 +261,23 @@ Status FillGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Fill", FillGrad);
+Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
+ *g = FDH::Define(
+ // Arg defs
+ {"x: T", "p: int32", "dy: T"},
+ // Ret val defs
+ {"dx: T", "dp: int32"},
+ // Attr defs
+ {"T: {float, double}"},
+ // Nodes
+ {
+ {{"q"}, "InvertPermutation", {"p"}, {}},
+ {{"dx"}, "Transpose", {"dy", "q"}, {{"T", "$T"}}},
+ {{"dp"}, "ZerosLike", {"p"}, {{"T", DT_INT32}}},
+ });
+ VLOG(1) << "TransposeGrad " << DebugString(*g);
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("Transpose", TransposeGrad);
+
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/array_grad_test.cc b/tensorflow/core/ops/array_grad_test.cc
index 9f6ff6f840..401286ace1 100644
--- a/tensorflow/core/ops/array_grad_test.cc
+++ b/tensorflow/core/ops/array_grad_test.cc
@@ -197,4 +197,123 @@ TEST_F(ArrayGradTest, SplitGrad) {
{2, 4, 5}));
}
+std::vector<Tensor> ReshapeGrad(const Tensor& x, const Tensor& s,
+ const Tensor& dy) {
+ auto T = DT_FLOAT;
+ auto gdef = test::function::GDef(
+ {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
+ f::NDef("s", "Placeholder", {}, {{"dtype", DT_INT32}}),
+ f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
+ f::NDef("dx", "SymbolicGradient", {"x", "s", "dy"},
+ {{"f", FDH::FunctionRef("Reshape", {{"T", T}})},
+ {"Tin", DataTypeSlice{T, DT_INT32, T}},
+ {"Tout", DataTypeSlice{T, DT_INT32}}})});
+ VLOG(1) << DebugStringWhole(gdef);
+ auto sess = NewSession();
+ TF_CHECK_OK(sess->Create(gdef));
+ std::vector<Tensor> out;
+ TF_CHECK_OK(sess->Run({{"x:0", x}, {"s:0", s}, {"dy:0", dy}},
+ {"dx:0", "dx:1"}, {}, &out));
+ CHECK_EQ(out.size(), 2);
+ TF_CHECK_OK(sess->Close());
+ delete sess;
+ return out;
+}
+
+TEST_F(ArrayGradTest, ReshapeGrad) {
+ Tensor x(DT_FLOAT, {2, 4, 5});
+ x.flat<float>().setZero();
+ auto s = test::AsTensor<int32>({8, 5});
+ Tensor dy(DT_FLOAT, {8, 5});
+ test::FillIota<float>(&dy, 73);
+ auto dx = ReshapeGrad(x, s, dy);
+ test::ExpectClose(
+ dx[0], test::AsTensor<float>(
+ {73., 74., 75., 76., 77., 78., 79., 80., 81., 82.,
+ 83., 84., 85., 86., 87., 88., 89., 90., 91., 92.,
+ 93., 94., 95., 96., 97., 98., 99., 100., 101., 102.,
+ 103., 104., 105., 106., 107., 108., 109., 110., 111., 112.},
+ {2, 4, 5}));
+ test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0}));
+}
+
+std::vector<Tensor> ExpandDimsGrad(const Tensor& x, const Tensor& s,
+ const Tensor& dy) {
+ auto T = DT_FLOAT;
+ auto gdef = test::function::GDef(
+ {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
+ f::NDef("s", "Placeholder", {}, {{"dtype", DT_INT32}}),
+ f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
+ f::NDef("dx", "SymbolicGradient", {"x", "s", "dy"},
+ {{"f", FDH::FunctionRef("ExpandDims", {{"T", T}})},
+ {"Tin", DataTypeSlice{T, DT_INT32, T}},
+ {"Tout", DataTypeSlice{T, DT_INT32}}})});
+ VLOG(1) << DebugStringWhole(gdef);
+ auto sess = NewSession();
+ TF_CHECK_OK(sess->Create(gdef));
+ std::vector<Tensor> out;
+ TF_CHECK_OK(sess->Run({{"x:0", x}, {"s:0", s}, {"dy:0", dy}},
+ {"dx:0", "dx:1"}, {}, &out));
+ CHECK_EQ(out.size(), 2);
+ TF_CHECK_OK(sess->Close());
+ delete sess;
+ return out;
+}
+
+TEST_F(ArrayGradTest, ExpandDimsGrad) {
+ Tensor x(DT_FLOAT, {2, 4, 5});
+ x.flat<float>().setZero();
+ auto s = test::AsTensor<int32>({1});
+ Tensor dy(DT_FLOAT, {2, 1, 4, 5});
+ test::FillIota<float>(&dy, 73);
+ auto dx = ExpandDimsGrad(x, s, dy);
+ test::ExpectClose(
+ dx[0], test::AsTensor<float>(
+ {73., 74., 75., 76., 77., 78., 79., 80., 81., 82.,
+ 83., 84., 85., 86., 87., 88., 89., 90., 91., 92.,
+ 93., 94., 95., 96., 97., 98., 99., 100., 101., 102.,
+ 103., 104., 105., 106., 107., 108., 109., 110., 111., 112.},
+ {2, 4, 5}));
+ test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0}));
+}
+
+std::vector<Tensor> TransposeGrad(const Tensor& x, const Tensor& p,
+ const Tensor& dy) {
+ auto T = DT_FLOAT;
+ auto gdef = test::function::GDef(
+ {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
+ f::NDef("p", "Placeholder", {}, {{"dtype", DT_INT32}}),
+ f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
+ f::NDef("dx", "SymbolicGradient", {"x", "p", "dy"},
+ {{"f", FDH::FunctionRef("Transpose", {{"T", T}})},
+ {"Tin", DataTypeSlice{T, DT_INT32, T}},
+ {"Tout", DataTypeSlice{T, DT_INT32}}})});
+ VLOG(1) << DebugStringWhole(gdef);
+ auto sess = NewSession();
+ TF_CHECK_OK(sess->Create(gdef));
+ std::vector<Tensor> out;
+ TF_CHECK_OK(sess->Run({{"x:0", x}, {"p:0", p}, {"dy:0", dy}},
+ {"dx:0", "dx:1"}, {}, &out));
+ CHECK_EQ(out.size(), 2);
+ TF_CHECK_OK(sess->Close());
+ delete sess;
+ return out;
+}
+
+TEST_F(ArrayGradTest, TransposeGrad) {
+ Tensor x(DT_FLOAT, {2, 4, 5});
+ x.flat<float>().setZero();
+ auto p = test::AsTensor<int32>({2, 0, 1});
+ Tensor dy(DT_FLOAT, {5, 2, 4});
+ test::FillIota<float>(&dy, 0);
+ auto dx = TransposeGrad(x, p, dy);
+ test::ExpectClose(dx[0], test::AsTensor<float>(
+ {0., 8., 16., 24., 32., 1., 9., 17., 25., 33.,
+ 2., 10., 18., 26., 34., 3., 11., 19., 27., 35.,
+ 4., 12., 20., 28., 36., 5., 13., 21., 29., 37.,
+ 6., 14., 22., 30., 38., 7., 15., 23., 31., 39.},
+ {2, 4, 5}));
+ test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}));
+}
+
} // namespace tensorflow