aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/array_grad.cc
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-08-22 11:10:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 11:15:26 -0700
commit7f3580dbbc99185be01b19265ba576d58df0ac25 (patch)
tree7fb402c4b236dd90cb6422c2457d1dfea07e43bc /tensorflow/core/ops/array_grad.cc
parenta3e8e2928506ca7cd1ed25a01ae81cedfdb4b404 (diff)
SymbolicGradient for gather_nd
PiperOrigin-RevId: 209795871
Diffstat (limited to 'tensorflow/core/ops/array_grad.cc')
-rw-r--r--tensorflow/core/ops/array_grad.cc21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc
index 1f2e57e9a9..3d03bc1d5f 100644
--- a/tensorflow/core/ops/array_grad.cc
+++ b/tensorflow/core/ops/array_grad.cc
@@ -354,6 +354,27 @@ Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Transpose", TransposeGrad);
+Status GatherNdGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FDH::Define(
+ // Arg defs
+ {"params: Tparams", "indices: Tindices", "doutput: Tparams"},
+ // Ret val defs
+ {"dparams: Tparams", "dindices: Tindices"},
+ // Attr defs
+ {"Tparams: type", "Tindices: type"},
+ // Nodes
+ {
+ {{"x_shape"}, "Shape", {"params"}, {{"T", "$Tparams"}}},
+ {{"dparams"}, "ScatterNd", {"indices", "doutput", "x_shape"},
+ {{"T", "$Tparams"}, {"Tindices", "$Tindices"}}},
+ {{"dindices"}, "ZerosLike", {"indices"}, {{"T", "$Tindices"}}},
+ });
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("GatherNd", GatherNdGrad);
+
Status ConjugateTransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
*g = FDH::Define(
// Arg defs