diff options
author | 2018-08-22 11:10:49 -0700 | |
---|---|---|
committer | 2018-08-22 11:15:26 -0700 | |
commit | 7f3580dbbc99185be01b19265ba576d58df0ac25 (patch) | |
tree | 7fb402c4b236dd90cb6422c2457d1dfea07e43bc /tensorflow/core/ops/array_grad.cc | |
parent | a3e8e2928506ca7cd1ed25a01ae81cedfdb4b404 (diff) |
SymbolicGradient for gather_nd
PiperOrigin-RevId: 209795871
Diffstat (limited to 'tensorflow/core/ops/array_grad.cc')
-rw-r--r-- | tensorflow/core/ops/array_grad.cc | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index 1f2e57e9a9..3d03bc1d5f 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -354,6 +354,27 @@ Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Transpose", TransposeGrad); +Status GatherNdGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"params: Tparams", "indices: Tindices", "doutput: Tparams"}, + // Ret val defs + {"dparams: Tparams", "dindices: Tindices"}, + // Attr defs + {"Tparams: type", "Tindices: type"}, + // Nodes + { + {{"x_shape"}, "Shape", {"params"}, {{"T", "$Tparams"}}}, + {{"dparams"}, "ScatterNd", {"indices", "doutput", "x_shape"}, + {{"T", "$Tparams"}, {"Tindices", "$Tindices"}}}, + {{"dindices"}, "ZerosLike", {"indices"}, {{"T", "$Tindices"}}}, + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("GatherNd", GatherNdGrad); + Status ConjugateTransposeGrad(const AttrSlice& attrs, FunctionDef* g) { *g = FDH::Define( // Arg defs |