diff options
author | 2016-02-06 21:38:25 -0800 | |
---|---|---|
committer | 2016-02-06 21:54:43 -0800 | |
commit | 23eccfb17635bce1c19b668986dceae1281ccee8 (patch) | |
tree | f9471e8c27057feae3bd1b5ade62295fe4bad583 /tensorflow/core/ops/array_grad.cc | |
parent | aa976ef6a20865b41ab2521756d2af4d8ebb6d5a (diff) |
Adds c++ grad for ExpandDims and Transpose.
Change: 114043957
Diffstat (limited to 'tensorflow/core/ops/array_grad.cc')
-rw-r--r-- | tensorflow/core/ops/array_grad.cc | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index 45c1152a33..4bd50c2aef 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -47,6 +47,7 @@ Status ReshapeGrad(const AttrSlice& attrs, FunctionDef* g) { return Status::OK(); } REGISTER_OP_GRADIENT("Reshape", ReshapeGrad); +REGISTER_OP_GRADIENT("ExpandDims", ReshapeGrad); Status IdentityGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off @@ -260,4 +261,23 @@ Status FillGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Fill", FillGrad); +Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) { + *g = FDH::Define( + // Arg defs + {"x: T", "p: int32", "dy: T"}, + // Ret val defs + {"dx: T", "dp: int32"}, + // Attr defs + {"T: {float, double}"}, + // Nodes + { + {{"q"}, "InvertPermutation", {"p"}, {}}, + {{"dx"}, "Transpose", {"dy", "q"}, {{"T", "$T"}}}, + {{"dp"}, "ZerosLike", {"p"}, {{"T", DT_INT32}}}, + }); + VLOG(1) << "TransposeGrad " << DebugString(*g); + return Status::OK(); +} +REGISTER_OP_GRADIENT("Transpose", TransposeGrad); + } // end namespace tensorflow |