aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/array_grad.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-02-06 21:38:25 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-02-06 21:54:43 -0800
commit23eccfb17635bce1c19b668986dceae1281ccee8 (patch)
treef9471e8c27057feae3bd1b5ade62295fe4bad583 /tensorflow/core/ops/array_grad.cc
parentaa976ef6a20865b41ab2521756d2af4d8ebb6d5a (diff)
Adds c++ grad for ExpandDims and Transpose.
Change: 114043957
Diffstat (limited to 'tensorflow/core/ops/array_grad.cc')
-rw-r--r--tensorflow/core/ops/array_grad.cc20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc
index 45c1152a33..4bd50c2aef 100644
--- a/tensorflow/core/ops/array_grad.cc
+++ b/tensorflow/core/ops/array_grad.cc
@@ -47,6 +47,7 @@ Status ReshapeGrad(const AttrSlice& attrs, FunctionDef* g) {
return Status::OK();
}
REGISTER_OP_GRADIENT("Reshape", ReshapeGrad);
+REGISTER_OP_GRADIENT("ExpandDims", ReshapeGrad);
Status IdentityGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
@@ -260,4 +261,23 @@ Status FillGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Fill", FillGrad);
+Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
+ *g = FDH::Define(
+ // Arg defs
+ {"x: T", "p: int32", "dy: T"},
+ // Ret val defs
+ {"dx: T", "dp: int32"},
+ // Attr defs
+ {"T: {float, double}"},
+ // Nodes
+ {
+ {{"q"}, "InvertPermutation", {"p"}, {}},
+ {{"dx"}, "Transpose", {"dy", "q"}, {{"T", "$T"}}},
+ {{"dp"}, "ZerosLike", {"p"}, {{"T", DT_INT32}}},
+ });
+ VLOG(1) << "TransposeGrad " << DebugString(*g);
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("Transpose", TransposeGrad);
+
} // end namespace tensorflow